Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Streaming support #195

Merged
merged 1 commit into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.theokanning.openai.completion;

import lombok.Data;
import java.util.List;

/**
* Object containing a response chunk from the completions streaming api.
*
* https://beta.openai.com/docs/api-reference/completions/create
*/
@Data
public class CompletionChunk {
/**
* A unique id assigned to this completion.
*/
String id;

/**https://beta.openai.com/docs/api-reference/create-completion
* The type of object returned, should be "text_completion"
*/
String object;

/**
* The creation time in epoch seconds.
*/
long created;

/**
* The GPT-3 model used.
*/
String model;

/**
* A list of generated completions.
*/
List<CompletionChoice> choices;
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package com.theokanning.openai.completion.chat;
import com.fasterxml.jackson.annotation.JsonAlias;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;

Expand All @@ -14,8 +15,9 @@ public class ChatCompletionChoice {
Integer index;

/**
* The {@link ChatMessageRole#assistant} message which was generated.
* The {@link ChatMessageRole#assistant} message or delta (when streaming) which was generated
*/
@JsonAlias("delta")
ChatMessage message;

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.theokanning.openai.completion.chat;
import lombok.Data;

import java.util.List;

/**
* Object containing a response chunk from the chat completions streaming api.
*/
@Data
public class ChatCompletionChunk {
/**
* Unique id assigned to this chat completion.
*/
String id;

/**
* The type of object returned, should be "chat.completion.chunk"
*/
String object;

/**
* The creation time in epoch seconds.
*/
long created;

/**
* The GPT-3.5 model used.
*/
String model;

/**
* A list of all generated completions.
*/
List<ChatCompletionChoice> choices;
}
10 changes: 10 additions & 0 deletions client/src/main/java/com/theokanning/openai/OpenAiApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import io.reactivex.Single;
import okhttp3.MultipartBody;
import okhttp3.RequestBody;
import okhttp3.ResponseBody;
import retrofit2.Call;
import retrofit2.http.*;

public interface OpenAiApi {
Expand All @@ -34,10 +36,18 @@ public interface OpenAiApi {

@POST("/v1/completions")
Single<CompletionResult> createCompletion(@Body CompletionRequest request);

@Streaming
@POST("/v1/completions")
Call<ResponseBody> createCompletionStream(@Body CompletionRequest request);

@POST("/v1/chat/completions")
Single<ChatCompletionResult> createChatCompletion(@Body ChatCompletionRequest request);

@Streaming
@POST("/v1/chat/completions")
Call<ResponseBody> createChatCompletionStream(@Body ChatCompletionRequest request);

@Deprecated
@POST("/v1/engines/{engine_id}/completions")
Single<CompletionResult> createCompletion(@Path("engine_id") String engineId, @Body CompletionRequest request);
Expand Down
79 changes: 79 additions & 0 deletions example/src/main/java/example/OpenAiApiStreamExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package example;

import com.theokanning.openai.service.OpenAiService;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;

public class OpenAiApiStreamExample {
public static void main(String... args) {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token);

System.out.println("\nCreating completion...");
CompletionRequest completionRequest = CompletionRequest.builder()
.model("ada")
.prompt("Somebody once told me the world is gonna roll me")
.echo(true)
.user("testing")
.n(3)
.build();

/*
Note: when using blockingForEach the calling Thread waits until the loop finishes.
Use forEach instaed of blockignForEach if you don't want the calling Thread to wait.
*/

// stream raw bytes
service
.streamCompletionBytes(completionRequest)
.doOnError( e -> {
e.printStackTrace();
})
.blockingForEach( bytes -> {
System.out.print(new String(bytes));
});

// stream CompletionChunks
service
.streamCompletion(completionRequest)
.doOnError( e -> {
e.printStackTrace();
})
.blockingForEach(System.out::println);


final List<ChatMessage> messages = new ArrayList<>();
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a dog and will speak as such.");
messages.add(systemMessage);

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo")
.messages(messages)
.n(5)
.maxTokens(50)
.logitBias(new HashMap<>())
.build();

// stream ChatCompletionChunks
service
.streamChatCompletion(chatCompletionRequest)
.doOnError( e -> {
e.printStackTrace();
})
.blockingForEach(System.out::println);

/*
* shutdown the OkHttp ExecutorService to
* exit immediately after the loops have finished
*/
service.shutdownExecutor();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import com.theokanning.openai.OpenAiApi;
import com.theokanning.openai.OpenAiError;
import com.theokanning.openai.OpenAiHttpException;
import com.theokanning.openai.completion.CompletionChunk;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.edit.EditRequest;
Expand All @@ -27,17 +29,22 @@
import com.theokanning.openai.model.Model;
import com.theokanning.openai.moderation.ModerationRequest;
import com.theokanning.openai.moderation.ModerationResult;

import io.reactivex.BackpressureStrategy;
import io.reactivex.Flowable;
import io.reactivex.Single;
import okhttp3.*;
import retrofit2.HttpException;
import retrofit2.Retrofit;
import retrofit2.Call;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;

import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;

public class OpenAiService {
Expand All @@ -47,6 +54,7 @@ public class OpenAiService {
private static final ObjectMapper errorMapper = defaultObjectMapper();

private final OpenAiApi api;
private final ExecutorService executorService;

/**
* Creates a new OpenAiService that wraps OpenAiApi
Expand All @@ -64,17 +72,29 @@ public OpenAiService(final String token) {
* @param timeout http read timeout, Duration.ZERO means no timeout
*/
public OpenAiService(final String token, final Duration timeout) {
this(buildApi(token, timeout));
this(defaultClient(token, timeout));
}

/**
* Creates a new OpenAiService that wraps OpenAiApi
*
* @param client OkHttpClient to be used for api calls
*/
public OpenAiService(OkHttpClient client){
this(buildApi(client), client.dispatcher().executorService());
}

/**
* Creates a new OpenAiService that wraps OpenAiApi.
* Use this if you need more customization.
* The ExecutoryService must be the one you get from the client you created the api with
* otherwise shutdownExecutor() won't work. Use this if you need more customization.
*
* @param api OpenAiApi instance to use for all methods
* @param executorService the ExecutorService from client.dispatcher().executorService()
*/
public OpenAiService(final OpenAiApi api) {
public OpenAiService(final OpenAiApi api, final ExecutorService executorService) {
this.api = api;
this.executorService = executorService;
}

public List<Model> listModels() {
Expand All @@ -88,11 +108,39 @@ public Model getModel(String modelId) {
public CompletionResult createCompletion(CompletionRequest request) {
return execute(api.createCompletion(request));
}

public Flowable<byte[]> streamCompletionBytes(CompletionRequest request) {
request.setStream(true);

return stream(api.createCompletionStream(request), true).map(sse -> {
return sse.toBytes();
});
}

public Flowable<CompletionChunk> streamCompletion(CompletionRequest request) {
request.setStream(true);

return stream(api.createCompletionStream(request), CompletionChunk.class);
}

public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
return execute(api.createChatCompletion(request));
}

public Flowable<byte[]> streamChatCompletionBytes(ChatCompletionRequest request) {
request.setStream(true);

return stream(api.createChatCompletionStream(request), true).map(sse -> {
return sse.toBytes();
});
}

public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
request.setStream(true);

return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
}

public EditResult createEdit(EditRequest request) {
return execute(api.createEdit(request));
}
Expand Down Expand Up @@ -232,12 +280,55 @@ public static <T> T execute(Single<T> apiCall) {
}
}

public static OpenAiApi buildApi(String token, Duration timeout) {
Objects.requireNonNull(token, "OpenAI token required");
/**
* Calls the Open AI api and returns a Flowable of SSE for streaming
* omitting the last message.
*
* @param apiCall The api call
*/
public static Flowable<SSE> stream(Call<ResponseBody> apiCall) {
return stream(apiCall, false);
}

/**
* Calls the Open AI api and returns a Flowable of SSE for streaming.
*
* @param apiCall The api call
* @param emitDone If true the last message ([DONE]) is emitted
*/
public static Flowable<SSE> stream(Call<ResponseBody> apiCall, boolean emitDone) {
return Flowable.create(emitter -> {
apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone));
}, BackpressureStrategy.BUFFER);
}

/**
* Calls the Open AI api and returns a Flowable of type T for streaming
* omitting the last message.
*
* @param apiCall The api call
* @param cl Class of type T to return
*/
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
return stream(apiCall).map(sse -> {
return errorMapper.readValue(sse.getData(), cl);
});
}

/**
* Shuts down the OkHttp ExecutorService.
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
* is to shutdown after an idle timeout of 60s.
* Call this method to shutdown the ExecutorService immediately.
*/
public void shutdownExecutor(){
this.executorService.shutdown();
}

public static OpenAiApi buildApi(OkHttpClient client) {
ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper);

return retrofit.create(OpenAiApi.class);
}

Expand All @@ -250,6 +341,8 @@ public static ObjectMapper defaultObjectMapper() {
}

public static OkHttpClient defaultClient(String token, Duration timeout) {
Objects.requireNonNull(token, "OpenAI token required");

return new OkHttpClient.Builder()
.addInterceptor(new AuthenticationInterceptor(token))
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))
Expand Down
Loading