Skip to content

GH-1378: Add parameter warnings and implement penalty options for Ver… #3033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
* @author Claudio Silva Junior
* @author Alexandros Pappas
* @author Jonghoon Park
* @author Soby Chacko
* @since 1.0.0
*/
public class AnthropicChatModel implements ChatModel {
Expand Down Expand Up @@ -424,6 +425,12 @@ Prompt buildRequestPrompt(Prompt prompt) {
// Merge @JsonIgnore-annotated options explicitly since they are ignored by
// Jackson, used by ModelOptionsUtils.
if (runtimeOptions != null) {
if (runtimeOptions.getFrequencyPenalty() != null) {
logger.warn("Frequency penalty option is ignored by the Anthropic API");
}
if (runtimeOptions.getPresencePenalty() != null) {
logger.warn("Presence penalty option is ignored by the Anthropic API");
}
requestOptions.setHttpHeaders(
mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders()));
requestOptions.setInternalToolExecutionEnabled(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
* @author Wei Jiang
* @author Alexandros Pappas
* @author Jihoon Kim
* @author Soby Chacko
* @since 1.0.0
*/
public class BedrockProxyChatModel implements ChatModel {
Expand Down Expand Up @@ -279,19 +280,23 @@ Prompt buildRequestPrompt(Prompt prompt) {
updatedRuntimeOptions = this.defaultOptions.copy();
}
else {
if (runtimeOptions.getFrequencyPenalty() != null) {
logger.warn("The frequencyPenalty option is not supported by the BedrockProxyChatModel. Ignoring.");
}
if (runtimeOptions.getPresencePenalty() != null) {
logger.warn("The presencePenalty option is not supported by the BedrockProxyChatModel. Ignoring.");
}
if (runtimeOptions.getTopK() != null) {
logger.warn("The topK option is not supported by the BedrockProxyChatModel. Ignoring.");
}
updatedRuntimeOptions = ToolCallingChatOptions.builder()
.model(runtimeOptions.getModel() != null ? runtimeOptions.getModel() : this.defaultOptions.getModel())
.frequencyPenalty(runtimeOptions.getFrequencyPenalty() != null ? runtimeOptions.getFrequencyPenalty()
: this.defaultOptions.getFrequencyPenalty())
.maxTokens(runtimeOptions.getMaxTokens() != null ? runtimeOptions.getMaxTokens()
: this.defaultOptions.getMaxTokens())
.presencePenalty(runtimeOptions.getPresencePenalty() != null ? runtimeOptions.getPresencePenalty()
: this.defaultOptions.getPresencePenalty())
.stopSequences(runtimeOptions.getStopSequences() != null ? runtimeOptions.getStopSequences()
: this.defaultOptions.getStopSequences())
.temperature(runtimeOptions.getTemperature() != null ? runtimeOptions.getTemperature()
: this.defaultOptions.getTemperature())
.topK(runtimeOptions.getTopK() != null ? runtimeOptions.getTopK() : this.defaultOptions.getTopK())
.topP(runtimeOptions.getTopP() != null ? runtimeOptions.getTopP() : this.defaultOptions.getTopP())

.toolCallbacks(runtimeOptions.getToolCallbacks() != null ? runtimeOptions.getToolCallbacks()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @author Alexandros Pappas
* @author Soby Chacko
* @see ChatModel
* @see StreamingChatModel
* @see OpenAiApi
Expand Down Expand Up @@ -507,6 +508,10 @@ Prompt buildRequestPrompt(Prompt prompt) {
// Merge @JsonIgnore-annotated options explicitly since they are ignored by
// Jackson, used by ModelOptionsUtils.
if (runtimeOptions != null) {
if (runtimeOptions.getTopK() != null) {
logger.warn("topK is not supported for chat models in OpenAI");
}

requestOptions.setHttpHeaders(
mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders()));
requestOptions.setInternalToolExecutionEnabled(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,12 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) {
if (options.getResponseMimeType() != null) {
generationConfigBuilder.setResponseMimeType(options.getResponseMimeType());
}
if (options.getFrequencyPenalty() != null) {
generationConfigBuilder.setFrequencyPenalty(options.getFrequencyPenalty().floatValue());
}
if (options.getPresencePenalty() != null) {
generationConfigBuilder.setPresencePenalty(options.getPresencePenalty().floatValue());
}

return generationConfigBuilder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
* @author Thomas Vitale
* @author Grogdunn
* @author Ilayaperumal Gopinathan
* @author Soby Chacko
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
Expand Down Expand Up @@ -95,6 +96,16 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions {
*/
private @JsonProperty("responseMimeType") String responseMimeType;

/**
* Optional. Frequency penalties.
*/
private @JsonProperty("frequencyPenalty") Double frequencyPenalty;

/**
* Optional. Positive penalties.
*/
private @JsonProperty("presencePenalty") Double presencePenalty;

/**
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
* completion requests.
Expand Down Expand Up @@ -138,6 +149,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
options.setTemperature(fromOptions.getTemperature());
options.setTopP(fromOptions.getTopP());
options.setTopK(fromOptions.getTopK());
options.setFrequencyPenalty(fromOptions.getFrequencyPenalty());
options.setPresencePenalty(fromOptions.getPresencePenalty());
options.setCandidateCount(fromOptions.getCandidateCount());
options.setMaxOutputTokens(fromOptions.getMaxOutputTokens());
options.setModel(fromOptions.getModel());
Expand Down Expand Up @@ -269,15 +282,21 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut
}

@Override
@JsonIgnore
public Double getFrequencyPenalty() {
return null;
return this.frequencyPenalty;
}

@Override
@JsonIgnore
public Double getPresencePenalty() {
return null;
return this.presencePenalty;
}

public void setFrequencyPenalty(Double frequencyPenalty) {
this.frequencyPenalty = frequencyPenalty;
}

public void setPresencePenalty(Double presencePenalty) {
this.presencePenalty = presencePenalty;
}

public Boolean getGoogleSearchRetrieval() {
Expand Down Expand Up @@ -319,6 +338,8 @@ public boolean equals(Object o) {
&& Objects.equals(this.stopSequences, that.stopSequences)
&& Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP)
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount)
&& Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
&& Objects.equals(this.presencePenalty, that.presencePenalty)
&& Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model)
&& Objects.equals(this.responseMimeType, that.responseMimeType)
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
Expand All @@ -331,14 +352,16 @@ public boolean equals(Object o) {
@Override
public int hashCode() {
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames,
this.googleSearchRetrieval, this.safetySettings, this.internalToolExecutionEnabled, this.toolContext);
this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType,
this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings,
this.internalToolExecutionEnabled, this.toolContext);
}

@Override
public String toString() {
return "VertexAiGeminiChatOptions{" + "stopSequences=" + this.stopSequences + ", temperature="
+ this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", candidateCount="
+ this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", frequencyPenalty="
+ this.frequencyPenalty + ", presencePenalty=" + this.presencePenalty + ", candidateCount="
+ this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\''
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks
+ ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval
Expand Down Expand Up @@ -380,6 +403,16 @@ public Builder topK(Integer topK) {
return this;
}

public Builder frequencePenalty(Double frequencyPenalty) {
this.options.setFrequencyPenalty(frequencyPenalty);
return this;
}

public Builder presencePenalty(Double presencePenalty) {
this.options.setPresencePenalty(presencePenalty);
return this;
}

public Builder candidateCount(Integer candidateCount) {
this.options.setCandidateCount(candidateCount);
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,6 +45,7 @@

/**
* @author Christian Tzolov
* @author Soby Chacko
*/
@ExtendWith(MockitoExtension.class)
public class CreateGeminiRequestTests {
Expand Down Expand Up @@ -79,6 +80,27 @@ public void createRequestWithChatOptions() {
assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(99.9f);
}

@Test
public void createRequestWithFrequencyAndPresencePenalty() {

var client = VertexAiGeminiChatModel.builder()
.vertexAI(this.vertexAI)
.defaultOptions(VertexAiGeminiChatOptions.builder()
.model("DEFAULT_MODEL")
.frequencePenalty(.25)
.presencePenalty(.75)
.build())
.build();

GeminiRequest request = client.createGeminiRequest(client
.buildRequestPrompt(new Prompt("Test message content", VertexAiGeminiChatOptions.builder().build())));

assertThat(request.contents()).hasSize(1);

assertThat(request.model().getGenerationConfig().getFrequencyPenalty()).isEqualTo(.25F);
assertThat(request.model().getGenerationConfig().getPresencePenalty()).isEqualTo(.75F);
}

@Test
public void createRequestWithSystemMessage() throws MalformedURLException {

Expand Down