Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add doRawStream #1639

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/clean-brooms-fold.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
'@ai-sdk/provider-utils': patch
'@ai-sdk/provider': patch
'@ai-sdk/openai': patch
'ai': patch
---

Prototype Raw Response
59 changes: 59 additions & 0 deletions packages/core/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { runToolsTransformation } from './run-tools-transformation';
import { TokenUsage } from './token-usage';
import { ToToolCall } from './tool-call';
import { ToToolResult } from './tool-result';
import { LanguageModelV1CallWarning } from '@ai-sdk/provider';

/**
Generate a text and call tools for a given prompt using a language model.
Expand Down Expand Up @@ -110,6 +111,64 @@ The tools that the model can call. The model needs to support calling tools.
});
}

export async function streamResponse<TOOLS extends Record<string, CoreTool>>({
model,
tools,
system,
prompt,
messages,
maxRetries,
abortSignal,
...settings
}: CallSettings &
Prompt & {
/**
The language model to use.
*/
model: LanguageModel;

/**
The tools that the model can call. The model needs to support calling tools.
*/
tools?: TOOLS;
}): Promise<{
stream: ReadableStream;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse:
| {
headers?: Record<string, string>;
}
| undefined;
}> {
const retry = retryWithExponentialBackoff({ maxRetries });
const validatedPrompt = getValidatedPrompt({ system, prompt, messages });
const { stream, warnings, rawResponse } = await retry(() => {
if (!model.doRawStream) {
throw new Error('The model does not support raw streaming.');
}
return model.doRawStream({
mode: {
type: 'regular',
tools:
tools == null
? undefined
: Object.entries(tools).map(([name, tool]) => ({
type: 'function',
name,
description: tool.description,
parameters: convertZodToJSONSchema(tool.parameters),
})),
},
...prepareCallSettings(settings),
inputFormat: validatedPrompt.type,
prompt: convertToLanguageModelPrompt(validatedPrompt),
abortSignal,
});
});

return { stream, warnings, rawResponse };
}

export type TextStreamPart<TOOLS extends Record<string, CoreTool>> =
| {
type: 'text-delta';
Expand Down
40 changes: 40 additions & 0 deletions packages/openai/src/openai-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
} from '@ai-sdk/provider';
import {
ParseResult,
createEventSourcePassThroughHandler,
createEventSourceResponseHandler,
createJsonResponseHandler,
generateId,
Expand Down Expand Up @@ -188,6 +189,45 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {
};
}

async doRawStream(
options: Parameters<LanguageModelV1['doStream']>[0],
): Promise<
Omit<Awaited<ReturnType<LanguageModelV1['doStream']>>, 'stream'> & {
stream: ReadableStream<Uint8Array>;
}
> {
const args = this.getArgs(options);

const { responseHeaders, value: responseBody } = await postJsonToApi({
url: `${this.config.baseURL}/chat/completions`,
headers: this.config.headers(),
body: {
...args,
stream: true,

// only include stream_options when in strict compatibility mode:
stream_options:
this.config.compatibility === 'strict'
? { include_usage: true }
: undefined,
},
failedResponseHandler: openaiFailedResponseHandler,
successfulResponseHandler: createEventSourcePassThroughHandler(
openaiChatChunkSchema,
),
abortSignal: options.abortSignal,
});

const { messages: rawPrompt, ...rawSettings } = args;

return {
stream: responseBody,
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings: [],
};
}

async doStream(
options: Parameters<LanguageModelV1['doStream']>[0],
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
Expand Down
15 changes: 15 additions & 0 deletions packages/provider-utils/src/response-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ export const createEventSourceResponseHandler =
};
};

export const createEventSourcePassThroughHandler =
<T>(chunkSchema: ZodSchema<T>): ResponseHandler<ReadableStream<Uint8Array>> =>
async ({ response }: { response: Response }) => {
const responseHeaders = extractResponseHeaders(response);

if (response.body == null) {
throw new EmptyResponseBodyError({});
}

return {
responseHeaders,
value: response.body,
};
};

export const createJsonResponseHandler =
<T>(responseSchema: ZodSchema<T>): ResponseHandler<T> =>
async ({ response, url, requestBodyValues }) => {
Expand Down
33 changes: 33 additions & 0 deletions packages/provider/src/language-model/v1/language-model-v1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,39 @@ Response headers.

warnings?: LanguageModelV1CallWarning[];
}>;

doRawStream?: (options: LanguageModelV1CallOptions) => PromiseLike<{
stream: ReadableStream<Uint8Array>;

/**
Raw prompt and setting information for observability provider integration.
*/
rawCall: {
/**
Raw prompt after expansion and conversion to the format that the
provider uses to send the information to their API.
*/
rawPrompt: unknown;

/**
Raw settings that are used for the API call. Includes provider-specific
settings.
*/
rawSettings: Record<string, unknown>;
};

/**
Optional raw response data.
*/
rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

warnings?: LanguageModelV1CallWarning[];
}>;
};

export type LanguageModelV1StreamPart =
Expand Down
Loading