Skip to content

Commit c038526

Browse files
committed
refactor(openai): consolidate token usage details and add audio tokens support
The commit restructures OpenAI token usage tracking by: - Adding audio_tokens support in PromptTokensDetails - Deprecating individual token getter methods in favor of consolidated records - Introducing new PromptTokensDetails and CompletionTokenDetails records - Updating tests to reflect the new structure Resolves #1369 , #1720
1 parent fb2e752 commit c038526

File tree

3 files changed

+117
-63
lines changed

3 files changed

+117
-63
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,11 +1145,11 @@ public record TopLogProbs(// @formatter:off
11451145
*/
11461146
@JsonInclude(Include.NON_NULL)
11471147
public record Usage(// @formatter:off
1148-
@JsonProperty("completion_tokens") Integer completionTokens,
1149-
@JsonProperty("prompt_tokens") Integer promptTokens,
1150-
@JsonProperty("total_tokens") Integer totalTokens,
1151-
@JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails,
1152-
@JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) { // @formatter:on
1148+
@JsonProperty("completion_tokens") Integer completionTokens,
1149+
@JsonProperty("prompt_tokens") Integer promptTokens,
1150+
@JsonProperty("total_tokens") Integer totalTokens,
1151+
@JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails,
1152+
@JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) { // @formatter:on
11531153

11541154
public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) {
11551155
this(completionTokens, promptTokens, totalTokens, null, null);
@@ -1158,11 +1158,13 @@ public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens
11581158
/**
11591159
* Breakdown of tokens used in the prompt
11601160
*
1161+
* @param audioTokens Audio input tokens present in the prompt.
11611162
* @param cachedTokens Cached tokens present in the prompt.
11621163
*/
11631164
@JsonInclude(Include.NON_NULL)
11641165
public record PromptTokensDetails(// @formatter:off
1165-
@JsonProperty("cached_tokens") Integer cachedTokens) { // @formatter:on
1166+
@JsonProperty("audio_tokens") Integer audioTokens,
1167+
@JsonProperty("cached_tokens") Integer cachedTokens) { // @formatter:on
11661168
}
11671169

11681170
/**
@@ -1178,10 +1180,10 @@ public record PromptTokensDetails(// @formatter:off
11781180
@JsonInclude(Include.NON_NULL)
11791181
@JsonIgnoreProperties(ignoreUnknown = true)
11801182
public record CompletionTokenDetails(// @formatter:off
1181-
@JsonProperty("reasoning_tokens") Integer reasoningTokens,
1182-
@JsonProperty("accepted_prediction_tokens") Integer acceptedPredictionTokens,
1183-
@JsonProperty("audio_tokens") Integer audioTokens,
1184-
@JsonProperty("rejected_prediction_tokens") Integer rejectedPredictionTokens) { // @formatter:on
1183+
@JsonProperty("reasoning_tokens") Integer reasoningTokens,
1184+
@JsonProperty("accepted_prediction_tokens") Integer acceptedPredictionTokens,
1185+
@JsonProperty("audio_tokens") Integer audioTokens,
1186+
@JsonProperty("rejected_prediction_tokens") Integer rejectedPredictionTokens) { // @formatter:on
11851187
}
11861188

11871189
}
@@ -1205,13 +1207,13 @@ public record CompletionTokenDetails(// @formatter:off
12051207
*/
12061208
@JsonInclude(Include.NON_NULL)
12071209
public record ChatCompletionChunk(// @formatter:off
1208-
@JsonProperty("id") String id,
1209-
@JsonProperty("choices") List<ChunkChoice> choices,
1210-
@JsonProperty("created") Long created,
1211-
@JsonProperty("model") String model,
1212-
@JsonProperty("system_fingerprint") String systemFingerprint,
1213-
@JsonProperty("object") String object,
1214-
@JsonProperty("usage") Usage usage) { // @formatter:on
1210+
@JsonProperty("id") String id,
1211+
@JsonProperty("choices") List<ChunkChoice> choices,
1212+
@JsonProperty("created") Long created,
1213+
@JsonProperty("model") String model,
1214+
@JsonProperty("system_fingerprint") String systemFingerprint,
1215+
@JsonProperty("object") String object,
1216+
@JsonProperty("usage") Usage usage) { // @formatter:on
12151217

12161218
/**
12171219
* Chat completion choice.

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
* @author John Blum
2727
* @author Thomas Vitale
2828
* @author David Frizelle
29+
* @author Christian Tzolov
2930
* @since 0.7.0
3031
* @see <a href=
3132
* "https://platform.openai.com/docs/api-reference/completions/object">Completion
@@ -60,52 +61,106 @@ public Long getGenerationTokens() {
6061
return generationTokens != null ? generationTokens.longValue() : 0;
6162
}
6263

63-
public Long getCachedTokens() {
64+
@Override
65+
public Long getTotalTokens() {
66+
Integer totalTokens = getUsage().totalTokens();
67+
if (totalTokens != null) {
68+
return totalTokens.longValue();
69+
}
70+
else {
71+
return getPromptTokens() + getGenerationTokens();
72+
}
73+
}
74+
75+
/**
76+
* @deprecated Use {@link #getPromptTokensDetails()} instead.
77+
*/
78+
@Deprecated
79+
public Long getPromptTokensDetailsCachedTokens() {
6480
OpenAiApi.Usage.PromptTokensDetails promptTokenDetails = getUsage().promptTokensDetails();
6581
Integer cachedTokens = promptTokenDetails != null ? promptTokenDetails.cachedTokens() : null;
6682
return cachedTokens != null ? cachedTokens.longValue() : 0;
6783
}
6884

85+
public PromptTokensDetails getPromptTokensDetails() {
86+
var details = getUsage().promptTokensDetails();
87+
if (details == null) {
88+
return new PromptTokensDetails(0, 0);
89+
}
90+
return new PromptTokensDetails(valueOrZero(details.audioTokens()), valueOrZero(details.cachedTokens()));
91+
}
92+
93+
/**
94+
* @deprecated Use {@link #getCompletionTokenDetails()} instead.
95+
*/
96+
@Deprecated
6997
public Long getReasoningTokens() {
7098
OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails();
7199
Integer reasoningTokens = completionTokenDetails != null ? completionTokenDetails.reasoningTokens() : null;
72100
return reasoningTokens != null ? reasoningTokens.longValue() : 0;
73101
}
74102

103+
/**
104+
* @deprecated Use {@link #getCompletionTokenDetails()} instead.
105+
*/
106+
@Deprecated
75107
public Long getAcceptedPredictionTokens() {
76108
OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails();
77109
Integer acceptedPredictionTokens = completionTokenDetails != null
78110
? completionTokenDetails.acceptedPredictionTokens() : null;
79111
return acceptedPredictionTokens != null ? acceptedPredictionTokens.longValue() : 0;
80112
}
81113

114+
/**
115+
* @deprecated Use {@link #getCompletionTokenDetails()} instead.
116+
*/
117+
@Deprecated
82118
public Long getAudioTokens() {
83119
OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails();
84120
Integer audioTokens = completionTokenDetails != null ? completionTokenDetails.audioTokens() : null;
85121
return audioTokens != null ? audioTokens.longValue() : 0;
86122
}
87123

124+
/**
125+
* @deprecated Use {@link #getCompletionTokenDetails()} instead.
126+
*/
127+
@Deprecated
88128
public Long getRejectedPredictionTokens() {
89129
OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails();
90130
Integer rejectedPredictionTokens = completionTokenDetails != null
91131
? completionTokenDetails.rejectedPredictionTokens() : null;
92132
return rejectedPredictionTokens != null ? rejectedPredictionTokens.longValue() : 0;
93133
}
94134

95-
@Override
96-
public Long getTotalTokens() {
97-
Integer totalTokens = getUsage().totalTokens();
98-
if (totalTokens != null) {
99-
return totalTokens.longValue();
100-
}
101-
else {
102-
return getPromptTokens() + getGenerationTokens();
135+
public CompletionTokenDetails getCompletionTokenDetails() {
136+
var details = getUsage().completionTokenDetails();
137+
if (details == null) {
138+
return new CompletionTokenDetails(0, 0, 0, 0);
103139
}
140+
return new CompletionTokenDetails(valueOrZero(details.reasoningTokens()),
141+
valueOrZero(details.acceptedPredictionTokens()), valueOrZero(details.audioTokens()),
142+
valueOrZero(details.rejectedPredictionTokens()));
143+
}
144+
145+
public record PromptTokensDetails(// @formatter:off
146+
Integer audioTokens,
147+
Integer cachedTokens) {
148+
}
149+
150+
public record CompletionTokenDetails(
151+
Integer reasoningTokens,
152+
Integer acceptedPredictionTokens,
153+
Integer audioTokens,
154+
Integer rejectedPredictionTokens) { // @formatter:on
104155
}
105156

106157
@Override
107158
public String toString() {
108159
return getUsage().toString();
109160
}
110161

162+
private int valueOrZero(Integer value) {
163+
return value != null ? value : 0;
164+
}
165+
111166
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
* Unit tests for {@link OpenAiUsage}.
2727
*
2828
* @author Thomas Vitale
29+
* @author Christian Tzolov
2930
*/
3031
class OpenAiUsageTests {
3132

@@ -76,88 +77,84 @@ void whenPromptAndCompletionTokensDetailsIsNull() {
7677
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, null);
7778
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
7879
assertThat(usage.getTotalTokens()).isEqualTo(300);
79-
assertThat(usage.getCachedTokens()).isEqualTo(0);
80-
assertThat(usage.getReasoningTokens()).isEqualTo(0);
81-
}
82-
83-
@Test
84-
void whenReasoningTokensIsNull() {
85-
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
86-
new OpenAiApi.Usage.CompletionTokenDetails(null, null, null, null));
87-
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
88-
assertThat(usage.getReasoningTokens()).isEqualTo(0);
80+
assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(0);
81+
assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(0);
82+
assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(0);
83+
assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(0);
8984
}
9085

9186
@Test
9287
void whenCompletionTokenDetailsIsPresent() {
9388
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
9489
new OpenAiApi.Usage.CompletionTokenDetails(50, null, null, null));
9590
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
96-
assertThat(usage.getReasoningTokens()).isEqualTo(50);
97-
}
98-
99-
@Test
100-
void whenAcceptedPredictionTokensIsNull() {
101-
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
102-
new OpenAiApi.Usage.CompletionTokenDetails(null, null, null, null));
103-
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
104-
assertThat(usage.getAcceptedPredictionTokens()).isEqualTo(0);
91+
assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(50);
92+
assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(0);
93+
assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(0);
94+
assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(0);
10595
}
10696

10797
@Test
10898
void whenAcceptedPredictionTokensIsPresent() {
10999
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
110100
new OpenAiApi.Usage.CompletionTokenDetails(null, 75, null, null));
111101
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
112-
assertThat(usage.getAcceptedPredictionTokens()).isEqualTo(75);
113-
}
114-
115-
@Test
116-
void whenAudioTokensIsNull() {
117-
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
118-
new OpenAiApi.Usage.CompletionTokenDetails(null, null, null, null));
119-
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
120-
assertThat(usage.getAudioTokens()).isEqualTo(0);
102+
assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(0);
103+
assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(75);
104+
assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(0);
105+
assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(0);
121106
}
122107

123108
@Test
124109
void whenAudioTokensIsPresent() {
125110
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
126111
new OpenAiApi.Usage.CompletionTokenDetails(null, null, 125, null));
127112
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
128-
assertThat(usage.getAudioTokens()).isEqualTo(125);
113+
assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(0);
114+
assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(0);
115+
assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(125);
116+
assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(0);
129117
}
130118

131119
@Test
132120
void whenRejectedPredictionTokensIsNull() {
133121
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
134122
new OpenAiApi.Usage.CompletionTokenDetails(null, null, null, null));
135123
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
136-
assertThat(usage.getRejectedPredictionTokens()).isEqualTo(0);
124+
assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(0);
125+
assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(0);
126+
assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(0);
127+
assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(0);
128+
137129
}
138130

139131
@Test
140132
void whenRejectedPredictionTokensIsPresent() {
141133
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
142134
new OpenAiApi.Usage.CompletionTokenDetails(null, null, null, 25));
143135
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
144-
assertThat(usage.getRejectedPredictionTokens()).isEqualTo(25);
136+
assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(0);
137+
assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(0);
138+
assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(0);
139+
assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(25);
145140
}
146141

147142
@Test
148143
void whenCacheTokensIsNull() {
149-
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, new OpenAiApi.Usage.PromptTokensDetails(null),
150-
null);
144+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300,
145+
new OpenAiApi.Usage.PromptTokensDetails(null, null), null);
151146
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
152-
assertThat(usage.getCachedTokens()).isEqualTo(0);
147+
assertThat(usage.getPromptTokensDetails().audioTokens()).isEqualTo(0);
148+
assertThat(usage.getPromptTokensDetails().cachedTokens()).isEqualTo(0);
153149
}
154150

155151
@Test
156152
void whenCacheTokensIsPresent() {
157-
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, new OpenAiApi.Usage.PromptTokensDetails(15),
158-
null);
153+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300,
154+
new OpenAiApi.Usage.PromptTokensDetails(99, 15), null);
159155
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
160-
assertThat(usage.getCachedTokens()).isEqualTo(15);
156+
assertThat(usage.getPromptTokensDetails().audioTokens()).isEqualTo(99);
157+
assertThat(usage.getPromptTokensDetails().cachedTokens()).isEqualTo(15);
161158
}
162159

163160
}

0 commit comments

Comments
 (0)