Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT4All API: Add Streaming to the Completions endpoint #1129

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 75 additions & 33 deletions gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json

from fastapi import APIRouter, Depends, Response, Security, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Dict
from typing import List, Dict, Iterable, AsyncIterable
import logging
from uuid import uuid4
from api_v1.settings import settings
Expand All @@ -10,6 +13,7 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml

class CompletionRequest(BaseModel):
Expand All @@ -28,10 +32,13 @@ class CompletionChoice(BaseModel):
logprobs: float
finish_reason: str


class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int


class CompletionResponse(BaseModel):
id: str
object: str = 'text_completion'
Expand All @@ -41,46 +48,81 @@ class CompletionResponse(BaseModel):
usage: CompletionUsage


class CompletionStreamResponse(BaseModel):
id: str
object: str = 'text_completion'
created: int
model: str
choices: List[CompletionChoice]


router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])


def stream_completion(output: Iterable, base_response: CompletionStreamResponse):
"""
Streams a GPT4All output to the client.

Args:
output: The output of GPT4All.generate(), which is an iterable of tokens.
base_response: The base response object, which is cloned and modified for each token.

Returns:
A Generator of CompletionStreamResponse objects, which are serialized to JSON Event Stream format.
"""
for token in output:
chunk = base_response.copy()
chunk.choices = [dict(CompletionChoice(
text=token,
index=0,
logprobs=-1,
finish_reason=''
))]
yield f"data: {json.dumps(dict(chunk))}\n\n"


@router.post("/", response_model=CompletionResponse)
async def completions(request: CompletionRequest):
'''
Completes a GPT4All model response.
'''

# global model
if request.stream:
raise NotImplementedError("Streaming is not yet implements")

model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)

output = model.generate(prompt=request.prompt,
n_predict = request.max_tokens,
top_k = 20,
top_p = request.top_p,
temp=request.temperature,
n_batch = 1024,
repeat_penalty = 1.2,
repeat_last_n = 10,
context_erase = 0)


return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, #TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)


n_predict=request.max_tokens,
streaming=request.stream,
top_k=20,
top_p=request.top_p,
temp=request.temperature,
n_batch=1024,
repeat_penalty=1.2,
repeat_last_n=10)

# If streaming, we need to return a StreamingResponse
if request.stream:
base_chunk = CompletionStreamResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[]
)
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
media_type="text/event-stream")
else:
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, #TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)
22 changes: 19 additions & 3 deletions gpt4all-api/gpt4all_api/app/tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,29 @@ def test_completion():
assert len(response['choices'][0]['text']) > len(prompt)
print(response)


def test_streaming_completion():
model = "gpt4all-j-v1.3-groovy"
prompt = "Who is Michael Jordan?"
tokens = []
for resp in openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=50,
temperature=0.28,
top_p=0.95,
n=1,
echo=True,
stream=True):
tokens.append(resp.choices[0].text)

assert (len(tokens) > 0)
assert (len("".join(tokens)) > len(prompt))

# def test_chat_completions():
# model = "gpt4all-j-v1.3-groovy"
# prompt = "Who is Michael Jordan?"
# response = openai.ChatCompletion.create(
# model=model,
# messages=[]
# )



2 changes: 1 addition & 1 deletion gpt4all-api/gpt4all_api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ requests>=2.24.0
ujson>=2.0.2
fastapi>=0.95.0
Jinja2>=3.0
gpt4all==0.2.3
gpt4all==1.0.1
pytest
openai