Skip to content

Commit 944c4eb

Browse files
committed
Fix Ollama client options NPE and lack for SYS message handling
- Fix the Ollama options merging to pervent NPE. - Fix the Ollama handling for SYS messages. - Fix the BeanOutputParser to support JSON Schema reponses. - All Ollama Parsers tests pass now. Resolves: #258 , #273
1 parent aa8c385 commit 944c4eb

File tree

5 files changed

+27
-15
lines changed

5 files changed

+27
-15
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,13 @@ private OllamaApi.ChatRequest request(Prompt prompt, String model, boolean strea
132132
List<OllamaApi.Message> ollamaMessages = prompt.getInstructions()
133133
.stream()
134134
.filter(message -> message.getMessageType() == MessageType.USER
135-
|| message.getMessageType() == MessageType.ASSISTANT)
135+
|| message.getMessageType() == MessageType.ASSISTANT
136+
|| message.getMessageType() == MessageType.SYSTEM)
136137
.map(m -> OllamaApi.Message.builder(toRole(m)).withContent(m.getContent()).build())
137138
.toList();
138139

139140
// runtime options
140-
Map<String, Object> promptOptions = objectToMap(prompt.getOptions());
141-
Map<String, Object> clientOptionsToUse = merge(promptOptions, this.clientOptions, HashMap.class);
141+
Map<String, Object> clientOptionsToUse = merge(prompt.getOptions(), this.clientOptions, HashMap.class);
142142

143143
return ChatRequest.builder(model)
144144
.withStream(stream)
@@ -169,6 +169,9 @@ public static <T> T mapToClass(Map<String, Object> source, Class<T> clazz) {
169169
}
170170

171171
public static <T> T merge(Object source, Object target, Class<T> clazz) {
172+
if (source == null) {
173+
source = Map.of();
174+
}
172175
Map<String, Object> sourceMap = objectToMap(source);
173176
Map<String, Object> targetMap = objectToMap(target);
174177

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class OllamaChatClientIT {
4747
private static final Log logger = LogFactory.getLog(OllamaChatClientIT.class);
4848

4949
@Container
50-
static GenericContainer<?> ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.16").withExposedPorts(11434);
50+
static GenericContainer<?> ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.21").withExposedPorts(11434);
5151

5252
static String baseUrl;
5353

@@ -86,7 +86,6 @@ void roleTest() {
8686
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
8787
}
8888

89-
@Disabled("TODO: Fix the parser instructions to return the correct format")
9089
@Test
9190
void outputParser() {
9291
DefaultConversionService conversionService = new DefaultConversionService();
@@ -106,7 +105,6 @@ void outputParser() {
106105
assertThat(list).hasSize(5);
107106
}
108107

109-
@Disabled("TODO: Fix the parser instructions to return the correct format")
110108
@Test
111109
void mapOutputParser() {
112110
MapOutputParser outputParser = new MapOutputParser();
@@ -131,7 +129,6 @@ void mapOutputParser() {
131129
record ActorsFilmsRecord(String actor, List<String> movies) {
132130
}
133131

134-
@Disabled("TODO: Fix the parser instructions to return the correct format")
135132
@Test
136133
void beanOutputParserRecords() {
137134

@@ -141,7 +138,6 @@ void beanOutputParserRecords() {
141138
String template = """
142139
Generate the filmography of 5 movies for Tom Hanks.
143140
{format}
144-
Remove Markdown code blocks from the output.
145141
""";
146142
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
147143
Prompt prompt = new Prompt(promptTemplate.createMessage());
@@ -152,7 +148,6 @@ void beanOutputParserRecords() {
152148
assertThat(actorsFilms.movies()).hasSize(5);
153149
}
154150

155-
@Disabled("TODO: Fix the parser instructions to return the correct format")
156151
@Test
157152
void beanStreamOutputParserRecords() {
158153

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class OllamaEmbeddingClientIT {
3030
private static final Log logger = LogFactory.getLog(OllamaApiIT.class);
3131

3232
@Container
33-
static GenericContainer<?> ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.16").withExposedPorts(11434);
33+
static GenericContainer<?> ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.21").withExposedPorts(11434);
3434

3535
static String baseUrl;
3636

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public class OllamaApiIT {
5151
private static final Log logger = LogFactory.getLog(OllamaApiIT.class);
5252

5353
@Container
54-
static GenericContainer<?> ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.16").withExposedPorts(11434);
54+
static GenericContainer<?> ollamaContainer = new GenericContainer<>("ollama/ollama:0.1.21").withExposedPorts(11434);
5555

5656
static OllamaApi ollamaApi;
5757

@@ -87,9 +87,14 @@ public void chat() {
8787

8888
var request = ChatRequest.builder("orca-mini")
8989
.withStream(false)
90-
.withMessages(List.of(Message.builder(Role.USER)
91-
.withContent("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?")
92-
.build()))
90+
.withMessages(List.of(
91+
Message.builder(Role.SYSTEM)
92+
.withContent("You are geography teacher. You are talking to a student.")
93+
.build(),
94+
Message.builder(Role.USER)
95+
.withContent("What is the capital of Bulgaria and what is the size? "
96+
+ "What it the national anthem?")
97+
.build()))
9398
.withOptions(OllamaOptions.create().withTemperature(0.9f))
9499
.build();
95100

@@ -127,7 +132,7 @@ public void streamingChat() {
127132
.collect(Collectors.joining("\n"))).contains("Sofia");
128133

129134
ChatResponse lastResponse = responses.get(responses.size() - 1);
130-
assertThat(lastResponse.message()).isNull();
135+
assertThat(lastResponse.message().content()).isEmpty();
131136
assertThat(lastResponse.done()).isTrue();
132137
}
133138

spring-ai-core/src/main/java/org/springframework/ai/parser/BeanOutputParser.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
2626
import com.github.victools.jsonschema.module.jackson.JacksonModule;
2727

28+
import java.util.Map;
2829
import java.util.Objects;
2930

3031
import static com.github.victools.jsonschema.generator.OptionPreset.PLAIN_JSON;
@@ -95,6 +96,13 @@ private void generateSchema() {
9596
*/
9697
public T parse(String text) {
9798
try {
99+
// If the response is a JSON Schema, extract the properties and use them as
100+
// the
101+
// response.
102+
Map<String, Object> map = this.objectMapper.readValue(text, Map.class);
103+
if (map.containsKey("$schema")) {
104+
text = this.objectMapper.writeValueAsString(map.get("properties"));
105+
}
98106
return (T) this.objectMapper.readValue(text, this.clazz);
99107
}
100108
catch (JsonProcessingException e) {
@@ -122,6 +130,7 @@ public String getFormat() {
122130
String template = """
123131
Your response should be in JSON format.
124132
Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
133+
Do not include markdown code blocks in your response.
125134
Here is the JSON Schema instance your output must adhere to:
126135
```%s```
127136
""";

0 commit comments

Comments
 (0)