Skip to content

Commit 3580849

Browse files
committed
Mistral AI: test package restructuring. Fix failing tests. Add AOT tests
1 parent a81a99f commit 3580849

File tree

5 files changed

+51
-6
lines changed

5 files changed

+51
-6
lines changed
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package org.springframework.ai.mistralai.chat;
17+
package org.springframework.ai.mistralai;
1818

1919
import java.util.Arrays;
2020
import java.util.List;
@@ -36,7 +36,6 @@
3636
import org.springframework.ai.chat.prompt.Prompt;
3737
import org.springframework.ai.chat.prompt.PromptTemplate;
3838
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
39-
import org.springframework.ai.mistralai.MistralAiTestConfiguration;
4039
import org.springframework.ai.parser.BeanOutputParser;
4140
import org.springframework.ai.parser.ListOutputParser;
4241
import org.springframework.ai.parser.MapOutputParser;
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package org.springframework.ai.mistralai.chat;
17+
package org.springframework.ai.mistralai;
1818

1919
import org.junit.jupiter.api.Test;
2020
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
@@ -47,14 +47,15 @@ void chatCompletionDefaultRequestTest() {
4747
assertThat(request.temperature()).isEqualTo(0.7f);
4848
assertThat(request.safePrompt()).isFalse();
4949
assertThat(request.maxTokens()).isNull();
50+
assertThat(request.stream()).isFalse();
5051
}
5152

5253
@Test
5354
void chatCompletionRequestWithOptionsTest() {
5455

5556
var options = MistralAiChatOptions.builder().withTemperature(0.5f).withTopP(0.8f).build();
5657

57-
var request = chatClient.createRequest(new Prompt("test content", options), false);
58+
var request = chatClient.createRequest(new Prompt("test content", options), true);
5859

5960
assertThat(request.messages().size()).isEqualTo(1);
6061
assertThat(request.topP()).isEqualTo(0.8f);
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
package org.springframework.ai.mistralai.embedding;
16+
package org.springframework.ai.mistralai;
1717

1818
import org.junit.jupiter.api.Test;
1919
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mistralai.aot;
18+
19+
import org.junit.jupiter.api.Test;
20+
21+
import org.springframework.ai.mistralai.api.MistralAiApi;
22+
import org.springframework.aot.hint.RuntimeHints;
23+
import org.springframework.aot.hint.TypeReference;
24+
25+
import java.util.Set;
26+
27+
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
28+
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
29+
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;
30+
31+
class MistralAiRuntimeHintsTests {
32+
33+
@Test
34+
void registerHints() {
35+
RuntimeHints runtimeHints = new RuntimeHints();
36+
MistralAiRuntimeHints mistralAiRuntimeHints = new MistralAiRuntimeHints();
37+
mistralAiRuntimeHints.registerHints(runtimeHints, null);
38+
39+
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(MistralAiApi.class);
40+
for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
41+
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
42+
}
43+
}
44+
45+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package org.springframework.ai.mistralai.chat.api;
17+
package org.springframework.ai.mistralai.api;
1818

1919
import java.util.List;
2020

0 commit comments

Comments
 (0)