Skip to content

Commit

Permalink
Update to gpt4all version 1.0.1. Implement the Streaming version of t…
Browse files Browse the repository at this point in the history
…he completions endpoint. Implemented an openai python client test for the new streaming functionality. (#1129)

Co-authored-by: Brandon <[email protected]>
  • Loading branch information
TheDropZone and bbeiler-ridgeline committed Jul 6, 2023
1 parent affd0af commit fb576fb
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 37 deletions.
108 changes: 75 additions & 33 deletions gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json

from fastapi import APIRouter, Depends, Response, Security, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Dict
from typing import List, Dict, Iterable, AsyncIterable
import logging
from uuid import uuid4
from api_v1.settings import settings
Expand All @@ -10,6 +13,7 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml

class CompletionRequest(BaseModel):
Expand All @@ -28,10 +32,13 @@ class CompletionChoice(BaseModel):
logprobs: float
finish_reason: str


class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int


class CompletionResponse(BaseModel):
id: str
object: str = 'text_completion'
Expand All @@ -41,46 +48,81 @@ class CompletionResponse(BaseModel):
usage: CompletionUsage


class CompletionStreamResponse(BaseModel):
id: str
object: str = 'text_completion'
created: int
model: str
choices: List[CompletionChoice]


router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])


def stream_completion(output: Iterable, base_response: CompletionStreamResponse):
"""
Streams a GPT4All output to the client.
Args:
output: The output of GPT4All.generate(), which is an iterable of tokens.
base_response: The base response object, which is cloned and modified for each token.
Returns:
A Generator of CompletionStreamResponse objects, which are serialized to JSON Event Stream format.
"""
for token in output:
chunk = base_response.copy()
chunk.choices = [dict(CompletionChoice(
text=token,
index=0,
logprobs=-1,
finish_reason=''
))]
yield f"data: {json.dumps(dict(chunk))}\n\n"


@router.post("/", response_model=CompletionResponse)
async def completions(request: CompletionRequest):
'''
Completes a GPT4All model response.
'''

# global model
if request.stream:
raise NotImplementedError("Streaming is not yet implements")

model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)

output = model.generate(prompt=request.prompt,
n_predict = request.max_tokens,
top_k = 20,
top_p = request.top_p,
temp=request.temperature,
n_batch = 1024,
repeat_penalty = 1.2,
repeat_last_n = 10,
context_erase = 0)


return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, #TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)


n_predict=request.max_tokens,
streaming=request.stream,
top_k=20,
top_p=request.top_p,
temp=request.temperature,
n_batch=1024,
repeat_penalty=1.2,
repeat_last_n=10)

# If streaming, we need to return a StreamingResponse
if request.stream:
base_chunk = CompletionStreamResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[]
)
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
media_type="text/event-stream")
else:
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, #TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)
22 changes: 19 additions & 3 deletions gpt4all-api/gpt4all_api/app/tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,29 @@ def test_completion():
assert len(response['choices'][0]['text']) > len(prompt)
print(response)


def test_streaming_completion():
model = "gpt4all-j-v1.3-groovy"
prompt = "Who is Michael Jordan?"
tokens = []
for resp in openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=50,
temperature=0.28,
top_p=0.95,
n=1,
echo=True,
stream=True):
tokens.append(resp.choices[0].text)

assert (len(tokens) > 0)
assert (len("".join(tokens)) > len(prompt))

# def test_chat_completions():
# model = "gpt4all-j-v1.3-groovy"
# prompt = "Who is Michael Jordan?"
# response = openai.ChatCompletion.create(
# model=model,
# messages=[]
# )



2 changes: 1 addition & 1 deletion gpt4all-api/gpt4all_api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ requests>=2.24.0
ujson>=2.0.2
fastapi>=0.95.0
Jinja2>=3.0
gpt4all==0.2.3
gpt4all==1.0.1
pytest
openai

0 comments on commit fb576fb

Please sign in to comment.