diff --git a/package.json b/package.json index 14637f2ade..ff79321eb9 100644 --- a/package.json +++ b/package.json @@ -18,7 +18,7 @@ "prettier": "^2.8.8", "prettier-plugin-tailwindcss": "^0.3.0", "ts-jest": "^29.1.0", - "turbo": "^1.10.7" + "turbo": "^1.10.6" }, "packageManager": "pnpm@7.15.0", "name": "llamascript" diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index b3aebadce6..9d7af413d6 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -1,9 +1,4 @@ -import { - BaseChatModel, - BaseMessage, - ChatOpenAI, - LLMResult, -} from "./LanguageModel"; +import { BaseChatModel, BaseMessage, ChatOpenAI } from "./LanguageModel"; import { TextNode } from "./Node"; import { SimplePrompt, @@ -15,6 +10,8 @@ import { BaseQueryEngine } from "./QueryEngine"; import { Response } from "./Response"; import { BaseRetriever } from "./Retriever"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; +import { v4 as uuidv4 } from "uuid"; +import { Event } from "./callbacks/CallbackManager"; interface ChatEngine { chatRepl(): void; @@ -30,7 +27,7 @@ export class SimpleChatEngine implements ChatEngine { constructor(init?: Partial) { this.chatHistory = init?.chatHistory ?? []; - this.llm = init?.llm ?? new ChatOpenAI(); + this.llm = init?.llm ?? new ChatOpenAI({ model: "gpt-3.5-turbo" }); } chatRepl() { @@ -125,7 +122,8 @@ export class ContextChatEngine implements ChatEngine { chatHistory?: BaseMessage[]; }) { this.retriever = init.retriever; - this.chatModel = init.chatModel ?? new ChatOpenAI("gpt-3.5-turbo-16k"); + this.chatModel = + init.chatModel ?? new ChatOpenAI({ model: "gpt-3.5-turbo-16k" }); this.chatHistory = init?.chatHistory ?? []; } @@ -136,7 +134,15 @@ export class ContextChatEngine implements ChatEngine { async achat(message: string, chatHistory?: BaseMessage[] | undefined) { chatHistory = chatHistory ?? this.chatHistory; - const sourceNodesWithScore = await this.retriever.aretrieve(message); + const parentEvent: Event = { + id: uuidv4(), + type: "wrapper", + tags: ["final"], + }; + const sourceNodesWithScore = await this.retriever.aretrieve( + message, + parentEvent + ); const systemMessage: BaseMessage = { content: contextSystemPrompt({ @@ -149,10 +155,10 @@ export class ContextChatEngine implements ChatEngine { chatHistory.push({ content: message, type: "human" }); - const response = await this.chatModel.agenerate([ - systemMessage, - ...chatHistory, - ]); + const response = await this.chatModel.agenerate( + [systemMessage, ...chatHistory], + parentEvent + ); const text = response.generations[0][0].text; chatHistory.push({ content: text, type: "ai" }); diff --git a/packages/core/src/GlobalsHelper.ts b/packages/core/src/GlobalsHelper.ts index afd4549bae..58a154d33a 100644 --- a/packages/core/src/GlobalsHelper.ts +++ b/packages/core/src/GlobalsHelper.ts @@ -1,3 +1,6 @@ +import { Event, EventTag, EventType } from "./callbacks/CallbackManager"; +import { v4 as uuidv4 } from "uuid"; + class GlobalsHelper { defaultTokenizer: ((text: string) => string[]) | null = null; @@ -13,6 +16,24 @@ class GlobalsHelper { }; return this.defaultTokenizer; } + + createEvent({ + parentEvent, + type, + tags, + }: { + parentEvent?: Event; + type: EventType; + tags?: EventTag[]; + }): Event { + return { + id: uuidv4(), + type, + // inherit parent tags if tags not set + tags: tags || parentEvent?.tags, + parentId: parentEvent?.id, + }; + } } export const globalsHelper = new GlobalsHelper(); diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts index da91123119..93a47cd806 100644 --- a/packages/core/src/LLMPredictor.ts +++ b/packages/core/src/LLMPredictor.ts @@ -1,30 +1,50 @@ import { ChatOpenAI } from "./LanguageModel"; import { SimplePrompt } from "./Prompt"; +import { CallbackManager, Event } from "./callbacks/CallbackManager"; // TODO change this to LLM class export interface BaseLLMPredictor { getLlmMetadata(): Promise; apredict( prompt: string | SimplePrompt, - input?: Record + input?: Record, + parentEvent?: Event ): Promise; - // stream(prompt: string, options: any): Promise; } // TODO change this to LLM class export class ChatGPTLLMPredictor implements BaseLLMPredictor { - llm: string; + model: string; retryOnThrottling: boolean; languageModel: ChatOpenAI; + callbackManager?: CallbackManager; constructor( - llm: string = "gpt-3.5-turbo", - retryOnThrottling: boolean = true + props: + | { + model?: string; + retryOnThrottling?: boolean; + callbackManager?: CallbackManager; + languageModel?: ChatOpenAI; + } + | undefined = undefined ) { - this.llm = llm; + const { + model = "gpt-3.5-turbo", + retryOnThrottling = true, + callbackManager, + languageModel, + } = props || {}; + this.model = model; + this.callbackManager = callbackManager; this.retryOnThrottling = retryOnThrottling; - this.languageModel = new ChatOpenAI(this.llm); + this.languageModel = + languageModel ?? + new ChatOpenAI({ + model: this.model, + callbackManager: this.callbackManager, + }); } async getLlmMetadata() { @@ -33,22 +53,22 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor { async apredict( prompt: string | SimplePrompt, - input?: Record + input?: Record, + parentEvent?: Event ): Promise { if (typeof prompt === "string") { - const result = await this.languageModel.agenerate([ - { - content: prompt, - type: "human", - }, - ]); + const result = await this.languageModel.agenerate( + [ + { + content: prompt, + type: "human", + }, + ], + parentEvent + ); return result.generations[0][0].text; } else { return this.apredict(prompt(input ?? {})); } } - - // async stream(prompt: string, options: any) { - // console.log("stream"); - // } } diff --git a/packages/core/src/LanguageModel.ts b/packages/core/src/LanguageModel.ts index 6e8a82ee14..9f9c676e91 100644 --- a/packages/core/src/LanguageModel.ts +++ b/packages/core/src/LanguageModel.ts @@ -1,8 +1,9 @@ +import { CallbackManager, Event } from "./callbacks/CallbackManager"; +import { aHandleOpenAIStream } from "./callbacks/utility/aHandleOpenAIStream"; import { ChatCompletionRequestMessageRoleEnum, - Configuration, + CreateChatCompletionRequest, OpenAISession, - OpenAIWrapper, getOpenAISession, } from "./openai"; @@ -25,7 +26,7 @@ export interface LLMResult { } export interface BaseChatModel extends BaseLanguageModel { - agenerate(messages: BaseMessage[]): Promise; + agenerate(messages: BaseMessage[], parentEvent?: Event): Promise; } export class ChatOpenAI implements BaseChatModel { @@ -36,11 +37,18 @@ export class ChatOpenAI implements BaseChatModel { maxRetries: number = 6; n: number = 1; maxTokens?: number; - session: OpenAISession; + callbackManager?: CallbackManager; - constructor(model: string = "gpt-3.5-turbo") { + constructor({ + model = "gpt-3.5-turbo", + callbackManager, + }: { + model: string; + callbackManager?: CallbackManager; + }) { this.model = model; + this.callbackManager = callbackManager; this.session = getOpenAISession(); } @@ -61,8 +69,11 @@ export class ChatOpenAI implements BaseChatModel { } } - async agenerate(messages: BaseMessage[]): Promise { - const { data } = await this.session.openai.createChatCompletion({ + async agenerate( + messages: BaseMessage[], + parentEvent?: Event + ): Promise { + const baseRequestParams: CreateChatCompletionRequest = { model: this.model, temperature: this.temperature, max_tokens: this.maxTokens, @@ -71,8 +82,29 @@ export class ChatOpenAI implements BaseChatModel { role: ChatOpenAI.mapMessageType(message.type), content: message.content, })), - }); + }; + + if (this.callbackManager?.onLLMStream) { + const response = await this.session.openai.createChatCompletion( + { + ...baseRequestParams, + stream: true, + }, + { responseType: "stream" } + ); + const fullResponse = await aHandleOpenAIStream({ + response, + onLLMStream: this.callbackManager.onLLMStream, + parentEvent, + }); + return { generations: [[{ text: fullResponse }]] }; + } + + const response = await this.session.openai.createChatCompletion( + baseRequestParams + ); + const { data } = response; const content = data.choices[0].message?.content ?? ""; return { generations: [[{ text: content }]] }; } diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index bddb0ea999..543305cadf 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -7,11 +7,13 @@ import { import { Response } from "./Response"; import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer"; import { BaseRetriever } from "./Retriever"; +import { v4 as uuidv4 } from "uuid"; +import { Event } from "./callbacks/CallbackManager"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; import { QueryEngineTool, ToolMetadata } from "./Tool"; export interface BaseQueryEngine { - aquery(query: string): Promise; + aquery(query: string, parentEvent?: Event): Promise; } export class RetrieverQueryEngine implements BaseQueryEngine { @@ -20,12 +22,19 @@ export class RetrieverQueryEngine implements BaseQueryEngine { constructor(retriever: BaseRetriever) { this.retriever = retriever; - this.responseSynthesizer = new ResponseSynthesizer(); + const serviceContext: ServiceContext | undefined = + this.retriever.getServiceContext(); + this.responseSynthesizer = new ResponseSynthesizer({ serviceContext }); } - async aquery(query: string) { - const nodes = await this.retriever.aretrieve(query); - return this.responseSynthesizer.asynthesize(query, nodes); + async aquery(query: string, parentEvent?: Event) { + const _parentEvent: Event = parentEvent || { + id: uuidv4(), + type: "wrapper", + tags: ["final"], + }; + const nodes = await this.retriever.aretrieve(query, _parentEvent); + return this.responseSynthesizer.asynthesize(query, nodes, _parentEvent); } } @@ -64,7 +73,10 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { const questionGen = init.questionGen ?? new LLMQuestionGenerator(); const responseSynthesizer = init.responseSynthesizer ?? - new ResponseSynthesizer(new CompactAndRefine(serviceContext)); + new ResponseSynthesizer({ + responseBuilder: new CompactAndRefine(serviceContext), + serviceContext, + }); return new SubQuestionQueryEngine({ questionGen, @@ -78,21 +90,41 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { this.metadatas, query ); + + // groups final retrieval+synthesis operation + const parentEvent: Event = { + id: uuidv4(), + type: "wrapper", + tags: ["final"], + }; + + // groups all sub-queries + const subQueryParentEvent: Event = { + id: uuidv4(), + parentId: parentEvent.id, + type: "wrapper", + tags: ["intermediate"], + }; + const subQNodes = await Promise.all( - subQuestions.map((subQ) => this.aquerySubQ(subQ)) + subQuestions.map((subQ) => this.aquerySubQ(subQ, subQueryParentEvent)) ); + const nodes = subQNodes .filter((node) => node !== null) .map((node) => node as NodeWithScore); - return this.responseSynthesizer.asynthesize(query, nodes); + return this.responseSynthesizer.asynthesize(query, nodes, parentEvent); } - private async aquerySubQ(subQ: SubQuestion): Promise { + private async aquerySubQ( + subQ: SubQuestion, + parentEvent?: Event + ): Promise { try { const question = subQ.subQuestion; const queryEngine = this.queryEngines[subQ.toolName]; - const response = await queryEngine.aquery(question); + const response = await queryEngine.aquery(question, parentEvent); const responseText = response.response; const nodeText = `Sub question: ${question}\nResponse: ${responseText}}`; const node = new TextNode({ text: nodeText }); diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index f00cbf93fb..4aeedb55e7 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -1,4 +1,4 @@ -import { ChatGPTLLMPredictor } from "./LLMPredictor"; +import { ChatGPTLLMPredictor, BaseLLMPredictor } from "./LLMPredictor"; import { MetadataMode, NodeWithScore } from "./Node"; import { SimplePrompt, @@ -8,28 +8,38 @@ import { import { getBiggestPrompt } from "./PromptHelper"; import { Response } from "./Response"; import { ServiceContext } from "./ServiceContext"; +import { Event } from "./callbacks/CallbackManager"; interface BaseResponseBuilder { - agetResponse(query: string, textChunks: string[]): Promise; + agetResponse( + query: string, + textChunks: string[], + parentEvent?: Event + ): Promise; } export class SimpleResponseBuilder implements BaseResponseBuilder { - llmPredictor: ChatGPTLLMPredictor; + llmPredictor: BaseLLMPredictor; textQATemplate: SimplePrompt; - constructor() { - this.llmPredictor = new ChatGPTLLMPredictor(); + constructor(serviceContext?: ServiceContext) { + this.llmPredictor = + serviceContext?.llmPredictor ?? new ChatGPTLLMPredictor(); this.textQATemplate = defaultTextQaPrompt; } - async agetResponse(query: string, textChunks: string[]): Promise { + async agetResponse( + query: string, + textChunks: string[], + parentEvent?: Event + ): Promise { const input = { query, context: textChunks.join("\n\n"), }; const prompt = this.textQATemplate(input); - return this.llmPredictor.apredict(prompt, {}); + return this.llmPredictor.apredict(prompt, {}, parentEvent); } } @@ -178,23 +188,42 @@ export class TreeSummarize implements BaseResponseBuilder { } } -export function getResponseBuilder(): BaseResponseBuilder { - return new SimpleResponseBuilder(); +export function getResponseBuilder( + serviceContext?: ServiceContext +): SimpleResponseBuilder { + return new SimpleResponseBuilder(serviceContext); } // TODO replace with Logan's new response_sythesizers/factory.py export class ResponseSynthesizer { responseBuilder: BaseResponseBuilder; - - constructor(responseBuilder?: BaseResponseBuilder) { - this.responseBuilder = responseBuilder ?? getResponseBuilder(); + serviceContext?: ServiceContext; + + constructor({ + responseBuilder, + serviceContext, + }: { + responseBuilder?: BaseResponseBuilder; + serviceContext?: ServiceContext; + } = {}) { + this.serviceContext = serviceContext; + this.responseBuilder = + responseBuilder ?? getResponseBuilder(this.serviceContext); } - async asynthesize(query: string, nodes: NodeWithScore[]) { + async asynthesize( + query: string, + nodes: NodeWithScore[], + parentEvent?: Event + ) { let textChunks: string[] = nodes.map((node) => node.node.getContent(MetadataMode.NONE) ); - const response = await this.responseBuilder.agetResponse(query, textChunks); + const response = await this.responseBuilder.agetResponse( + query, + textChunks, + parentEvent + ); return new Response( response, nodes.map((node) => node.node) diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index 61928ff2c3..428be6bb68 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -1,6 +1,8 @@ import { VectorStoreIndex } from "./BaseIndex"; +import { globalsHelper } from "./GlobalsHelper"; import { NodeWithScore } from "./Node"; import { ServiceContext } from "./ServiceContext"; +import { Event } from "./callbacks/CallbackManager"; import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; import { VectorStoreQuery, @@ -8,7 +10,8 @@ import { } from "./storage/vectorStore/types"; export interface BaseRetriever { - aretrieve(query: string): Promise; + aretrieve(query: string, parentEvent?: Event): Promise; + getServiceContext(): ServiceContext; } export class VectorIndexRetriever implements BaseRetriever { @@ -21,7 +24,10 @@ export class VectorIndexRetriever implements BaseRetriever { this.serviceContext = this.index.serviceContext; } - async aretrieve(query: string): Promise { + async aretrieve( + query: string, + parentEvent?: Event + ): Promise { const queryEmbedding = await this.serviceContext.embedModel.aGetQueryEmbedding(query); @@ -41,6 +47,21 @@ export class VectorIndexRetriever implements BaseRetriever { }); } + if (this.serviceContext.callbackManager.onRetrieve) { + this.serviceContext.callbackManager.onRetrieve({ + query, + nodes: nodesWithScores, + event: globalsHelper.createEvent({ + parentEvent, + type: "retrieve", + }), + }); + } + return nodesWithScores; } + + getServiceContext(): ServiceContext { + return this.serviceContext; + } } diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index 9df16f9dc4..58570afd16 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -1,32 +1,37 @@ import { BaseEmbedding, OpenAIEmbedding } from "./Embedding"; import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor"; -import { BaseLanguageModel } from "./LanguageModel"; +import { ChatOpenAI } from "./LanguageModel"; import { NodeParser, SimpleNodeParser } from "./NodeParser"; import { PromptHelper } from "./PromptHelper"; +import { CallbackManager } from "./callbacks/CallbackManager"; export interface ServiceContext { llmPredictor: BaseLLMPredictor; promptHelper: PromptHelper; embedModel: BaseEmbedding; nodeParser: NodeParser; + callbackManager: CallbackManager; // llamaLogger: any; - // callbackManager: any; } export interface ServiceContextOptions { llmPredictor?: BaseLLMPredictor; - llm?: BaseLanguageModel; + llm?: ChatOpenAI; promptHelper?: PromptHelper; embedModel?: BaseEmbedding; nodeParser?: NodeParser; + callbackManager?: CallbackManager; // NodeParser arguments chunkSize?: number; chunkOverlap?: number; } export function serviceContextFromDefaults(options?: ServiceContextOptions) { + const callbackManager = options?.callbackManager ?? new CallbackManager(); const serviceContext: ServiceContext = { - llmPredictor: options?.llmPredictor ?? new ChatGPTLLMPredictor(), + llmPredictor: + options?.llmPredictor ?? + new ChatGPTLLMPredictor({ callbackManager, languageModel: options?.llm }), embedModel: options?.embedModel ?? new OpenAIEmbedding(), nodeParser: options?.nodeParser ?? @@ -35,6 +40,7 @@ export function serviceContextFromDefaults(options?: ServiceContextOptions) { chunkOverlap: options?.chunkOverlap, }), promptHelper: options?.promptHelper ?? new PromptHelper(), + callbackManager, }; return serviceContext; @@ -57,5 +63,8 @@ export function serviceContextFromServiceContext( if (options.nodeParser) { newServiceContext.nodeParser = options.nodeParser; } + if (options.callbackManager) { + newServiceContext.callbackManager = options.callbackManager; + } return newServiceContext; } diff --git a/packages/core/src/callbacks/CallbackManager.ts b/packages/core/src/callbacks/CallbackManager.ts new file mode 100644 index 0000000000..a35e061c10 --- /dev/null +++ b/packages/core/src/callbacks/CallbackManager.ts @@ -0,0 +1,72 @@ +import { ChatCompletionResponseMessageRoleEnum } from "openai"; +import { NodeWithScore } from "../Node"; + +/* + An event is a wrapper that groups related operations. + For example, during retrieve and synthesize, + a parent event wraps both operations, and each operation has it's own + event. In this case, both sub-events will share a parentId. +*/ + +export type EventTag = "intermediate" | "final"; +export type EventType = "retrieve" | "llmPredict" | "wrapper"; +export interface Event { + id: string; + type: EventType; + tags?: EventTag[]; + parentId?: string; +} + +interface BaseCallbackResponse { + event: Event; +} + +export interface StreamToken { + id: string; + object: string; + created: number; + model: string; + choices: { + index: number; + delta: { + content?: string; + role?: ChatCompletionResponseMessageRoleEnum; + }; + finish_reason: string | null; + }[]; +} + +export interface StreamCallbackResponse extends BaseCallbackResponse { + index: number; + isDone?: boolean; + token?: StreamToken; +} + +export interface RetrievalCallbackResponse extends BaseCallbackResponse { + query: string; + nodes: NodeWithScore[]; +} + +interface CallbackManagerMethods { + /* + onLLMStream is called when a token is streamed from the LLM. Defining this + callback auto sets the stream = True flag on the openAI createChatCompletion request. + */ + onLLMStream?: (params: StreamCallbackResponse) => Promise | void; + /* + onRetrieve is called as soon as the retriever finishes fetching relevant nodes. + This callback allows you to handle the retrieved nodes even if the synthesizer + is still running. + */ + onRetrieve?: (params: RetrievalCallbackResponse) => Promise | void; +} + +export class CallbackManager implements CallbackManagerMethods { + onLLMStream?: (params: StreamCallbackResponse) => Promise | void; + onRetrieve?: (params: RetrievalCallbackResponse) => Promise | void; + + constructor(handlers?: CallbackManagerMethods) { + this.onLLMStream = handlers?.onLLMStream; + this.onRetrieve = handlers?.onRetrieve; + } +} diff --git a/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts b/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts new file mode 100644 index 0000000000..b477806086 --- /dev/null +++ b/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts @@ -0,0 +1,64 @@ +import { globalsHelper } from "../../GlobalsHelper"; +import { StreamCallbackResponse, Event } from "../CallbackManager"; +import { StreamToken } from "../CallbackManager"; + +export async function aHandleOpenAIStream({ + response, + onLLMStream, + parentEvent, +}: { + response: any; + onLLMStream: (data: StreamCallbackResponse) => void; + parentEvent?: Event; +}): Promise { + const event = globalsHelper.createEvent({ + parentEvent, + type: "llmPredict", + }); + const stream = __astreamCompletion(response.data as any); + let index = 0; + let cumulativeText = ""; + for await (const message of stream) { + const token: StreamToken = JSON.parse(message); + const { content = "", role = "assistant" } = token?.choices[0]?.delta ?? {}; + // ignore the first token + if (!content && role === "assistant" && index === 0) { + continue; + } + cumulativeText += content; + onLLMStream?.({ event, index, token }); + index++; + } + onLLMStream?.({ event, index, isDone: true }); + return cumulativeText; +} + +/* + sources: + - https://github.com/openai/openai-node/issues/18#issuecomment-1372047643 + - https://github.com/openai/openai-node/issues/18#issuecomment-1595805163 +*/ +async function* __astreamCompletion(data: string[]) { + yield* __alinesToText(__achunksToLines(data)); +} + +async function* __alinesToText(linesAsync: string | void | any) { + for await (const line of linesAsync) { + yield line.substring("data :".length); + } +} + +async function* __achunksToLines(chunksAsync: string[]) { + let previous = ""; + for await (const chunk of chunksAsync) { + const bufferChunk = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk); + previous += bufferChunk; + let eolIndex; + while ((eolIndex = previous.indexOf("\n")) >= 0) { + const line = previous.slice(0, eolIndex + 1).trimEnd(); + if (line === "data: [DONE]") break; + if (line.startsWith("data: ")) yield line; + previous = previous.slice(eolIndex + 1); + } + } +} diff --git a/packages/core/src/index/list/ListIndexRetriever.ts b/packages/core/src/index/list/ListIndexRetriever.ts index 33e7420755..15b6d9c2e8 100644 --- a/packages/core/src/index/list/ListIndexRetriever.ts +++ b/packages/core/src/index/list/ListIndexRetriever.ts @@ -10,6 +10,8 @@ import { } from "./utils"; import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt"; import _ from "lodash"; +import { globalsHelper } from "../../GlobalsHelper"; +import { Event } from "../../callbacks/CallbackManager"; /** * Simple retriever for ListIndex that returns all nodes @@ -21,13 +23,33 @@ export class ListIndexRetriever implements BaseRetriever { this.index = index; } - async aretrieve(query: string): Promise { + async aretrieve( + query: string, + parentEvent?: Event + ): Promise { const nodeIds = this.index.indexStruct.nodes; const nodes = await this.index.docStore.getNodes(nodeIds); - return nodes.map((node) => ({ + const result = nodes.map((node) => ({ node: node, score: 1, })); + + if (this.index.serviceContext.callbackManager.onRetrieve) { + this.index.serviceContext.callbackManager.onRetrieve({ + query, + nodes: result, + event: globalsHelper.createEvent({ + parentEvent, + type: "retrieve", + }), + }); + } + + return result; + } + + getServiceContext(): ServiceContext { + return this.index.serviceContext; } } @@ -59,7 +81,10 @@ export class ListIndexLLMRetriever implements BaseRetriever { this.serviceContext = serviceContext || index.serviceContext; } - async aretrieve(query: string): Promise { + async aretrieve( + query: string, + parentEvent?: Event + ): Promise { const nodeIds = this.index.indexStruct.nodes; const results: NodeWithScore[] = []; @@ -91,6 +116,22 @@ export class ListIndexLLMRetriever implements BaseRetriever { results.push(...nodeWithScores); } + + if (this.serviceContext.callbackManager.onRetrieve) { + this.serviceContext.callbackManager.onRetrieve({ + query, + nodes: results, + event: globalsHelper.createEvent({ + parentEvent, + type: "retrieve", + }), + }); + } + return results; } + + getServiceContext(): ServiceContext { + return this.serviceContext; + } } diff --git a/packages/core/src/storage/docStore/utils.ts b/packages/core/src/storage/docStore/utils.ts index a7329df67e..3df1c1b303 100644 --- a/packages/core/src/storage/docStore/utils.ts +++ b/packages/core/src/storage/docStore/utils.ts @@ -23,7 +23,6 @@ export function jsonToDoc(docDict: Record): BaseNode { hash: dataDict.hash, }); } else if (docType === ObjectType.TEXT) { - console.log({ dataDict }); doc = new TextNode({ text: dataDict.text, id_: dataDict.id_, diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts new file mode 100644 index 0000000000..26bcbee2f7 --- /dev/null +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -0,0 +1,208 @@ +import { VectorStoreIndex } from "../BaseIndex"; +import { OpenAIEmbedding } from "../Embedding"; +import { ChatOpenAI } from "../LanguageModel"; +import { Document } from "../Node"; +import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; +import { + CallbackManager, + RetrievalCallbackResponse, + StreamCallbackResponse, +} from "../callbacks/CallbackManager"; +import { ListIndex } from "../index/list"; +import { mockEmbeddingModel, mockLlmGeneration } from "./utility/mockOpenAI"; + +// Mock the OpenAI getOpenAISession function during testing +jest.mock("../openai", () => { + return { + getOpenAISession: jest.fn().mockImplementation(() => null), + }; +}); + +describe("CallbackManager: onLLMStream and onRetrieve", () => { + let serviceContext: ServiceContext; + let streamCallbackData: StreamCallbackResponse[] = []; + let retrieveCallbackData: RetrievalCallbackResponse[] = []; + let document: Document; + + beforeAll(async () => { + document = new Document({ text: "Author: My name is Paul Graham" }); + const callbackManager = new CallbackManager({ + onLLMStream: (data) => { + streamCallbackData.push(data); + }, + onRetrieve: (data) => { + retrieveCallbackData.push(data); + }, + }); + + const languageModel = new ChatOpenAI({ + model: "gpt-3.5-turbo", + callbackManager, + }); + mockLlmGeneration({ languageModel, callbackManager }); + + const embedModel = new OpenAIEmbedding(); + mockEmbeddingModel(embedModel); + + serviceContext = serviceContextFromDefaults({ + callbackManager, + llm: languageModel, + embedModel, + }); + }); + + beforeEach(() => { + streamCallbackData = []; + retrieveCallbackData = []; + }); + + afterAll(() => { + jest.clearAllMocks(); + }); + + test("For VectorStoreIndex w/ a SimpleResponseBuilder", async () => { + const vectorStoreIndex = await VectorStoreIndex.fromDocuments( + [document], + undefined, + serviceContext + ); + const queryEngine = vectorStoreIndex.asQueryEngine(); + const query = "What is the author's name?"; + const response = await queryEngine.aquery(query); + expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); + expect(streamCallbackData).toEqual([ + { + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "llmPredict", + tags: ["final"], + }, + index: 0, + token: { + id: "id", + object: "object", + created: 1, + model: "model", + choices: expect.any(Array), + }, + }, + { + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "llmPredict", + tags: ["final"], + }, + index: 1, + token: { + id: "id", + object: "object", + created: 1, + model: "model", + choices: expect.any(Array), + }, + }, + { + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "llmPredict", + tags: ["final"], + }, + index: 2, + isDone: true, + }, + ]); + expect(retrieveCallbackData).toEqual([ + { + query: query, + nodes: expect.any(Array), + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "retrieve", + tags: ["final"], + }, + }, + ]); + // both retrieval and streaming should have + // the same parent event + expect(streamCallbackData[0].event.parentId).toBe( + retrieveCallbackData[0].event.parentId + ); + }); + + test("For ListIndex w/ a ListIndexRetriever", async () => { + const listIndex = await ListIndex.fromDocuments( + [document], + undefined, + serviceContext + ); + const queryEngine = listIndex.asQueryEngine(); + const query = "What is the author's name?"; + const response = await queryEngine.aquery(query); + expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); + expect(streamCallbackData).toEqual([ + { + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "llmPredict", + tags: ["final"], + }, + index: 0, + token: { + id: "id", + object: "object", + created: 1, + model: "model", + choices: expect.any(Array), + }, + }, + { + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "llmPredict", + tags: ["final"], + }, + index: 1, + token: { + id: "id", + object: "object", + created: 1, + model: "model", + choices: expect.any(Array), + }, + }, + { + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "llmPredict", + tags: ["final"], + }, + index: 2, + isDone: true, + }, + ]); + expect(retrieveCallbackData).toEqual([ + { + query: query, + nodes: expect.any(Array), + event: { + id: expect.any(String), + parentId: expect.any(String), + type: "retrieve", + tags: ["final"], + }, + }, + ]); + // both retrieval and streaming should have + // the same parent event + expect(streamCallbackData[0].event.parentId).toBe( + retrieveCallbackData[0].event.parentId + ); + }); +}); diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts new file mode 100644 index 0000000000..67631a9acd --- /dev/null +++ b/packages/core/src/tests/utility/mockOpenAI.ts @@ -0,0 +1,72 @@ +import { OpenAIEmbedding } from "../../Embedding"; +import { globalsHelper } from "../../GlobalsHelper"; +import { BaseMessage, ChatOpenAI } from "../../LanguageModel"; +import { CallbackManager, Event } from "../../callbacks/CallbackManager"; + +export function mockLlmGeneration({ + languageModel, + callbackManager, +}: { + languageModel: ChatOpenAI; + callbackManager: CallbackManager; +}) { + jest + .spyOn(languageModel, "agenerate") + .mockImplementation( + async (messages: BaseMessage[], parentEvent?: Event) => { + const text = "MOCK_TOKEN_1-MOCK_TOKEN_2"; + const event = globalsHelper.createEvent({ + parentEvent, + type: "llmPredict", + }); + if (callbackManager?.onLLMStream) { + const chunks = text.split("-"); + for (let i = 0; i < chunks.length; i++) { + const chunk = chunks[i]; + callbackManager?.onLLMStream({ + event, + index: i, + token: { + id: "id", + object: "object", + created: 1, + model: "model", + choices: [ + { + index: 0, + delta: { + content: chunk, + }, + finish_reason: null, + }, + ], + }, + }); + } + callbackManager?.onLLMStream({ + event, + index: chunks.length, + isDone: true, + }); + } + return new Promise((resolve) => { + resolve({ + generations: [[{ text }]], + }); + }); + } + ); +} + +export function mockEmbeddingModel(embedModel: OpenAIEmbedding) { + jest.spyOn(embedModel, "aGetTextEmbedding").mockImplementation(async (x) => { + return new Promise((resolve) => { + resolve([1, 0, 0, 0, 0, 0]); + }); + }); + jest.spyOn(embedModel, "aGetQueryEmbedding").mockImplementation(async (x) => { + return new Promise((resolve) => { + resolve([0, 1, 0, 0, 0, 0]); + }); + }); +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 0bd001b095..7fc78117e7 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -36,7 +36,7 @@ importers: specifier: ^29.1.0 version: 29.1.0(@babel/core@7.22.5)(jest@29.5.0)(typescript@4.9.5) turbo: - specifier: ^1.10.7 + specifier: ^1.10.6 version: 1.10.7 apps/docs: