Skip to content

Commit

Permalink
Streamline ChatOptions
Browse files Browse the repository at this point in the history
* Surface more configuration APIs to ChatOptions
* Use abstraction in Observations directly instead of dedicated implementation
* Simplify metadata config in observations for defined models
* Improve merging of runtime and default options in OpenAI
* Fix missing option in Mistral AI

Relates to spring-projectsgh-1148

Signed-off-by: Thomas Vitale <[email protected]>
  • Loading branch information
ThomasVitale committed Aug 10, 2024
1 parent af25430 commit bf84d59
Show file tree
Hide file tree
Showing 59 changed files with 1,031 additions and 694 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
* The options to be used when sending a chat request to the Anthropic API.
*
* @author Christian Tzolov
* @author Thomas Vitale
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
Expand Down Expand Up @@ -149,6 +150,7 @@ public AnthropicChatOptions build() {

}

@Override
public String getModel() {
return model;
}
Expand All @@ -157,6 +159,7 @@ public void setModel(String model) {
this.model = model;
}

@Override
public Integer getMaxTokens() {
return this.maxTokens;
}
Expand All @@ -173,6 +176,7 @@ public void setMetadata(ChatCompletionRequest.Metadata metadata) {
this.metadata = metadata;
}

@Override
public List<String> getStopSequences() {
return this.stopSequences;
}
Expand All @@ -199,6 +203,7 @@ public void setTopP(Float topP) {
this.topP = topP;
}

@Override
public Integer getTopK() {
return this.topK;
}
Expand Down Expand Up @@ -229,6 +234,18 @@ public void setFunctions(Set<String> functions) {
this.functions = functions;
}

@Override
@JsonIgnore
public Float getFrequencyPenalty() {
return null;
}

@Override
@JsonIgnore
public Float getPresencePenalty() {
return null;
}

@Override
public AnthropicChatOptions copy() {
return fromOptions(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
* prompt data.
*
* @author Christian Tzolov
* @author Thomas Vitale
*/
@JsonInclude(Include.NON_NULL)
public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
Expand Down Expand Up @@ -108,7 +109,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
* output new topics.
*/
@JsonProperty(value = "presence_penalty")
private Double presencePenalty;
private Float presencePenalty;

/**
* A value that influences the probability of generated tokens appearing based on
Expand All @@ -117,7 +118,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
* model repeating the same statements verbatim.
*/
@JsonProperty(value = "frequency_penalty")
private Double frequencyPenalty;
private Float frequencyPenalty;

/**
* The deployment name as defined in Azure Open AI Studio when creating a deployment
Expand Down Expand Up @@ -182,9 +183,7 @@ public Builder withDeploymentName(String deploymentName) {
}

public Builder withFrequencyPenalty(Float frequencyPenalty) {
if (frequencyPenalty != null) {
this.options.frequencyPenalty = frequencyPenalty.doubleValue();
}
this.options.frequencyPenalty = frequencyPenalty;
return this;
}

Expand All @@ -204,9 +203,7 @@ public Builder withN(Integer n) {
}

public Builder withPresencePenalty(Float presencePenalty) {
if (presencePenalty != null) {
this.options.presencePenalty = presencePenalty.doubleValue();
}
this.options.presencePenalty = presencePenalty;
return this;
}

Expand Down Expand Up @@ -259,6 +256,7 @@ public AzureOpenAiChatOptions build() {

}

@Override
public Integer getMaxTokens() {
return this.maxTokens;
}
Expand Down Expand Up @@ -291,6 +289,17 @@ public void setN(Integer n) {
this.n = n;
}

@Override
@JsonIgnore
public List<String> getStopSequences() {
return getStop();
}

@JsonIgnore
public void setStopSequences(List<String> stopSequences) {
setStop(stopSequences);
}

public List<String> getStop() {
return this.stop;
}
Expand All @@ -299,22 +308,35 @@ public void setStop(List<String> stop) {
this.stop = stop;
}

public Double getPresencePenalty() {
@Override
public Float getPresencePenalty() {
return this.presencePenalty;
}

public void setPresencePenalty(Double presencePenalty) {
public void setPresencePenalty(Float presencePenalty) {
this.presencePenalty = presencePenalty;
}

public Double getFrequencyPenalty() {
@Override
public Float getFrequencyPenalty() {
return this.frequencyPenalty;
}

public void setFrequencyPenalty(Double frequencyPenalty) {
public void setFrequencyPenalty(Float frequencyPenalty) {
this.frequencyPenalty = frequencyPenalty;
}

@Override
@JsonIgnore
public String getModel() {
return getDeploymentName();
}

@JsonIgnore
public void setModel(String model) {
setDeploymentName(model);
}

public String getDeploymentName() {
return this.deploymentName;
}
Expand All @@ -341,17 +363,6 @@ public void setTopP(Float topP) {
this.topP = topP;
}

@Override
@JsonIgnore
public Integer getTopK() {
throw new UnsupportedOperationException("Unimplemented method 'getTopK'");
}

@JsonIgnore
public void setTopK(Integer topK) {
throw new UnsupportedOperationException("Unimplemented method 'setTopK'");
}

@Override
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
Expand All @@ -378,20 +389,24 @@ public void setResponseFormat(AzureOpenAiResponseFormat responseFormat) {
this.responseFormat = responseFormat;
}

@Override
@JsonIgnore
public Integer getTopK() {
return null;
}

@Override
public AzureOpenAiChatOptions copy() {
return fromOptions(this);
}

public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOptions) {
return builder().withDeploymentName(fromOptions.getDeploymentName())
.withFrequencyPenalty(
fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty().floatValue() : null)
.withFrequencyPenalty(fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty() : null)
.withLogitBias(fromOptions.getLogitBias())
.withMaxTokens(fromOptions.getMaxTokens())
.withN(fromOptions.getN())
.withPresencePenalty(
fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty().floatValue() : null)
.withPresencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null)
.withStop(fromOptions.getStop())
.withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.List;

import com.fasterxml.jackson.annotation.JsonIgnore;
import org.springframework.ai.embedding.EmbeddingOptions;

/**
Expand Down Expand Up @@ -125,10 +126,16 @@ public AzureOpenAiEmbeddingOptions build() {
}

@Override
@JsonIgnore
public String getModel() {
return getDeploymentName();
}

@JsonIgnore
public void setModel(String model) {
setDeploymentName(model);
}

public String getUser() {
return this.user;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.List;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;

Expand All @@ -26,6 +27,7 @@

/**
* @author Christian Tzolov
* @author Thomas Vitale
*/
@JsonInclude(Include.NON_NULL)
public class AnthropicChatOptions implements ChatOptions {
Expand Down Expand Up @@ -122,6 +124,17 @@ public void setTemperature(Float temperature) {
this.temperature = temperature;
}

@Override
@JsonIgnore
public Integer getMaxTokens() {
return getMaxTokensToSample();
}

@JsonIgnore
public void setMaxTokens(Integer maxTokens) {
setMaxTokensToSample(maxTokens);
}

public Integer getMaxTokensToSample() {
return this.maxTokensToSample;
}
Expand All @@ -148,6 +161,7 @@ public void setTopP(Float topP) {
this.topP = topP;
}

@Override
public List<String> getStopSequences() {
return this.stopSequences;
}
Expand All @@ -164,6 +178,24 @@ public void setAnthropicVersion(String anthropicVersion) {
this.anthropicVersion = anthropicVersion;
}

@Override
@JsonIgnore
public String getModel() {
return null;
}

@Override
@JsonIgnore
public Float getFrequencyPenalty() {
return null;
}

@Override
@JsonIgnore
public Float getPresencePenalty() {
return null;
}

@Override
public AnthropicChatOptions copy() {
return fromOptions(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.bedrock.anthropic3;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
Expand All @@ -24,6 +25,7 @@

/**
* @author Ben Middleton
* @author Thomas Vitale
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
Expand Down Expand Up @@ -121,6 +123,7 @@ public void setTemperature(Float temperature) {
this.temperature = temperature;
}

@Override
public Integer getMaxTokens() {
return this.maxTokens;
}
Expand All @@ -147,6 +150,7 @@ public void setTopP(Float topP) {
this.topP = topP;
}

@Override
public List<String> getStopSequences() {
return this.stopSequences;
}
Expand All @@ -163,6 +167,24 @@ public void setAnthropicVersion(String anthropicVersion) {
this.anthropicVersion = anthropicVersion;
}

@Override
@JsonIgnore
public String getModel() {
return null;
}

@Override
@JsonIgnore
public Float getFrequencyPenalty() {
return null;
}

@Override
@JsonIgnore
public Float getPresencePenalty() {
return null;
}

@Override
public Anthropic3ChatOptions copy() {
return fromOptions(this);
Expand Down
Loading

0 comments on commit bf84d59

Please sign in to comment.