Skip to content

Commit

Permalink
fix: Accessing Stream Chunks (Streamed generation) #36
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Jul 4, 2024
1 parent 9b4e1cc commit 64f661f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 40 deletions.
75 changes: 37 additions & 38 deletions src/ax/dsp/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
import {
type extractionState,
extractValues,
streamingExtractFinalValue,
streamingExtractValues,
ValidationError
} from './extract.js';
Expand Down Expand Up @@ -293,8 +294,8 @@ export class AxGenerate<
await this.processFunctions(funcs, mem, sessionId, traceId);
}

// streamingExtractFinalValue(values, xstate, content);
// assertAssertions(this.asserts, values);
streamingExtractFinalValue(values, xstate, content);
assertAssertions(this.asserts, values);

return { ...values } as unknown as OUT;
}
Expand Down Expand Up @@ -361,9 +362,9 @@ export class AxGenerate<
const userMsg = { role: 'user' as const, content: prompt };
mem.add(userMsg, options?.sessionId);

for (let i = 0; i < maxRetries; i++) {
try {
for (let n = 0; n < maxSteps; n++) {
multiStepLoop: for (let n = 0; n < maxSteps; n++) {
for (let i = 0; i < maxRetries; i++) {
try {
const {
sessionId,
traceId,
Expand All @@ -385,48 +386,46 @@ export class AxGenerate<

const lastMemItem = mem.getLast(sessionId);

if (lastMemItem?.role !== 'function') {
assertRequiredFields(this.signature, output);
this.trace = { ...output };
return output;
if (lastMemItem?.role === 'function') {
continue multiStepLoop;
}
}
throw new Error('Could not complete task within maximum allowed steps');
} catch (e) {
let extraFields;
span?.recordAxSpanException(e as Error);

if (e instanceof ValidationError) {
extraFields = e.getFixingInstructions();
err = e;
} else if (e instanceof AxAssertionError) {
const e1 = e as AxAssertionError;
extraFields = e1.getFixingInstructions(this.signature);
err = e;
} else {
throw e;
}

if (extraFields) {
const content = this.pt.renderExtraFields(extraFields);
const userMsg = {
role: 'user' as const,
content
};
assertRequiredFields(this.signature, output);
this.trace = { ...output };
return output;
} catch (e) {
let extraFields;
span?.recordAxSpanException(e as Error);

if (e instanceof ValidationError) {
extraFields = e.getFixingInstructions();
err = e;
} else if (e instanceof AxAssertionError) {
const e1 = e as AxAssertionError;
extraFields = e1.getFixingInstructions(this.signature);
err = e;
} else {
throw e;
}

if (extraFields) {
const content = this.pt.renderExtraFields(extraFields);
mem.add({ role: 'user' as const, content }, options?.sessionId);

mem.add(userMsg, options?.sessionId);
if (options?.debug) {
console.log('Error Correction:', content);
if (options?.debug) {
console.log('Error Correction:', content);
}
}
}
}
}
if (err instanceof AxAssertionError && err.getOptional()) {
return err.getValue() as OUT;
}

if (err instanceof AxAssertionError && err.getOptional()) {
return err.getValue() as OUT;
throw new Error(`Unable to fix validation error: ${err?.message}`);
}

throw new Error(`Unable to fix validation error: ${err?.message}`);
throw new Error('Could not complete task within maximum allowed steps');
}

public override async forward(
Expand Down
4 changes: 2 additions & 2 deletions src/examples/streaming2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ gen.addStreamingAssert(
// run the program with streaming enabled
const res = await gen.forward(
{
question: 'Provide a list of optimizations to speedup LLM inference.'
question: 'Provide a list of 3 optimizations to speedup LLM inference.'
},
{ stream: true, debug: true }
{ stream: true }
);

console.log('>', res);

0 comments on commit 64f661f

Please sign in to comment.