22
22
23
23
import org .junit .jupiter .api .Test ;
24
24
25
+ import org .springframework .ai .chat .model .ToolContext ;
25
26
import org .springframework .ai .tool .definition .DefaultToolDefinition ;
26
27
import org .springframework .ai .tool .definition .ToolDefinition ;
27
28
28
29
import static org .assertj .core .api .Assertions .assertThat ;
30
+ import static org .assertj .core .api .Assertions .assertThatThrownBy ;
29
31
30
32
/**
31
33
* Tests for {@link MethodToolCallback} with generic types.
@@ -137,6 +139,76 @@ void testNestedGenericType() throws Exception {
137
139
assertThat (result ).isEqualTo ("2 maps processed: [{a=1, b=2}, {c=3, d=4}]" );
138
140
}
139
141
142
+ @ Test
143
+ void testToolContextType () throws Exception {
144
+ // Create a test object with a method that takes a List<Map<String, Integer>>
145
+ TestGenericClass testObject = new TestGenericClass ();
146
+ Method method = TestGenericClass .class .getMethod ("processStringListInToolContext" , ToolContext .class );
147
+
148
+ // Create a tool definition
149
+ ToolDefinition toolDefinition = DefaultToolDefinition .builder ()
150
+ .name ("processToolContext" )
151
+ .description ("Process tool context" )
152
+ .inputSchema ("{}" )
153
+ .build ();
154
+
155
+ // Create a MethodToolCallback
156
+ MethodToolCallback callback = MethodToolCallback .builder ()
157
+ .toolDefinition (toolDefinition )
158
+ .toolMethod (method )
159
+ .toolObject (testObject )
160
+ .build ();
161
+
162
+ // Create an empty JSON input
163
+ String toolInput = """
164
+ {}
165
+ """ ;
166
+
167
+ // Create a toolContext
168
+ ToolContext toolContext = new ToolContext (Map .of ("foo" , "bar" ));
169
+
170
+ // Call the tool
171
+ String result = callback .call (toolInput , toolContext );
172
+
173
+ // Verify the result
174
+ assertThat (result ).isEqualTo ("1 entries processed {foo=bar}" );
175
+ }
176
+
177
+ @ Test
178
+ void testToolContextTypeWithNonToolContextArgs () throws Exception {
179
+ // Create a test object with a method that takes a List<String>
180
+ TestGenericClass testObject = new TestGenericClass ();
181
+ Method method = TestGenericClass .class .getMethod ("processStringList" , List .class );
182
+
183
+ // Create a tool definition
184
+ ToolDefinition toolDefinition = DefaultToolDefinition .builder ()
185
+ .name ("processStringList" )
186
+ .description ("Process a list of strings" )
187
+ .inputSchema ("{}" )
188
+ .build ();
189
+
190
+ // Create a MethodToolCallback
191
+ MethodToolCallback callback = MethodToolCallback .builder ()
192
+ .toolDefinition (toolDefinition )
193
+ .toolMethod (method )
194
+ .toolObject (testObject )
195
+ .build ();
196
+
197
+ // Create a JSON input with a list of strings
198
+ String toolInput = """
199
+ {
200
+ "strings": ["one", "two", "three"]
201
+ }
202
+ """ ;
203
+
204
+ // Create a toolContext
205
+ ToolContext toolContext = new ToolContext (Map .of ("foo" , "bar" ));
206
+
207
+ // Call the tool and verify
208
+ assertThatThrownBy (() -> callback .call (toolInput , toolContext )).isInstanceOf (IllegalArgumentException .class )
209
+ .hasMessageContaining ("ToolContext is required by the method as an argument" );
210
+ }
211
+
140
212
/**
141
213
* Test class with methods that use generic types.
142
214
*/
@@ -154,6 +226,11 @@ public String processListOfMaps(List<Map<String, Integer>> listOfMaps) {
154
226
return listOfMaps .size () + " maps processed: " + listOfMaps ;
155
227
}
156
228
229
+ public String processStringListInToolContext (ToolContext toolContext ) {
230
+ Map <String , Object > context = toolContext .getContext ();
231
+ return context .size () + " entries processed " + context ;
232
+ }
233
+
157
234
}
158
235
159
236
}
0 commit comments