Skip to content

Commit 5f9ecdd

Browse files
ricken07tzolov
authored andcommitted
Fixing Log probability information
1 parent 7e03a15 commit 5f9ecdd

File tree

4 files changed

+62
-8
lines changed

4 files changed

+62
-8
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
177177
private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) {
178178
List<Choice> choices = chunk.choices()
179179
.stream()
180-
.map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason()))
180+
.map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason(), cc.logprobs()))
181181
.toList();
182182

183183
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,17 +575,65 @@ public record ChatCompletion(
575575
* @param index The index of the choice in the list of choices.
576576
* @param message A chat completion message generated by the model.
577577
* @param finishReason The reason the model stopped generating tokens.
578+
* @param logprobs Log probability information for the choice.
578579
*/
579580
@JsonInclude(Include.NON_NULL)
580581
public record Choice(
581582
// @formatter:off
582583
@JsonProperty("index") Integer index,
583584
@JsonProperty("message") ChatCompletionMessage message,
584-
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
585+
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
586+
@JsonProperty("logprobs") LogProbs logprobs) {
585587
// @formatter:on
586588
}
587589
}
588590

591+
/**
592+
*
593+
* Log probability information for the choice. anticipation of future changes.
594+
*
595+
* @param content A list of message content tokens with log probability information.
596+
*/
597+
@JsonInclude(Include.NON_NULL)
598+
public record LogProbs(@JsonProperty("content") List<Content> content) {
599+
600+
/**
601+
* Message content tokens with log probability information.
602+
*
603+
* @param token The token.
604+
* @param logprob The log probability of the token.
605+
* @param probBytes A list of integers representing the UTF-8 bytes representation
606+
* of the token. Useful in instances where characters are represented by multiple
607+
* tokens and their byte representations must be combined to generate the correct
608+
* text representation. Can be null if there is no bytes representation for the
609+
* token.
610+
* @param topLogprobs List of the most likely tokens and their log probability, at
611+
* this token position. In rare cases, there may be fewer than the number of
612+
* requested top_logprobs returned.
613+
*/
614+
@JsonInclude(Include.NON_NULL)
615+
public record Content(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob,
616+
@JsonProperty("bytes") List<Integer> probBytes,
617+
@JsonProperty("top_logprobs") List<TopLogProbs> topLogprobs) {
618+
619+
/**
620+
* The most likely tokens and their log probability, at this token position.
621+
*
622+
* @param token The token.
623+
* @param logprob The log probability of the token.
624+
* @param probBytes A list of integers representing the UTF-8 bytes
625+
* representation of the token. Useful in instances where characters are
626+
* represented by multiple tokens and their byte representations must be
627+
* combined to generate the correct text representation. Can be null if there
628+
* is no bytes representation for the token.
629+
*/
630+
@JsonInclude(Include.NON_NULL)
631+
public record TopLogProbs(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob,
632+
@JsonProperty("bytes") List<Integer> probBytes) {
633+
}
634+
}
635+
}
636+
589637
/**
590638
* Represents a streamed chunk of a chat completion response returned by model, based
591639
* on the provided input.
@@ -614,13 +662,15 @@ public record ChatCompletionChunk(
614662
* @param index The index of the choice in the list of choices.
615663
* @param delta A chat completion delta generated by streamed model responses.
616664
* @param finishReason The reason the model stopped generating tokens.
665+
* @param logprobs Log probability information for the choice.
617666
*/
618667
@JsonInclude(Include.NON_NULL)
619668
public record ChunkChoice(
620669
// @formatter:off
621670
@JsonProperty("index") Integer index,
622671
@JsonProperty("delta") ChatCompletionMessage delta,
623-
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
672+
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
673+
@JsonProperty("logprobs") LogProbs logprobs) {
624674
// @formatter:on
625675
}
626676
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ChatCompletionFunction;
2828
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.Role;
2929
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall;
30+
import org.springframework.ai.mistralai.api.MistralAiApi.LogProbs;
3031
import org.springframework.util.CollectionUtils;
3132

3233
/**
@@ -83,8 +84,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
8384
.toList();
8485

8586
var role = current.delta().role() != null ? current.delta().role() : Role.ASSISTANT;
86-
current = new ChunkChoice(current.index(), new ChatCompletionMessage(current.delta().content(),
87-
role, current.delta().name(), toolCallsWithID), current.finishReason());
87+
current = new ChunkChoice(
88+
current.index(), new ChatCompletionMessage(current.delta().content(), role,
89+
current.delta().name(), toolCallsWithID),
90+
current.finishReason(), current.logprobs());
8891
}
8992
}
9093
return current;
@@ -95,8 +98,9 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
9598
Integer index = (current.index() != null ? current.index() : previous.index());
9699

97100
ChatCompletionMessage message = merge(previous.delta(), current.delta());
101+
LogProbs logprobs = (current.logprobs() != null ? current.logprobs() : previous.logprobs());
98102

99-
return new ChunkChoice(index, message, finishReason);
103+
return new ChunkChoice(index, message, finishReason, logprobs);
100104
}
101105

102106
private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ public void beforeEach() {
109109
public void mistralAiChatTransientError() {
110110

111111
var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
112-
ChatCompletionFinishReason.STOP);
112+
ChatCompletionFinishReason.STOP, null);
113113
ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model",
114114
List.of(choice), new MistralAiApi.Usage(10, 10, 10));
115115

@@ -137,7 +137,7 @@ public void mistralAiChatNonTransientError() {
137137
public void mistralAiChatStreamTransientError() {
138138

139139
var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
140-
ChatCompletionFinishReason.STOP);
140+
ChatCompletionFinishReason.STOP, null);
141141
ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l,
142142
"model", List.of(choice));
143143

0 commit comments

Comments
 (0)