Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Commit

Permalink
Streaming support (#195)
Browse files Browse the repository at this point in the history
Utilize retrofit2.http.Streaming and retrofit2.Call<ResponseBody>
in additional OpenAIApi methods to enable a streamable ResponseBody.

Utilize retrofit2.Callback to get the streamable ResponseBody,
parse Server Sent Events (SSE) and emit them using
io.reactivex.FlowableEmitter.

Enable:

- Streaming of raw bytes
- Streaming of Java objects
- Shutdown of OkHttp ExecutorService

Fixes: #51, #83, #182, #184
  • Loading branch information
n3bul4 committed Mar 28, 2023
1 parent a44d79b commit 7dc5b5b
Show file tree
Hide file tree
Showing 9 changed files with 403 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.theokanning.openai.completion;

import lombok.Data;
import java.util.List;

/**
* Object containing a response chunk from the completions streaming api.
*
* https://beta.openai.com/docs/api-reference/completions/create
*/
@Data
public class CompletionChunk {
/**
* A unique id assigned to this completion.
*/
String id;

/**https://beta.openai.com/docs/api-reference/create-completion
* The type of object returned, should be "text_completion"
*/
String object;

/**
* The creation time in epoch seconds.
*/
long created;

/**
* The GPT-3 model used.
*/
String model;

/**
* A list of generated completions.
*/
List<CompletionChoice> choices;
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package com.theokanning.openai.completion.chat;
import com.fasterxml.jackson.annotation.JsonAlias;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;

Expand All @@ -14,8 +15,9 @@ public class ChatCompletionChoice {
Integer index;

/**
* The {@link ChatMessageRole#assistant} message which was generated.
* The {@link ChatMessageRole#assistant} message or delta (when streaming) which was generated
*/
@JsonAlias("delta")
ChatMessage message;

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.theokanning.openai.completion.chat;
import lombok.Data;

import java.util.List;

/**
* Object containing a response chunk from the chat completions streaming api.
*/
@Data
public class ChatCompletionChunk {
/**
* Unique id assigned to this chat completion.
*/
String id;

/**
* The type of object returned, should be "chat.completion.chunk"
*/
String object;

/**
* The creation time in epoch seconds.
*/
long created;

/**
* The GPT-3.5 model used.
*/
String model;

/**
* A list of all generated completions.
*/
List<ChatCompletionChoice> choices;
}
10 changes: 10 additions & 0 deletions client/src/main/java/com/theokanning/openai/OpenAiApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import io.reactivex.Single;
import okhttp3.MultipartBody;
import okhttp3.RequestBody;
import okhttp3.ResponseBody;
import retrofit2.Call;
import retrofit2.http.*;

public interface OpenAiApi {
Expand All @@ -34,10 +36,18 @@ public interface OpenAiApi {

@POST("/v1/completions")
Single<CompletionResult> createCompletion(@Body CompletionRequest request);

@Streaming
@POST("/v1/completions")
Call<ResponseBody> createCompletionStream(@Body CompletionRequest request);

@POST("/v1/chat/completions")
Single<ChatCompletionResult> createChatCompletion(@Body ChatCompletionRequest request);

@Streaming
@POST("/v1/chat/completions")
Call<ResponseBody> createChatCompletionStream(@Body ChatCompletionRequest request);

@Deprecated
@POST("/v1/engines/{engine_id}/completions")
Single<CompletionResult> createCompletion(@Path("engine_id") String engineId, @Body CompletionRequest request);
Expand Down
79 changes: 79 additions & 0 deletions example/src/main/java/example/OpenAiApiStreamExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package example;

import com.theokanning.openai.service.OpenAiService;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;

public class OpenAiApiStreamExample {
public static void main(String... args) {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token);

System.out.println("\nCreating completion...");
CompletionRequest completionRequest = CompletionRequest.builder()
.model("ada")
.prompt("Somebody once told me the world is gonna roll me")
.echo(true)
.user("testing")
.n(3)
.build();

/*
Note: when using blockingForEach the calling Thread waits until the loop finishes.
Use forEach instaed of blockignForEach if you don't want the calling Thread to wait.
*/

// stream raw bytes
service
.streamCompletionBytes(completionRequest)
.doOnError( e -> {
e.printStackTrace();
})
.blockingForEach( bytes -> {
System.out.print(new String(bytes));
});

// stream CompletionChunks
service
.streamCompletion(completionRequest)
.doOnError( e -> {
e.printStackTrace();
})
.blockingForEach(System.out::println);


final List<ChatMessage> messages = new ArrayList<>();
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a dog and will speak as such.");
messages.add(systemMessage);

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo")
.messages(messages)
.n(5)
.maxTokens(50)
.logitBias(new HashMap<>())
.build();

// stream ChatCompletionChunks
service
.streamChatCompletion(chatCompletionRequest)
.doOnError( e -> {
e.printStackTrace();
})
.blockingForEach(System.out::println);

/*
* shutdown the OkHttp ExecutorService to
* exit immediately after the loops have finished
*/
service.shutdownExecutor();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import com.theokanning.openai.OpenAiApi;
import com.theokanning.openai.OpenAiError;
import com.theokanning.openai.OpenAiHttpException;
import com.theokanning.openai.completion.CompletionChunk;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.edit.EditRequest;
Expand All @@ -27,17 +29,22 @@
import com.theokanning.openai.model.Model;
import com.theokanning.openai.moderation.ModerationRequest;
import com.theokanning.openai.moderation.ModerationResult;

import io.reactivex.BackpressureStrategy;
import io.reactivex.Flowable;
import io.reactivex.Single;
import okhttp3.*;
import retrofit2.HttpException;
import retrofit2.Retrofit;
import retrofit2.Call;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;

import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;

public class OpenAiService {
Expand All @@ -47,6 +54,7 @@ public class OpenAiService {
private static final ObjectMapper errorMapper = defaultObjectMapper();

private final OpenAiApi api;
private final ExecutorService executorService;

/**
* Creates a new OpenAiService that wraps OpenAiApi
Expand All @@ -64,17 +72,29 @@ public OpenAiService(final String token) {
* @param timeout http read timeout, Duration.ZERO means no timeout
*/
public OpenAiService(final String token, final Duration timeout) {
this(buildApi(token, timeout));
this(defaultClient(token, timeout));
}

/**
* Creates a new OpenAiService that wraps OpenAiApi
*
* @param client OkHttpClient to be used for api calls
*/
public OpenAiService(OkHttpClient client){
this(buildApi(client), client.dispatcher().executorService());
}

/**
* Creates a new OpenAiService that wraps OpenAiApi.
* Use this if you need more customization.
* The ExecutoryService must be the one you get from the client you created the api with
* otherwise shutdownExecutor() won't work. Use this if you need more customization.
*
* @param api OpenAiApi instance to use for all methods
* @param executorService the ExecutorService from client.dispatcher().executorService()
*/
public OpenAiService(final OpenAiApi api) {
public OpenAiService(final OpenAiApi api, final ExecutorService executorService) {
this.api = api;
this.executorService = executorService;
}

public List<Model> listModels() {
Expand All @@ -88,11 +108,39 @@ public Model getModel(String modelId) {
public CompletionResult createCompletion(CompletionRequest request) {
return execute(api.createCompletion(request));
}

public Flowable<byte[]> streamCompletionBytes(CompletionRequest request) {
request.setStream(true);

return stream(api.createCompletionStream(request), true).map(sse -> {
return sse.toBytes();
});
}

public Flowable<CompletionChunk> streamCompletion(CompletionRequest request) {
request.setStream(true);

return stream(api.createCompletionStream(request), CompletionChunk.class);
}

public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
return execute(api.createChatCompletion(request));
}

public Flowable<byte[]> streamChatCompletionBytes(ChatCompletionRequest request) {
request.setStream(true);

return stream(api.createChatCompletionStream(request), true).map(sse -> {
return sse.toBytes();
});
}

public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
request.setStream(true);

return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
}

public EditResult createEdit(EditRequest request) {
return execute(api.createEdit(request));
}
Expand Down Expand Up @@ -232,12 +280,55 @@ public static <T> T execute(Single<T> apiCall) {
}
}

public static OpenAiApi buildApi(String token, Duration timeout) {
Objects.requireNonNull(token, "OpenAI token required");
/**
* Calls the Open AI api and returns a Flowable of SSE for streaming
* omitting the last message.
*
* @param apiCall The api call
*/
public static Flowable<SSE> stream(Call<ResponseBody> apiCall) {
return stream(apiCall, false);
}

/**
* Calls the Open AI api and returns a Flowable of SSE for streaming.
*
* @param apiCall The api call
* @param emitDone If true the last message ([DONE]) is emitted
*/
public static Flowable<SSE> stream(Call<ResponseBody> apiCall, boolean emitDone) {
return Flowable.create(emitter -> {
apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone));
}, BackpressureStrategy.BUFFER);
}

/**
* Calls the Open AI api and returns a Flowable of type T for streaming
* omitting the last message.
*
* @param apiCall The api call
* @param cl Class of type T to return
*/
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
return stream(apiCall).map(sse -> {
return errorMapper.readValue(sse.getData(), cl);
});
}

/**
* Shuts down the OkHttp ExecutorService.
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
* is to shutdown after an idle timeout of 60s.
* Call this method to shutdown the ExecutorService immediately.
*/
public void shutdownExecutor(){
this.executorService.shutdown();
}

public static OpenAiApi buildApi(OkHttpClient client) {
ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper);

return retrofit.create(OpenAiApi.class);
}

Expand All @@ -250,6 +341,8 @@ public static ObjectMapper defaultObjectMapper() {
}

public static OkHttpClient defaultClient(String token, Duration timeout) {
Objects.requireNonNull(token, "OpenAI token required");

return new OkHttpClient.Builder()
.addInterceptor(new AuthenticationInterceptor(token))
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))
Expand Down
Loading

0 comments on commit 7dc5b5b

Please sign in to comment.