Skip to content

Commit

Permalink
Merge pull request #8 from run-llama/stream_responses
Browse files Browse the repository at this point in the history
[Feature] CallbackManager with onLLMStream and onRetrieve
  • Loading branch information
yisding committed Jul 10, 2023
2 parents 1e51009 + 0bec460 commit 2f468ab
Show file tree
Hide file tree
Showing 15 changed files with 701 additions and 74 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"prettier": "^2.8.8",
"prettier-plugin-tailwindcss": "^0.3.0",
"ts-jest": "^29.1.0",
"turbo": "^1.10.7"
"turbo": "^1.10.6"
},
"packageManager": "[email protected]",
"name": "llamascript"
Expand Down
32 changes: 19 additions & 13 deletions packages/core/src/ChatEngine.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import {
BaseChatModel,
BaseMessage,
ChatOpenAI,
LLMResult,
} from "./LanguageModel";
import { BaseChatModel, BaseMessage, ChatOpenAI } from "./LanguageModel";
import { TextNode } from "./Node";
import {
SimplePrompt,
Expand All @@ -15,6 +10,8 @@ import { BaseQueryEngine } from "./QueryEngine";
import { Response } from "./Response";
import { BaseRetriever } from "./Retriever";
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
import { v4 as uuidv4 } from "uuid";
import { Event } from "./callbacks/CallbackManager";

interface ChatEngine {
chatRepl(): void;
Expand All @@ -30,7 +27,7 @@ export class SimpleChatEngine implements ChatEngine {

constructor(init?: Partial<SimpleChatEngine>) {
this.chatHistory = init?.chatHistory ?? [];
this.llm = init?.llm ?? new ChatOpenAI();
this.llm = init?.llm ?? new ChatOpenAI({ model: "gpt-3.5-turbo" });
}

chatRepl() {
Expand Down Expand Up @@ -125,7 +122,8 @@ export class ContextChatEngine implements ChatEngine {
chatHistory?: BaseMessage[];
}) {
this.retriever = init.retriever;
this.chatModel = init.chatModel ?? new ChatOpenAI("gpt-3.5-turbo-16k");
this.chatModel =
init.chatModel ?? new ChatOpenAI({ model: "gpt-3.5-turbo-16k" });
this.chatHistory = init?.chatHistory ?? [];
}

Expand All @@ -136,7 +134,15 @@ export class ContextChatEngine implements ChatEngine {
async achat(message: string, chatHistory?: BaseMessage[] | undefined) {
chatHistory = chatHistory ?? this.chatHistory;

const sourceNodesWithScore = await this.retriever.aretrieve(message);
const parentEvent: Event = {
id: uuidv4(),
type: "wrapper",
tags: ["final"],
};
const sourceNodesWithScore = await this.retriever.aretrieve(
message,
parentEvent
);

const systemMessage: BaseMessage = {
content: contextSystemPrompt({
Expand All @@ -149,10 +155,10 @@ export class ContextChatEngine implements ChatEngine {

chatHistory.push({ content: message, type: "human" });

const response = await this.chatModel.agenerate([
systemMessage,
...chatHistory,
]);
const response = await this.chatModel.agenerate(
[systemMessage, ...chatHistory],
parentEvent
);
const text = response.generations[0][0].text;

chatHistory.push({ content: text, type: "ai" });
Expand Down
21 changes: 21 additions & 0 deletions packages/core/src/GlobalsHelper.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import { Event, EventTag, EventType } from "./callbacks/CallbackManager";
import { v4 as uuidv4 } from "uuid";

class GlobalsHelper {
defaultTokenizer: ((text: string) => string[]) | null = null;

Expand All @@ -13,6 +16,24 @@ class GlobalsHelper {
};
return this.defaultTokenizer;
}

createEvent({
parentEvent,
type,
tags,
}: {
parentEvent?: Event;
type: EventType;
tags?: EventTag[];
}): Event {
return {
id: uuidv4(),
type,
// inherit parent tags if tags not set
tags: tags || parentEvent?.tags,
parentId: parentEvent?.id,
};
}
}

export const globalsHelper = new GlobalsHelper();
56 changes: 38 additions & 18 deletions packages/core/src/LLMPredictor.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,50 @@
import { ChatOpenAI } from "./LanguageModel";
import { SimplePrompt } from "./Prompt";
import { CallbackManager, Event } from "./callbacks/CallbackManager";

// TODO change this to LLM class
export interface BaseLLMPredictor {
getLlmMetadata(): Promise<any>;
apredict(
prompt: string | SimplePrompt,
input?: Record<string, string>
input?: Record<string, string>,
parentEvent?: Event
): Promise<string>;
// stream(prompt: string, options: any): Promise<any>;
}

// TODO change this to LLM class
export class ChatGPTLLMPredictor implements BaseLLMPredictor {
llm: string;
model: string;
retryOnThrottling: boolean;
languageModel: ChatOpenAI;
callbackManager?: CallbackManager;

constructor(
llm: string = "gpt-3.5-turbo",
retryOnThrottling: boolean = true
props:
| {
model?: string;
retryOnThrottling?: boolean;
callbackManager?: CallbackManager;
languageModel?: ChatOpenAI;
}
| undefined = undefined
) {
this.llm = llm;
const {
model = "gpt-3.5-turbo",
retryOnThrottling = true,
callbackManager,
languageModel,
} = props || {};
this.model = model;
this.callbackManager = callbackManager;
this.retryOnThrottling = retryOnThrottling;

this.languageModel = new ChatOpenAI(this.llm);
this.languageModel =
languageModel ??
new ChatOpenAI({
model: this.model,
callbackManager: this.callbackManager,
});
}

async getLlmMetadata() {
Expand All @@ -33,22 +53,22 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor {

async apredict(
prompt: string | SimplePrompt,
input?: Record<string, string>
input?: Record<string, string>,
parentEvent?: Event
): Promise<string> {
if (typeof prompt === "string") {
const result = await this.languageModel.agenerate([
{
content: prompt,
type: "human",
},
]);
const result = await this.languageModel.agenerate(
[
{
content: prompt,
type: "human",
},
],
parentEvent
);
return result.generations[0][0].text;
} else {
return this.apredict(prompt(input ?? {}));
}
}

// async stream(prompt: string, options: any) {
// console.log("stream");
// }
}
48 changes: 40 additions & 8 deletions packages/core/src/LanguageModel.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { CallbackManager, Event } from "./callbacks/CallbackManager";
import { aHandleOpenAIStream } from "./callbacks/utility/aHandleOpenAIStream";
import {
ChatCompletionRequestMessageRoleEnum,
Configuration,
CreateChatCompletionRequest,
OpenAISession,
OpenAIWrapper,
getOpenAISession,
} from "./openai";

Expand All @@ -25,7 +26,7 @@ export interface LLMResult {
}

export interface BaseChatModel extends BaseLanguageModel {
agenerate(messages: BaseMessage[]): Promise<LLMResult>;
agenerate(messages: BaseMessage[], parentEvent?: Event): Promise<LLMResult>;
}

export class ChatOpenAI implements BaseChatModel {
Expand All @@ -36,11 +37,18 @@ export class ChatOpenAI implements BaseChatModel {
maxRetries: number = 6;
n: number = 1;
maxTokens?: number;

session: OpenAISession;
callbackManager?: CallbackManager;

constructor(model: string = "gpt-3.5-turbo") {
constructor({
model = "gpt-3.5-turbo",
callbackManager,
}: {
model: string;
callbackManager?: CallbackManager;
}) {
this.model = model;
this.callbackManager = callbackManager;
this.session = getOpenAISession();
}

Expand All @@ -61,8 +69,11 @@ export class ChatOpenAI implements BaseChatModel {
}
}

async agenerate(messages: BaseMessage[]): Promise<LLMResult> {
const { data } = await this.session.openai.createChatCompletion({
async agenerate(
messages: BaseMessage[],
parentEvent?: Event
): Promise<LLMResult> {
const baseRequestParams: CreateChatCompletionRequest = {
model: this.model,
temperature: this.temperature,
max_tokens: this.maxTokens,
Expand All @@ -71,8 +82,29 @@ export class ChatOpenAI implements BaseChatModel {
role: ChatOpenAI.mapMessageType(message.type),
content: message.content,
})),
});
};

if (this.callbackManager?.onLLMStream) {
const response = await this.session.openai.createChatCompletion(
{
...baseRequestParams,
stream: true,
},
{ responseType: "stream" }
);
const fullResponse = await aHandleOpenAIStream({
response,
onLLMStream: this.callbackManager.onLLMStream,
parentEvent,
});
return { generations: [[{ text: fullResponse }]] };
}

const response = await this.session.openai.createChatCompletion(
baseRequestParams
);

const { data } = response;
const content = data.choices[0].message?.content ?? "";
return { generations: [[{ text: content }]] };
}
Expand Down
52 changes: 42 additions & 10 deletions packages/core/src/QueryEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import {
import { Response } from "./Response";
import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer";
import { BaseRetriever } from "./Retriever";
import { v4 as uuidv4 } from "uuid";
import { Event } from "./callbacks/CallbackManager";
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
import { QueryEngineTool, ToolMetadata } from "./Tool";

export interface BaseQueryEngine {
aquery(query: string): Promise<Response>;
aquery(query: string, parentEvent?: Event): Promise<Response>;
}

export class RetrieverQueryEngine implements BaseQueryEngine {
Expand All @@ -20,12 +22,19 @@ export class RetrieverQueryEngine implements BaseQueryEngine {

constructor(retriever: BaseRetriever) {
this.retriever = retriever;
this.responseSynthesizer = new ResponseSynthesizer();
const serviceContext: ServiceContext | undefined =
this.retriever.getServiceContext();
this.responseSynthesizer = new ResponseSynthesizer({ serviceContext });
}

async aquery(query: string) {
const nodes = await this.retriever.aretrieve(query);
return this.responseSynthesizer.asynthesize(query, nodes);
async aquery(query: string, parentEvent?: Event) {
const _parentEvent: Event = parentEvent || {
id: uuidv4(),
type: "wrapper",
tags: ["final"],
};
const nodes = await this.retriever.aretrieve(query, _parentEvent);
return this.responseSynthesizer.asynthesize(query, nodes, _parentEvent);
}
}

Expand Down Expand Up @@ -64,7 +73,10 @@ export class SubQuestionQueryEngine implements BaseQueryEngine {
const questionGen = init.questionGen ?? new LLMQuestionGenerator();
const responseSynthesizer =
init.responseSynthesizer ??
new ResponseSynthesizer(new CompactAndRefine(serviceContext));
new ResponseSynthesizer({
responseBuilder: new CompactAndRefine(serviceContext),
serviceContext,
});

return new SubQuestionQueryEngine({
questionGen,
Expand All @@ -78,21 +90,41 @@ export class SubQuestionQueryEngine implements BaseQueryEngine {
this.metadatas,
query
);

// groups final retrieval+synthesis operation
const parentEvent: Event = {
id: uuidv4(),
type: "wrapper",
tags: ["final"],
};

// groups all sub-queries
const subQueryParentEvent: Event = {
id: uuidv4(),
parentId: parentEvent.id,
type: "wrapper",
tags: ["intermediate"],
};

const subQNodes = await Promise.all(
subQuestions.map((subQ) => this.aquerySubQ(subQ))
subQuestions.map((subQ) => this.aquerySubQ(subQ, subQueryParentEvent))
);

const nodes = subQNodes
.filter((node) => node !== null)
.map((node) => node as NodeWithScore);
return this.responseSynthesizer.asynthesize(query, nodes);
return this.responseSynthesizer.asynthesize(query, nodes, parentEvent);
}

private async aquerySubQ(subQ: SubQuestion): Promise<NodeWithScore | null> {
private async aquerySubQ(
subQ: SubQuestion,
parentEvent?: Event
): Promise<NodeWithScore | null> {
try {
const question = subQ.subQuestion;
const queryEngine = this.queryEngines[subQ.toolName];

const response = await queryEngine.aquery(question);
const response = await queryEngine.aquery(question, parentEvent);
const responseText = response.response;
const nodeText = `Sub question: ${question}\nResponse: ${responseText}}`;
const node = new TextNode({ text: nodeText });
Expand Down
Loading

0 comments on commit 2f468ab

Please sign in to comment.