From 64f661f34de1e97e335ac1032a2064b185ec97d8 Mon Sep 17 00:00:00 2001 From: dosco <832235+dosco@users.noreply.github.com> Date: Thu, 4 Jul 2024 00:49:05 -0700 Subject: [PATCH] fix: Accessing Stream Chunks (Streamed generation) #36 --- src/ax/dsp/generate.ts | 75 +++++++++++++++++++------------------- src/examples/streaming2.ts | 4 +- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/src/ax/dsp/generate.ts b/src/ax/dsp/generate.ts index 77838b6..def0146 100644 --- a/src/ax/dsp/generate.ts +++ b/src/ax/dsp/generate.ts @@ -27,6 +27,7 @@ import { import { type extractionState, extractValues, + streamingExtractFinalValue, streamingExtractValues, ValidationError } from './extract.js'; @@ -293,8 +294,8 @@ export class AxGenerate< await this.processFunctions(funcs, mem, sessionId, traceId); } - // streamingExtractFinalValue(values, xstate, content); - // assertAssertions(this.asserts, values); + streamingExtractFinalValue(values, xstate, content); + assertAssertions(this.asserts, values); return { ...values } as unknown as OUT; } @@ -361,9 +362,9 @@ export class AxGenerate< const userMsg = { role: 'user' as const, content: prompt }; mem.add(userMsg, options?.sessionId); - for (let i = 0; i < maxRetries; i++) { - try { - for (let n = 0; n < maxSteps; n++) { + multiStepLoop: for (let n = 0; n < maxSteps; n++) { + for (let i = 0; i < maxRetries; i++) { + try { const { sessionId, traceId, @@ -385,48 +386,46 @@ export class AxGenerate< const lastMemItem = mem.getLast(sessionId); - if (lastMemItem?.role !== 'function') { - assertRequiredFields(this.signature, output); - this.trace = { ...output }; - return output; + if (lastMemItem?.role === 'function') { + continue multiStepLoop; } - } - throw new Error('Could not complete task within maximum allowed steps'); - } catch (e) { - let extraFields; - span?.recordAxSpanException(e as Error); - - if (e instanceof ValidationError) { - extraFields = e.getFixingInstructions(); - err = e; - } else if (e instanceof AxAssertionError) { - const e1 = e as AxAssertionError; - extraFields = e1.getFixingInstructions(this.signature); - err = e; - } else { - throw e; - } - if (extraFields) { - const content = this.pt.renderExtraFields(extraFields); - const userMsg = { - role: 'user' as const, - content - }; + assertRequiredFields(this.signature, output); + this.trace = { ...output }; + return output; + } catch (e) { + let extraFields; + span?.recordAxSpanException(e as Error); + + if (e instanceof ValidationError) { + extraFields = e.getFixingInstructions(); + err = e; + } else if (e instanceof AxAssertionError) { + const e1 = e as AxAssertionError; + extraFields = e1.getFixingInstructions(this.signature); + err = e; + } else { + throw e; + } + + if (extraFields) { + const content = this.pt.renderExtraFields(extraFields); + mem.add({ role: 'user' as const, content }, options?.sessionId); - mem.add(userMsg, options?.sessionId); - if (options?.debug) { - console.log('Error Correction:', content); + if (options?.debug) { + console.log('Error Correction:', content); + } } } } - } + if (err instanceof AxAssertionError && err.getOptional()) { + return err.getValue() as OUT; + } - if (err instanceof AxAssertionError && err.getOptional()) { - return err.getValue() as OUT; + throw new Error(`Unable to fix validation error: ${err?.message}`); } - throw new Error(`Unable to fix validation error: ${err?.message}`); + throw new Error('Could not complete task within maximum allowed steps'); } public override async forward( diff --git a/src/examples/streaming2.ts b/src/examples/streaming2.ts index 55deb15..4218969 100644 --- a/src/examples/streaming2.ts +++ b/src/examples/streaming2.ts @@ -36,9 +36,9 @@ gen.addStreamingAssert( // run the program with streaming enabled const res = await gen.forward( { - question: 'Provide a list of optimizations to speedup LLM inference.' + question: 'Provide a list of 3 optimizations to speedup LLM inference.' }, - { stream: true, debug: true } + { stream: true } ); console.log('>', res);