Skip to content

Commit 9a8a133

Browse files
ThomasVitaletzolov
authored andcommitted
Fix NPE in OpenAiUsage
Fix gh-1152 Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent 4cacbe8 commit 9a8a133

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
* {@link Usage} implementation for {@literal OpenAI}.
2424
*
2525
* @author John Blum
26+
* @author Thomas Vitale
2627
* @since 0.7.0
2728
* @see <a href=
2829
* "https://platform.openai.com/docs/api-reference/completions/object">Completion
@@ -47,17 +48,25 @@ protected OpenAiApi.Usage getUsage() {
4748

4849
@Override
4950
public Long getPromptTokens() {
50-
return getUsage().promptTokens().longValue();
51+
Integer promptTokens = getUsage().promptTokens();
52+
return promptTokens != null ? promptTokens.longValue() : 0;
5153
}
5254

5355
@Override
5456
public Long getGenerationTokens() {
55-
return getUsage().completionTokens().longValue();
57+
Integer generationTokens = getUsage().completionTokens();
58+
return generationTokens != null ? generationTokens.longValue() : 0;
5659
}
5760

5861
@Override
5962
public Long getTotalTokens() {
60-
return getUsage().totalTokens().longValue();
63+
Integer totalTokens = getUsage().totalTokens();
64+
if (totalTokens != null) {
65+
return totalTokens.longValue();
66+
}
67+
else {
68+
return getPromptTokens() + getGenerationTokens();
69+
}
6170
}
6271

6372
@Override
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package org.springframework.ai.openai.metadata;
2+
3+
import org.junit.jupiter.api.Test;
4+
import org.springframework.ai.openai.api.OpenAiApi;
5+
6+
import static org.assertj.core.api.Assertions.assertThat;
7+
8+
/**
9+
* Unit tests for {@link OpenAiUsage}.
10+
*
11+
* @author Thomas Vitale
12+
*/
13+
class OpenAiUsageTests {
14+
15+
@Test
16+
void whenPromptTokensIsPresent() {
17+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300);
18+
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
19+
assertThat(usage.getPromptTokens()).isEqualTo(200);
20+
}
21+
22+
@Test
23+
void whenPromptTokensIsNull() {
24+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, null, 100);
25+
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
26+
assertThat(usage.getPromptTokens()).isEqualTo(0);
27+
}
28+
29+
@Test
30+
void whenGenerationTokensIsPresent() {
31+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300);
32+
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
33+
assertThat(usage.getGenerationTokens()).isEqualTo(100);
34+
}
35+
36+
@Test
37+
void whenGenerationTokensIsNull() {
38+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(null, 200, 200);
39+
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
40+
assertThat(usage.getGenerationTokens()).isEqualTo(0);
41+
}
42+
43+
@Test
44+
void whenTotalTokensIsPresent() {
45+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300);
46+
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
47+
assertThat(usage.getTotalTokens()).isEqualTo(300);
48+
}
49+
50+
@Test
51+
void whenTotalTokensIsNull() {
52+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, null);
53+
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
54+
assertThat(usage.getTotalTokens()).isEqualTo(300);
55+
}
56+
57+
}

0 commit comments

Comments
 (0)