Skip to content

Commit 3c40268

Browse files
Shane Witbeckmarkpollack
authored andcommitted
Fix PromptTemplate to handle Arrays/Lists
Fixes #631
1 parent 7252ba1 commit 3c40268

File tree

2 files changed

+83
-28
lines changed

2 files changed

+83
-28
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import org.antlr.runtime.Token;
1919
import org.antlr.runtime.TokenStream;
20-
import org.springframework.ai.parser.OutputParser;
2120
import org.springframework.ai.chat.messages.Media;
2221
import org.springframework.ai.chat.messages.Message;
2322
import org.springframework.ai.chat.messages.UserMessage;
23+
import org.springframework.ai.parser.OutputParser;
2424
import org.springframework.core.io.Resource;
2525
import org.springframework.util.StreamUtils;
2626
import org.stringtemplate.v4.ST;
@@ -31,8 +31,6 @@
3131
import java.nio.charset.Charset;
3232
import java.util.*;
3333
import java.util.Map.Entry;
34-
import java.util.stream.Collectors;
35-
import java.util.stream.IntStream;
3634

3735
public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions {
3836

@@ -161,12 +159,6 @@ private String renderResource(Resource resource) {
161159
catch (IOException e) {
162160
throw new RuntimeException(e);
163161
}
164-
// try (InputStream inputStream = resource.getInputStream()) {
165-
// return StreamUtils.copyToString(inputStream, Charset.defaultCharset());
166-
// }
167-
// catch (IOException ex) {
168-
// throw new RuntimeException(ex);
169-
// }
170162
}
171163

172164
@Override
@@ -196,22 +188,54 @@ public Prompt create(Map<String, Object> model) {
196188

197189
public Set<String> getInputVariables() {
198190
TokenStream tokens = this.st.impl.tokens;
199-
return IntStream.range(0, tokens.range())
200-
.mapToObj(tokens::get)
201-
.filter(token -> token.getType() == STLexer.ID)
202-
.map(Token::getText)
203-
.collect(Collectors.toSet());
191+
Set<String> inputVariables = new HashSet<>();
192+
boolean isInsideList = false;
193+
194+
for (int i = 0; i < tokens.size(); i++) {
195+
Token token = tokens.get(i);
196+
197+
if (token.getType() == STLexer.LDELIM && i + 1 < tokens.size()
198+
&& tokens.get(i + 1).getType() == STLexer.ID) {
199+
if (i + 2 < tokens.size() && tokens.get(i + 2).getType() == STLexer.COLON) {
200+
inputVariables.add(tokens.get(i + 1).getText());
201+
isInsideList = true;
202+
}
203+
}
204+
else if (token.getType() == STLexer.RDELIM) {
205+
isInsideList = false;
206+
}
207+
else if (!isInsideList && token.getType() == STLexer.ID) {
208+
inputVariables.add(token.getText());
209+
}
210+
}
211+
212+
return inputVariables;
204213
}
205214

206-
protected void validate(Map<String, Object> model) {
215+
private Set<String> getModelKeys(Map<String, Object> model) {
207216
Set<String> dynamicVariableNames = new HashSet<>(this.dynamicModel.keySet());
208217
Set<String> modelVariables = new HashSet<>(model.keySet());
209218
modelVariables.addAll(dynamicVariableNames);
210-
Set<String> missingEntries = new HashSet<>(getInputVariables());
211-
missingEntries.removeAll(modelVariables);
212-
if (!missingEntries.isEmpty()) {
219+
return modelVariables;
220+
}
221+
222+
protected void validate(Map<String, Object> model) {
223+
224+
Set<String> templateTokens = getInputVariables();
225+
Set<String> modelKeys = getModelKeys(model);
226+
227+
// Check if model provides all keys required by the template
228+
if (!modelKeys.containsAll(templateTokens)) {
229+
templateTokens.removeAll(modelKeys);
230+
throw new IllegalStateException(
231+
"All template variables were not replaced. Missing variable names are " + templateTokens);
232+
}
233+
234+
// Check if the template references any keys not provided by the model
235+
if (!templateTokens.containsAll(modelKeys)) {
236+
modelKeys.removeAll(templateTokens);
213237
throw new IllegalStateException(
214-
"All template variables were not replaced. Missing variable names are " + missingEntries);
238+
"All model variables were not replaced. Missing variable names are " + modelKeys);
215239
}
216240
}
217241

spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,44 @@
1717

1818
import org.junit.jupiter.api.Disabled;
1919
import org.junit.jupiter.api.Test;
20+
import org.springframework.ai.chat.messages.Message;
2021
import org.springframework.ai.chat.prompt.PromptTemplate;
2122
import org.springframework.core.io.InputStreamResource;
2223
import org.springframework.core.io.Resource;
2324

2425
import java.io.ByteArrayInputStream;
2526
import java.io.InputStream;
2627
import java.nio.charset.Charset;
28+
import java.util.Arrays;
2729
import java.util.HashMap;
30+
import java.util.List;
2831
import java.util.Map;
2932

33+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
3034
import static org.junit.jupiter.api.Assertions.assertEquals;
3135
import static org.junit.jupiter.api.Assertions.assertThrows;
3236

3337
public class PromptTemplateTest {
3438

39+
@Test
40+
public void testRenderWithList() {
41+
String templateString = "The items are:\n{items:{item | - {item}\n}}";
42+
List<String> itemList = Arrays.asList("apple", "banana", "cherry");
43+
PromptTemplate promptTemplate = new PromptTemplate(templateString);
44+
Message message = promptTemplate.createMessage(Map.of("items", itemList));
45+
46+
String expected = "The items are:\n" + "- apple\n" + "- banana\n" + "- cherry\n";
47+
48+
assertEquals(expected, message.getContent());
49+
50+
PromptTemplate unfilledPromptTemplate = new PromptTemplate(templateString);
51+
assertThatExceptionOfType(IllegalStateException.class).isThrownBy(unfilledPromptTemplate::render)
52+
.withMessage("All template variables were not replaced. Missing variable names are [items]");
53+
}
54+
3555
@Test
3656
public void testRender() {
37-
// Create a map with string keys and object values to serve as a generative for
38-
// testing
39-
Map<String, Object> model = new HashMap<>();
40-
model.put("key1", "value1");
41-
model.put("key2", true);
57+
Map<String, Object> model = createTestMap();
4258
model.put("key3", 100);
4359

4460
// Create a simple template with placeholders for keys in the generative
@@ -58,14 +74,29 @@ public void testRender() {
5874
assertEquals(expected, result);
5975
}
6076

61-
@Disabled("Need to improve PromptTemplate to better handle Resource toString and tracking with 'dynamicModel' for underlying StringTemplate")
6277
@Test
63-
public void testRenderResource() throws Exception {
64-
// Create a map with string keys and object values to serve as a generative for
65-
// testing
78+
public void testRenderResource() {
79+
Map<String, Object> model = createTestMap();
80+
InputStream inputStream = new ByteArrayInputStream(
81+
"key1's value is {key1} and key2's value is {key2}".getBytes(Charset.defaultCharset()));
82+
Resource resource = new InputStreamResource(inputStream);
83+
PromptTemplate promptTemplate = new PromptTemplate(resource, model);
84+
String expected = "key1's value is value1 and key2's value is true";
85+
String result = promptTemplate.render();
86+
assertEquals(expected, result);
87+
}
88+
89+
private static Map<String, Object> createTestMap() {
6690
Map<String, Object> model = new HashMap<>();
6791
model.put("key1", "value1");
6892
model.put("key2", true);
93+
return model;
94+
}
95+
96+
@Disabled("Need to improve PromptTemplate to better handle Resource toString and tracking with 'dynamicModel' for underlying StringTemplate")
97+
@Test
98+
public void testRenderResourceAsValue() throws Exception {
99+
Map<String, Object> model = createTestMap();
69100

70101
// Create an input stream for the resource
71102
InputStream inputStream = new ByteArrayInputStream("it costs 100".getBytes(Charset.defaultCharset()));

0 commit comments

Comments
 (0)