Skip to content

Commit

Permalink
Add streaming tests (TheoKanning#225)
Browse files Browse the repository at this point in the history
Added tests, formatted code, and restored public method signatures.
  • Loading branch information
TheoKanning authored Apr 1, 2023
1 parent 7dc5b5b commit 4d5e496
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 225 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ OpenAiApi api = retrofit.create(OpenAiApi.class);
OpenAiService service = new OpenAiService(api);
```


### Streaming thread shutdown
If you want to shut down your process immediately after streaming responses, call `OpenAiService.shutdown()`.
This is not necessary for non-streaming calls.

## Running the example project
All the [example](example/src/main/java/example/OpenAiApiExample.java) project requires is your OpenAI api token
Expand Down
2 changes: 1 addition & 1 deletion client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ apply plugin: "com.vanniktech.maven.publish"
dependencies {
api project(":api")
api 'com.squareup.retrofit2:retrofit:2.9.0'
implementation 'com.squareup.retrofit2:adapter-rxjava2:2.9.0'
api 'com.squareup.retrofit2:adapter-rxjava2:2.9.0'
implementation 'com.squareup.retrofit2:converter-jackson:2.9.0'

testImplementation(platform('org.junit:junit-bom:5.8.2'))
Expand Down
26 changes: 26 additions & 0 deletions example/src/main/java/example/OpenAiApiExample.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
package example;

import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import com.theokanning.openai.service.OpenAiService;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.image.CreateImageRequest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

class OpenAiApiExample {
public static void main(String... args) {
String token = System.getenv("OPENAI_TOKEN");
Expand All @@ -26,5 +33,24 @@ public static void main(String... args) {

System.out.println("\nImage is located at:");
System.out.println(service.createImage(request).getData().get(0).getUrl());

System.out.println("Streaming chat completion...");
final List<ChatMessage> messages = new ArrayList<>();
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a dog and will speak as such.");
messages.add(systemMessage);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo")
.messages(messages)
.n(1)
.maxTokens(50)
.logitBias(new HashMap<>())
.build();

service.streamChatCompletion(chatCompletionRequest)
.doOnError(Throwable::printStackTrace)
.blockingForEach(System.out::println);

service.shutdownExecutor();
}
}
79 changes: 0 additions & 79 deletions example/src/main/java/example/OpenAiApiStreamExample.java

This file was deleted.

3 changes: 2 additions & 1 deletion service/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ dependencies {
implementation 'com.squareup.retrofit2:converter-jackson:2.9.0'

testImplementation(platform('org.junit:junit-bom:5.8.2'))
testImplementation('org.junit.jupiter:junit-jupiter')
testImplementation 'org.junit.jupiter:junit-jupiter'
testImplementation 'com.squareup.retrofit2:retrofit-mock:2.9.0'
}

compileJava {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import okhttp3.Response;

import java.io.IOException;
import java.util.Objects;

/**
* OkHttp Interceptor that adds an authorization token header
Expand All @@ -14,6 +15,7 @@ public class AuthenticationInterceptor implements Interceptor {
private final String token;

AuthenticationInterceptor(String token) {
Objects.requireNonNull(token, "OpenAI token required");
this.token = token;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class OpenAiService {

private static final String BASE_URL = "https://api.openai.com/";
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10);
private static final ObjectMapper errorMapper = defaultObjectMapper();
private static final ObjectMapper mapper = defaultObjectMapper();

private final OpenAiApi api;
private final ExecutorService executorService;
Expand All @@ -72,24 +72,34 @@ public OpenAiService(final String token) {
* @param timeout http read timeout, Duration.ZERO means no timeout
*/
public OpenAiService(final String token, final Duration timeout) {
this(defaultClient(token, timeout));
ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper);

this.api = retrofit.create(OpenAiApi.class);
this.executorService = client.dispatcher().executorService();
}

/**
* Creates a new OpenAiService that wraps OpenAiApi
* Creates a new OpenAiService that wraps OpenAiApi.
* Use this if you need more customization, but use OpenAiService(api, executorService) if you use streaming and
* want to shut down instantly
*
* @param client OkHttpClient to be used for api calls
* @param api OpenAiApi instance to use for all methods
*/
public OpenAiService(OkHttpClient client){
this(buildApi(client), client.dispatcher().executorService());
public OpenAiService(final OpenAiApi api) {
this.api = api;
this.executorService = null;
}

/**
* Creates a new OpenAiService that wraps OpenAiApi.
* The ExecutoryService must be the one you get from the client you created the api with
* otherwise shutdownExecutor() won't work. Use this if you need more customization.
* The ExecutorService must be the one you get from the client you created the api with
* otherwise shutdownExecutor() won't work.
* <p>
* Use this if you need more customization.
*
* @param api OpenAiApi instance to use for all methods
* @param api OpenAiApi instance to use for all methods
* @param executorService the ExecutorService from client.dispatcher().executorService()
*/
public OpenAiService(final OpenAiApi api, final ExecutorService executorService) {
Expand All @@ -109,37 +119,21 @@ public CompletionResult createCompletion(CompletionRequest request) {
return execute(api.createCompletion(request));
}

public Flowable<byte[]> streamCompletionBytes(CompletionRequest request) {
request.setStream(true);

return stream(api.createCompletionStream(request), true).map(sse -> {
return sse.toBytes();
});
}

public Flowable<CompletionChunk> streamCompletion(CompletionRequest request) {
request.setStream(true);
return stream(api.createCompletionStream(request), CompletionChunk.class);
}
request.setStream(true);

return stream(api.createCompletionStream(request), CompletionChunk.class);
}

public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
return execute(api.createChatCompletion(request));
}

public Flowable<byte[]> streamChatCompletionBytes(ChatCompletionRequest request) {
request.setStream(true);
public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
request.setStream(true);

return stream(api.createChatCompletionStream(request), true).map(sse -> {
return sse.toBytes();
});
}

public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
request.setStream(true);

return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
}
return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
}

public EditResult createEdit(EditRequest request) {
return execute(api.createEdit(request));
Expand Down Expand Up @@ -271,7 +265,7 @@ public static <T> T execute(Single<T> apiCall) {
}
String errorBody = e.response().errorBody().string();

OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class);
OpenAiError error = mapper.readValue(errorBody, OpenAiError.class);
throw new OpenAiHttpException(error, e, e.code());
} catch (IOException ex) {
// couldn't parse OpenAI error
Expand All @@ -283,52 +277,50 @@ public static <T> T execute(Single<T> apiCall) {
/**
* Calls the Open AI api and returns a Flowable of SSE for streaming
* omitting the last message.
*
*
* @param apiCall The api call
*/
public static Flowable<SSE> stream(Call<ResponseBody> apiCall) {
return stream(apiCall, false);
}
return stream(apiCall, false);
}

/**
* Calls the Open AI api and returns a Flowable of SSE for streaming.
*
* @param apiCall The api call
*
* @param apiCall The api call
* @param emitDone If true the last message ([DONE]) is emitted
*/
public static Flowable<SSE> stream(Call<ResponseBody> apiCall, boolean emitDone) {
return Flowable.create(emitter -> {
apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone));
}, BackpressureStrategy.BUFFER);
}
public static Flowable<SSE> stream(Call<ResponseBody> apiCall, boolean emitDone) {
return Flowable.create(emitter -> apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone)), BackpressureStrategy.BUFFER);
}

/**
* Calls the Open AI api and returns a Flowable of type T for streaming
* omitting the last message.
*
*
* @param apiCall The api call
* @param cl Class of type T to return
* @param cl Class of type T to return
*/
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
return stream(apiCall).map(sse -> {
return errorMapper.readValue(sse.getData(), cl);
});
}
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
return stream(apiCall).map(sse -> mapper.readValue(sse.getData(), cl));
}

/**
* Shuts down the OkHttp ExecutorService.
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
* is to shutdown after an idle timeout of 60s.
* Call this method to shutdown the ExecutorService immediately.
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
* is to shut down after an idle timeout of 60s.
* Call this method to shut down the ExecutorService immediately.
*/
public void shutdownExecutor(){
public void shutdownExecutor() {
Objects.requireNonNull(this.executorService, "executorService must be set in order to shut down");
this.executorService.shutdown();
}

public static OpenAiApi buildApi(OkHttpClient client) {
public static OpenAiApi buildApi(String token, Duration timeout) {
ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper);

return retrofit.create(OpenAiApi.class);
}

Expand All @@ -341,8 +333,6 @@ public static ObjectMapper defaultObjectMapper() {
}

public static OkHttpClient defaultClient(String token, Duration timeout) {
Objects.requireNonNull(token, "OpenAI token required");

return new OkHttpClient.Builder()
.addInterceptor(new AuthenticationInterceptor(token))
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))
Expand Down
Loading

0 comments on commit 4d5e496

Please sign in to comment.