Skip to content

Commit

Permalink
fix: ax ai provider
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Jul 28, 2024
1 parent 8969119 commit b87bf02
Showing 1 changed file with 26 additions and 25 deletions.
51 changes: 26 additions & 25 deletions src/ai-sdk-provider/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,46 +155,40 @@ function prepareToolsAndToolChoice(
mode: Readonly<
Parameters<LanguageModelV1['doGenerate']>[0]['mode'] & { type: 'regular' }
>
) {
): Pick<AxChatRequest, 'functions' | 'functionCall'> {
// when the tools array is empty, change it to undefined to prevent errors:
const tools = mode.tools?.length ? mode.tools : undefined;

if (tools == null) {
return { tools: undefined, tool_choice: undefined };
if (!tools) {
return {};
}

const mappedTools = tools.map((tool) => ({
type: 'function',
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters
}
const functions = tools.map((f) => ({
name: f.name,
description: f.description ?? '',
parameters: f.parameters
}));

const toolChoice = mode.toolChoice;

if (toolChoice == null) {
return { tools: mappedTools, tool_choice: undefined };
if (!toolChoice) {
return { functions };
}

const type = toolChoice.type;

switch (type) {
case 'auto':
return { functions, functionCall: 'auto' };
case 'none':
return { tools: mappedTools, tool_choice: type };
return { functions, functionCall: 'none' };
case 'required':
return { tools: mappedTools, tool_choice: 'any' };

// mistral does not support tool mode directly,
// so we filter the tools and force the tool choice through 'any'
return { functions, functionCall: 'required' };
case 'tool':
return {
tools: mappedTools.filter(
(tool) => tool.function.name === toolChoice.toolName
),
tool_choice: 'any'
functions,
functionCall: {
type: 'function',
function: { name: toolChoice.toolName }
}
};
default: {
const _exhaustiveCheck: never = type;
Expand Down Expand Up @@ -322,7 +316,7 @@ function createChatRequest({
topP,
frequencyPenalty,
presencePenalty
//seed
//seed,
}: Readonly<Parameters<LanguageModelV1['doGenerate']>[0]>): {
req: AxChatRequest;
warnings: LanguageModelV1CallWarning[];
Expand Down Expand Up @@ -354,8 +348,15 @@ function createChatRequest({
}

case 'object-tool': {
const tool = {
type: 'function',
function: {
name: mode.tool.name,
params: mode.tool.parameters
}
};
return {
req: { ...req, ...mode.tool },
req: { ...req, ...tool },
warnings
};
}
Expand Down

0 comments on commit b87bf02

Please sign in to comment.