Skip to content

Commit

Permalink
[OpenAI] Rename to getImages() and LRO NPE work around solution (Azur…
Browse files Browse the repository at this point in the history
  • Loading branch information
mssfang authored Jul 14, 2023
1 parent bb947df commit 88b51b9
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,34 @@
@ServiceClient(builder = OpenAIClientBuilder.class, isAsync = true)
public final class OpenAIAsyncClient {

@Generated private final OpenAIClientImpl serviceClient;

private final NonAzureOpenAIClientImpl openAIServiceClient;

/**
* Initializes an instance of OpenAIAsyncClient class by using "Azure" OpenAI service implementation. Azure OpenAI
* and Non-Azure OpenAI Service implementations are mutually exclusive. Both service client implementation cannot
* coexist because `OpenAIClient` operates either way in a mutually exclusive way.
*
* @param serviceClient the service client implementation for Azure OpenAI Service client.
*/
OpenAIAsyncClient(OpenAIClientImpl serviceClient) {
this.serviceClient = serviceClient;
openAIServiceClient = null;
}

/**
* Initializes an instance of OpenAIAsyncClient class by using "Non-Azure" OpenAI service implementation. Azure
* OpenAI and Non-Azure OpenAI Service implementations are mutually exclusive. Both service client implementation
* cannot coexist because `OpenAIClient` operates either way in a mutually exclusive way.
*
* @param serviceClient the service client implementation for Non-Azure OpenAI Service client.
*/
OpenAIAsyncClient(NonAzureOpenAIClientImpl serviceClient) {
this.serviceClient = null;
openAIServiceClient = serviceClient;
}

/**
* Return the embeddings for a given prompt.
*
Expand Down Expand Up @@ -438,34 +466,6 @@ public Flux<ChatCompletions> getChatCompletionsStream(
return chatCompletionsStream.getEvents();
}

@Generated private final OpenAIClientImpl serviceClient;

private final NonAzureOpenAIClientImpl openAIServiceClient;

/**
* Initializes an instance of OpenAIAsyncClient class by using "Azure" OpenAI service implementation. Azure OpenAI
* and Non-Azure OpenAI Service implementations are mutually exclusive. Both service client implementation cannot
* coexist because `OpenAIClient` operates either way in a mutually exclusive way.
*
* @param serviceClient the service client implementation for Azure OpenAI Service client.
*/
OpenAIAsyncClient(OpenAIClientImpl serviceClient) {
this.serviceClient = serviceClient;
openAIServiceClient = null;
}

/**
* Initializes an instance of OpenAIAsyncClient class by using "Non-Azure" OpenAI service implementation. Azure
* OpenAI and Non-Azure OpenAI Service implementations are mutually exclusive. Both service client implementation
* cannot coexist because `OpenAIClient` operates either way in a mutually exclusive way.
*
* @param serviceClient the service client implementation for Non-Azure OpenAI Service client.
*/
OpenAIAsyncClient(NonAzureOpenAIClientImpl serviceClient) {
this.serviceClient = null;
openAIServiceClient = serviceClient;
}

/**
* Starts the generation of a batch of images from a text caption.
*
Expand All @@ -479,7 +479,7 @@ public Flux<ChatCompletions> getChatCompletionsStream(
* @return the {@link Mono} with the image generation result
*/
@ServiceMethod(returns = ReturnType.SINGLE)
public Mono<ImageResponse> generateImage(ImageGenerationOptions imageGenerationOptions) {
public Mono<ImageResponse> getImages(ImageGenerationOptions imageGenerationOptions) {
RequestOptions requestOptions = new RequestOptions();
BinaryData imageGenerationOptionsBinaryData = BinaryData.fromObject(imageGenerationOptions);
return openAIServiceClient != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,43 @@
import com.azure.core.http.rest.Response;
import com.azure.core.util.BinaryData;
import com.azure.core.util.IterableStream;
import com.azure.core.util.logging.ClientLogger;
import com.azure.core.util.polling.SyncPoller;
import java.nio.ByteBuffer;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;

/** Initializes a new instance of the synchronous OpenAIClient type. */
@ServiceClient(builder = OpenAIClientBuilder.class)
public final class OpenAIClient {
private static final ClientLogger LOGGER = new ClientLogger(OpenAIClient.class);

@Generated private final OpenAIClientImpl serviceClient;

private final NonAzureOpenAIClientImpl openAIServiceClient;

/**
* Initializes an instance of OpenAIClient class.
*
* @param serviceClient the service client implementation.
*/
OpenAIClient(OpenAIClientImpl serviceClient) {
this.serviceClient = serviceClient;
openAIServiceClient = null;
}

/**
* Initializes an instance of OpenAIClient class for NonAzure Implementation.
*
* @param serviceClient the service client implementation.
*/
OpenAIClient(NonAzureOpenAIClientImpl serviceClient) {
this.serviceClient = null;
openAIServiceClient = serviceClient;
}

/**
* Return the embeddings for a given prompt.
Expand Down Expand Up @@ -435,30 +465,6 @@ public IterableStream<ChatCompletions> getChatCompletionsStream(
return new IterableStream<ChatCompletions>(chatCompletionsStream.getEvents());
}

@Generated private final OpenAIClientImpl serviceClient;

private final NonAzureOpenAIClientImpl openAIServiceClient;

/**
* Initializes an instance of OpenAIClient class.
*
* @param serviceClient the service client implementation.
*/
OpenAIClient(OpenAIClientImpl serviceClient) {
this.serviceClient = serviceClient;
openAIServiceClient = null;
}

/**
* Initializes an instance of OpenAIClient class for NonAzure Implementation.
*
* @param serviceClient the service client implementation.
*/
OpenAIClient(NonAzureOpenAIClientImpl serviceClient) {
this.serviceClient = null;
openAIServiceClient = serviceClient;
}

/**
* Starts the generation of a batch of images from a text caption.
*
Expand All @@ -469,22 +475,39 @@ public IterableStream<ChatCompletions> getChatCompletionsStream(
* @throws ResourceNotFoundException thrown if the request is rejected by server on status code 404.
* @throws ResourceModifiedException thrown if the request is rejected by server on status code 409.
* @throws RuntimeException all other wrapped checked exceptions if the request fails to be sent.
* @return the {@link ImageOperationResponse} for polling of status details for long running operations.
* @return the {@link ImageResponse} for the image generation result.
*/
@ServiceMethod(returns = ReturnType.SINGLE)
public ImageResponse generateImage(ImageGenerationOptions imageGenerationOptions) {
public ImageResponse getImages(ImageGenerationOptions imageGenerationOptions) {
RequestOptions requestOptions = new RequestOptions();
BinaryData imageGenerationOptionsBinaryData = BinaryData.fromObject(imageGenerationOptions);
return openAIServiceClient != null
? openAIServiceClient
.generateImageWithResponse(imageGenerationOptionsBinaryData, requestOptions)
.getValue()
.toObject(ImageResponse.class)
: beginBeginAzureBatchImageGeneration(imageGenerationOptionsBinaryData, requestOptions)
.waitForCompletion()
.getValue()
.toObject(ImageOperationResponse.class)
.getResult();

if (openAIServiceClient != null) {
return openAIServiceClient
.generateImageWithResponse(imageGenerationOptionsBinaryData, requestOptions)
.getValue()
.toObject(ImageResponse.class);
} else {
// TODO: Currently, we use async client block() to avoid a unknown LRO status "notRunning" which Azure Core will
// fix the issue in August release and we will reuse the method
// "SyncPoller<BinaryData, BinaryData> beginBeginAzureBatchImageGeneration()" after.
try {
return this.serviceClient.beginBeginAzureBatchImageGenerationAsync(imageGenerationOptionsBinaryData,
requestOptions)
.last()
.flatMap(it -> it.getFinalResult())
.map(it -> it.toObject(ImageOperationResponse.class).getResult()).block();
} catch (Exception e) {
Throwable unwrapped = Exceptions.unwrap(e);
if (unwrapped instanceof RuntimeException) {
throw LOGGER.logExceptionAsError((RuntimeException) unwrapped);
} else if (unwrapped instanceof IOException) {
throw LOGGER.logExceptionAsError(new UncheckedIOException((IOException) unwrapped));
} else {
throw LOGGER.logExceptionAsError(new RuntimeException(unwrapped));
}
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,16 @@ public void testGetEmbeddingsWithResponse(HttpClient httpClient, OpenAIServiceVe
});
}


@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
public void testGenerateImage(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);
getImageGenerationRunner(options ->
StepVerifier.create(client.generateImage(options))
StepVerifier.create(client.getImages(options))
.assertNext(OpenAIClientTestBase::assertImageResponse)
.verifyComplete());
}


@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
public void testChatFunctionAutoPreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void testGetEmbeddingsWithResponse(HttpClient httpClient, OpenAIServiceVe
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
public void testGenerateImage(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);
getImageGenerationRunner(options -> assertImageResponse(client.generateImage(options)));
getImageGenerationRunner(options -> assertImageResponse(client.getImages(options)));
}

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ public void testGetEmbeddingsWithResponse(HttpClient httpClient, OpenAIServiceVe
public void testGenerateImage(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);
getImageGenerationRunner(options ->
StepVerifier.create(client.generateImage(options))
StepVerifier.create(client.getImages(options))
.assertNext(OpenAIClientTestBase::assertImageResponse)
.verifyComplete());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
package com.azure.ai.openai;

import com.azure.ai.openai.functions.Parameters;
import com.azure.ai.openai.models.ContentFilterResults;
import com.azure.ai.openai.models.ContentFilterSeverity;
import com.azure.ai.openai.models.FunctionCall;
import com.azure.ai.openai.implementation.models.FunctionDefinition;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatMessage;
import com.azure.ai.openai.models.ChatRole;
import com.azure.ai.openai.models.Choice;
import com.azure.ai.openai.models.Completions;
import com.azure.ai.openai.models.ContentFilterResults;
import com.azure.ai.openai.models.ContentFilterSeverity;
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
import com.azure.ai.openai.models.EmbeddingsOptions;
import com.azure.ai.openai.implementation.models.FunctionDefinition;
import com.azure.ai.openai.models.FunctionCall;
import com.azure.ai.openai.models.ImageGenerationOptions;
import com.azure.ai.openai.models.ImageResponse;
import com.azure.ai.openai.models.NonAzureOpenAIKeyCredential;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public void testGetEmbeddingsWithResponse(HttpClient httpClient, OpenAIServiceVe
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
public void testGenerateImage(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIClient(httpClient, serviceVersion);
getImageGenerationRunner(options -> assertImageResponse(client.generateImage(options)));
getImageGenerationRunner(options -> assertImageResponse(client.getImages(options)));
}

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
Expand Down

0 comments on commit 88b51b9

Please sign in to comment.