diff --git a/.changeset/clean-brooms-fold.md b/.changeset/clean-brooms-fold.md new file mode 100644 index 00000000000..50d6a0b6bac --- /dev/null +++ b/.changeset/clean-brooms-fold.md @@ -0,0 +1,8 @@ +--- +'@ai-sdk/provider-utils': patch +'@ai-sdk/provider': patch +'@ai-sdk/openai': patch +'ai': patch +--- + +Prototype Raw Response diff --git a/packages/core/core/generate-text/stream-text.ts b/packages/core/core/generate-text/stream-text.ts index e52134a350f..61ea39aa537 100644 --- a/packages/core/core/generate-text/stream-text.ts +++ b/packages/core/core/generate-text/stream-text.ts @@ -21,6 +21,7 @@ import { runToolsTransformation } from './run-tools-transformation'; import { TokenUsage } from './token-usage'; import { ToToolCall } from './tool-call'; import { ToToolResult } from './tool-result'; +import { LanguageModelV1CallWarning } from '@ai-sdk/provider'; /** Generate a text and call tools for a given prompt using a language model. @@ -110,6 +111,64 @@ The tools that the model can call. The model needs to support calling tools. }); } +export async function streamResponse>({ + model, + tools, + system, + prompt, + messages, + maxRetries, + abortSignal, + ...settings +}: CallSettings & + Prompt & { + /** +The language model to use. + */ + model: LanguageModel; + + /** +The tools that the model can call. The model needs to support calling tools. + */ + tools?: TOOLS; + }): Promise<{ + stream: ReadableStream; + warnings: LanguageModelV1CallWarning[] | undefined; + rawResponse: + | { + headers?: Record; + } + | undefined; +}> { + const retry = retryWithExponentialBackoff({ maxRetries }); + const validatedPrompt = getValidatedPrompt({ system, prompt, messages }); + const { stream, warnings, rawResponse } = await retry(() => { + if (!model.doRawStream) { + throw new Error('The model does not support raw streaming.'); + } + return model.doRawStream({ + mode: { + type: 'regular', + tools: + tools == null + ? undefined + : Object.entries(tools).map(([name, tool]) => ({ + type: 'function', + name, + description: tool.description, + parameters: convertZodToJSONSchema(tool.parameters), + })), + }, + ...prepareCallSettings(settings), + inputFormat: validatedPrompt.type, + prompt: convertToLanguageModelPrompt(validatedPrompt), + abortSignal, + }); + }); + + return { stream, warnings, rawResponse }; +} + export type TextStreamPart> = | { type: 'text-delta'; diff --git a/packages/openai/src/openai-chat-language-model.ts b/packages/openai/src/openai-chat-language-model.ts index 521b61af12f..a1949ce911e 100644 --- a/packages/openai/src/openai-chat-language-model.ts +++ b/packages/openai/src/openai-chat-language-model.ts @@ -8,6 +8,7 @@ import { } from '@ai-sdk/provider'; import { ParseResult, + createEventSourcePassThroughHandler, createEventSourceResponseHandler, createJsonResponseHandler, generateId, @@ -188,6 +189,45 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { }; } + async doRawStream( + options: Parameters[0], + ): Promise< + Omit>, 'stream'> & { + stream: ReadableStream; + } + > { + const args = this.getArgs(options); + + const { responseHeaders, value: responseBody } = await postJsonToApi({ + url: `${this.config.baseURL}/chat/completions`, + headers: this.config.headers(), + body: { + ...args, + stream: true, + + // only include stream_options when in strict compatibility mode: + stream_options: + this.config.compatibility === 'strict' + ? { include_usage: true } + : undefined, + }, + failedResponseHandler: openaiFailedResponseHandler, + successfulResponseHandler: createEventSourcePassThroughHandler( + openaiChatChunkSchema, + ), + abortSignal: options.abortSignal, + }); + + const { messages: rawPrompt, ...rawSettings } = args; + + return { + stream: responseBody, + rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, + warnings: [], + }; + } + async doStream( options: Parameters[0], ): Promise>> { diff --git a/packages/provider-utils/src/response-handler.ts b/packages/provider-utils/src/response-handler.ts index 63b5b5a8dad..f61001c80f3 100644 --- a/packages/provider-utils/src/response-handler.ts +++ b/packages/provider-utils/src/response-handler.ts @@ -118,6 +118,21 @@ export const createEventSourceResponseHandler = }; }; +export const createEventSourcePassThroughHandler = + (chunkSchema: ZodSchema): ResponseHandler> => + async ({ response }: { response: Response }) => { + const responseHeaders = extractResponseHeaders(response); + + if (response.body == null) { + throw new EmptyResponseBodyError({}); + } + + return { + responseHeaders, + value: response.body, + }; + }; + export const createJsonResponseHandler = (responseSchema: ZodSchema): ResponseHandler => async ({ response, url, requestBodyValues }) => { diff --git a/packages/provider/src/language-model/v1/language-model-v1.ts b/packages/provider/src/language-model/v1/language-model-v1.ts index 3d3c09a4928..4cbc73cfa72 100644 --- a/packages/provider/src/language-model/v1/language-model-v1.ts +++ b/packages/provider/src/language-model/v1/language-model-v1.ts @@ -145,6 +145,39 @@ Response headers. warnings?: LanguageModelV1CallWarning[]; }>; + + doRawStream?: (options: LanguageModelV1CallOptions) => PromiseLike<{ + stream: ReadableStream; + + /** +Raw prompt and setting information for observability provider integration. + */ + rawCall: { + /** +Raw prompt after expansion and conversion to the format that the +provider uses to send the information to their API. + */ + rawPrompt: unknown; + + /** +Raw settings that are used for the API call. Includes provider-specific +settings. + */ + rawSettings: Record; + }; + + /** +Optional raw response data. + */ + rawResponse?: { + /** +Response headers. + */ + headers?: Record; + }; + + warnings?: LanguageModelV1CallWarning[]; + }>; }; export type LanguageModelV1StreamPart =