Skip to content

Feature: Added Structured Output on Azure Open AI #2431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.ai.azure.openai;

import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormatJsonSchema;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
Expand Down Expand Up @@ -58,6 +60,8 @@
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.JsonSchema;
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand Down Expand Up @@ -115,6 +119,7 @@
* @author Jihoon Kim
* @author Ilayaperumal Gopinathan
* @author Alexandros Pappas
* @author Bart Veenstra
* @see ChatModel
* @see com.azure.ai.openai.OpenAIClient
* @since 1.0.0
Expand Down Expand Up @@ -278,7 +283,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
Expand Down Expand Up @@ -334,7 +338,6 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
Expand Down Expand Up @@ -940,9 +943,16 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
* @return Azure response format
*/
private ChatCompletionsResponseFormat toAzureResponseFormat(AzureOpenAiResponseFormat responseFormat) {
if (responseFormat == AzureOpenAiResponseFormat.JSON) {
if (responseFormat.getType() == Type.JSON_OBJECT) {
return new ChatCompletionsJsonResponseFormat();
}
if (responseFormat.getType() == Type.JSON_SCHEMA) {
JsonSchema jsonSchema = responseFormat.getJsonSchema();
var responseFormatJsonSchema = new ChatCompletionsJsonSchemaResponseFormatJsonSchema(jsonSchema.getName());
String jsonString = ModelOptionsUtils.toJsonString(jsonSchema.getSchema());
responseFormatJsonSchema.setSchema(BinaryData.fromString(jsonString));
return new ChatCompletionsJsonSchemaResponseFormat(responseFormatJsonSchema);
}
return new ChatCompletionsTextResponseFormat();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,255 @@

package org.springframework.ai.azure.openai;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Map;
import java.util.Objects;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.StringUtils;

/**
* Utility enumeration for representing the response format that may be requested from the
* Azure OpenAI model. Please check <a href=
* "https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format">OpenAI
* API documentation</a> for more details.
*/
public enum AzureOpenAiResponseFormat {

// default value used by OpenAI
TEXT,
/*
* From the OpenAI API documentation: Compatability: Compatible with GPT-4 Turbo and
* all GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Caveats: This enables JSON
* mode, which guarantees the message the model generates is valid JSON. Important:
* when using JSON mode, you must also instruct the model to produce JSON yourself via
* a system or user message. Without this, the model may generate an unending stream
* of whitespace until the generation reaches the token limit, resulting in a
* long-running and seemingly "stuck" request. Also note that the message content may
* be partially cut off if finish_reason="length", which indicates the generation
* exceeded max_tokens or the conversation exceeded the max context length.
@JsonInclude(Include.NON_NULL)
public class AzureOpenAiResponseFormat {

/**
* Type Must be one of 'text', 'json_object' or 'json_schema'.
*/
@JsonProperty("type")
private Type type;

/**
* JSON schema object that describes the format of the JSON object. Only applicable
* when type is 'json_schema'.
*/
JSON
@JsonProperty("json_schema")
private JsonSchema jsonSchema = null;

private String schema;

public AzureOpenAiResponseFormat() {

}

public Type getType() {
return this.type;
}

public void setType(Type type) {
this.type = type;
}

public JsonSchema getJsonSchema() {
return this.jsonSchema;
}

public void setJsonSchema(JsonSchema jsonSchema) {
this.jsonSchema = jsonSchema;
}

public String getSchema() {
return this.schema;
}

public void setSchema(String schema) {
this.schema = schema;
if (schema != null) {
this.jsonSchema = JsonSchema.builder().schema(schema).strict(true).build();
}
}

private AzureOpenAiResponseFormat(Type type, JsonSchema jsonSchema) {
this.type = type;
this.jsonSchema = jsonSchema;
}

public AzureOpenAiResponseFormat(Type type, String schema) {
this(type, StringUtils.hasText(schema) ? JsonSchema.builder().schema(schema).strict(true).build() : null);
}

public static Builder builder() {
return new Builder();
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
AzureOpenAiResponseFormat that = (AzureOpenAiResponseFormat) o;
return this.type == that.type && Objects.equals(this.jsonSchema, that.jsonSchema);
}

@Override
public int hashCode() {
return Objects.hash(this.type, this.jsonSchema);
}

@Override
public String toString() {
return "ResponseFormat{" + "type=" + this.type + ", jsonSchema=" + this.jsonSchema + '}';
}

public static final class Builder {

private Type type;

private JsonSchema jsonSchema;

private Builder() {
}

public Builder type(Type type) {
this.type = type;
return this;
}

public Builder jsonSchema(JsonSchema jsonSchema) {
this.jsonSchema = jsonSchema;
return this;
}

public Builder jsonSchema(String jsonSchema) {
this.jsonSchema = JsonSchema.builder().schema(jsonSchema).build();
return this;
}

public AzureOpenAiResponseFormat build() {
return new AzureOpenAiResponseFormat(this.type, this.jsonSchema);
}

}

public enum Type {

/**
* Generates a text response. (default)
*/
@JsonProperty("text")
TEXT,

/**
* Enables JSON mode, which guarantees the message the model generates is valid
* JSON.
*/
@JsonProperty("json_object")
JSON_OBJECT,

/**
* Enables Structured Outputs which guarantees the model will match your supplied
* JSON schema.
*/
@JsonProperty("json_schema")
JSON_SCHEMA

}

/**
* JSON schema object that describes the format of the JSON object. Applicable for the
* 'json_schema' type only.
*/
@JsonInclude(Include.NON_NULL)
public static class JsonSchema {

@JsonProperty("name")
private String name;

@JsonProperty("schema")
private Map<String, Object> schema;

@JsonProperty("strict")
private Boolean strict;

public JsonSchema() {

}

public String getName() {
return this.name;
}

public Map<String, Object> getSchema() {
return this.schema;
}

public Boolean getStrict() {
return this.strict;
}

private JsonSchema(String name, Map<String, Object> schema, Boolean strict) {
this.name = name;
this.schema = schema;
this.strict = strict;
}

public static Builder builder() {
return new Builder();
}

@Override
public int hashCode() {
return Objects.hash(this.name, this.schema, this.strict);
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
JsonSchema that = (JsonSchema) o;
return Objects.equals(this.name, that.name) && Objects.equals(this.schema, that.schema)
&& Objects.equals(this.strict, that.strict);
}

public static final class Builder {

private String name = "custom_schema";

private Map<String, Object> schema;

private Boolean strict = true;

private Builder() {
}

public Builder name(String name) {
this.name = name;
return this;
}

public Builder schema(Map<String, Object> schema) {
this.schema = schema;
return this;
}

public Builder schema(String schema) {
this.schema = ModelOptionsUtils.jsonToMap(schema);
return this;
}

public Builder strict(Boolean strict) {
this.strict = strict;
return this;
}

public JsonSchema build() {
return new JsonSchema(this.name, this.schema, this.strict);
}

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;

import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
import org.springframework.ai.chat.prompt.Prompt;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -68,7 +69,7 @@ public void createRequestWithChatOptions() {
.logprobs(true)
.topLogprobs(5)
.enhancements(mockAzureChatEnhancementConfiguration)
.responseFormat(AzureOpenAiResponseFormat.TEXT)
.responseFormat(AzureOpenAiResponseFormat.builder().type(Type.TEXT).build())
.build();

var client = AzureOpenAiChatModel.builder()
Expand Down Expand Up @@ -114,7 +115,7 @@ public void createRequestWithChatOptions() {
.logprobs(true)
.topLogprobs(4)
.enhancements(anotherMockAzureChatEnhancementConfiguration)
.responseFormat(AzureOpenAiResponseFormat.JSON)
.responseFormat(AzureOpenAiResponseFormat.builder().type(Type.JSON_OBJECT).build())
.build();

requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content", runtimeOptions));
Expand Down