|
17 | 17 |
|
18 | 18 | import org.antlr.runtime.Token;
|
19 | 19 | import org.antlr.runtime.TokenStream;
|
20 |
| -import org.springframework.ai.parser.OutputParser; |
21 | 20 | import org.springframework.ai.chat.messages.Media;
|
22 | 21 | import org.springframework.ai.chat.messages.Message;
|
23 | 22 | import org.springframework.ai.chat.messages.UserMessage;
|
| 23 | +import org.springframework.ai.parser.OutputParser; |
24 | 24 | import org.springframework.core.io.Resource;
|
25 | 25 | import org.springframework.util.StreamUtils;
|
26 | 26 | import org.stringtemplate.v4.ST;
|
|
31 | 31 | import java.nio.charset.Charset;
|
32 | 32 | import java.util.*;
|
33 | 33 | import java.util.Map.Entry;
|
34 |
| -import java.util.stream.Collectors; |
35 |
| -import java.util.stream.IntStream; |
36 | 34 |
|
37 | 35 | public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions {
|
38 | 36 |
|
@@ -161,12 +159,6 @@ private String renderResource(Resource resource) {
|
161 | 159 | catch (IOException e) {
|
162 | 160 | throw new RuntimeException(e);
|
163 | 161 | }
|
164 |
| - // try (InputStream inputStream = resource.getInputStream()) { |
165 |
| - // return StreamUtils.copyToString(inputStream, Charset.defaultCharset()); |
166 |
| - // } |
167 |
| - // catch (IOException ex) { |
168 |
| - // throw new RuntimeException(ex); |
169 |
| - // } |
170 | 162 | }
|
171 | 163 |
|
172 | 164 | @Override
|
@@ -196,22 +188,54 @@ public Prompt create(Map<String, Object> model) {
|
196 | 188 |
|
197 | 189 | public Set<String> getInputVariables() {
|
198 | 190 | TokenStream tokens = this.st.impl.tokens;
|
199 |
| - return IntStream.range(0, tokens.range()) |
200 |
| - .mapToObj(tokens::get) |
201 |
| - .filter(token -> token.getType() == STLexer.ID) |
202 |
| - .map(Token::getText) |
203 |
| - .collect(Collectors.toSet()); |
| 191 | + Set<String> inputVariables = new HashSet<>(); |
| 192 | + boolean isInsideList = false; |
| 193 | + |
| 194 | + for (int i = 0; i < tokens.size(); i++) { |
| 195 | + Token token = tokens.get(i); |
| 196 | + |
| 197 | + if (token.getType() == STLexer.LDELIM && i + 1 < tokens.size() |
| 198 | + && tokens.get(i + 1).getType() == STLexer.ID) { |
| 199 | + if (i + 2 < tokens.size() && tokens.get(i + 2).getType() == STLexer.COLON) { |
| 200 | + inputVariables.add(tokens.get(i + 1).getText()); |
| 201 | + isInsideList = true; |
| 202 | + } |
| 203 | + } |
| 204 | + else if (token.getType() == STLexer.RDELIM) { |
| 205 | + isInsideList = false; |
| 206 | + } |
| 207 | + else if (!isInsideList && token.getType() == STLexer.ID) { |
| 208 | + inputVariables.add(token.getText()); |
| 209 | + } |
| 210 | + } |
| 211 | + |
| 212 | + return inputVariables; |
204 | 213 | }
|
205 | 214 |
|
206 |
| - protected void validate(Map<String, Object> model) { |
| 215 | + private Set<String> getModelKeys(Map<String, Object> model) { |
207 | 216 | Set<String> dynamicVariableNames = new HashSet<>(this.dynamicModel.keySet());
|
208 | 217 | Set<String> modelVariables = new HashSet<>(model.keySet());
|
209 | 218 | modelVariables.addAll(dynamicVariableNames);
|
210 |
| - Set<String> missingEntries = new HashSet<>(getInputVariables()); |
211 |
| - missingEntries.removeAll(modelVariables); |
212 |
| - if (!missingEntries.isEmpty()) { |
| 219 | + return modelVariables; |
| 220 | + } |
| 221 | + |
| 222 | + protected void validate(Map<String, Object> model) { |
| 223 | + |
| 224 | + Set<String> templateTokens = getInputVariables(); |
| 225 | + Set<String> modelKeys = getModelKeys(model); |
| 226 | + |
| 227 | + // Check if model provides all keys required by the template |
| 228 | + if (!modelKeys.containsAll(templateTokens)) { |
| 229 | + templateTokens.removeAll(modelKeys); |
| 230 | + throw new IllegalStateException( |
| 231 | + "All template variables were not replaced. Missing variable names are " + templateTokens); |
| 232 | + } |
| 233 | + |
| 234 | + // Check if the template references any keys not provided by the model |
| 235 | + if (!templateTokens.containsAll(modelKeys)) { |
| 236 | + modelKeys.removeAll(templateTokens); |
213 | 237 | throw new IllegalStateException(
|
214 |
| - "All template variables were not replaced. Missing variable names are " + missingEntries); |
| 238 | + "All model variables were not replaced. Missing variable names are " + modelKeys); |
215 | 239 | }
|
216 | 240 | }
|
217 | 241 |
|
|
0 commit comments