17
17
package org .springframework .ai .model .function ;
18
18
19
19
import java .util .function .BiFunction ;
20
+ import java .util .function .Consumer ;
20
21
import java .util .function .Function ;
22
+ import java .util .function .Supplier ;
21
23
22
24
import com .fasterxml .jackson .annotation .JsonClassDescription ;
25
+ import kotlin .jvm .functions .Function0 ;
23
26
import kotlin .jvm .functions .Function1 ;
24
27
import kotlin .jvm .functions .Function2 ;
25
28
30
33
import org .springframework .context .annotation .Description ;
31
34
import org .springframework .context .support .GenericApplicationContext ;
32
35
import org .springframework .core .KotlinDetector ;
36
+ import org .springframework .core .ParameterizedTypeReference ;
33
37
import org .springframework .core .ResolvableType ;
34
38
import org .springframework .lang .NonNull ;
35
39
import org .springframework .lang .Nullable ;
38
42
/**
39
43
* A Spring {@link ApplicationContextAware} implementation that provides a way to retrieve
40
44
* a {@link Function} from the Spring context and wrap it into a {@link FunctionCallback}.
41
- *
45
+ * <p>
42
46
* The name of the function is determined by the bean name.
43
- *
47
+ * <p>
44
48
* The description of the function is determined by the following rules:
45
49
* <ul>
46
50
* <li>Provided as a default description</li>
@@ -69,24 +73,28 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext
69
73
70
74
@ SuppressWarnings ({ "unchecked" })
71
75
public FunctionCallback getFunctionCallback (@ NonNull String beanName , @ Nullable String defaultDescription ) {
72
-
73
76
ResolvableType functionType = TypeResolverHelper .resolveBeanType (this .applicationContext , beanName );
74
- ResolvableType functionInputType = TypeResolverHelper .getFunctionArgumentType (functionType , 0 );
77
+ ResolvableType functionInputType = (ResolvableType .forType (Supplier .class ).isAssignableFrom (functionType ))
78
+ ? ResolvableType .forType (Void .class ) : TypeResolverHelper .getFunctionArgumentType (functionType , 0 );
79
+
80
+ String functionDescription = resolveFunctionDescription (beanName , defaultDescription ,
81
+ functionInputType .toClass ());
82
+ Object bean = this .applicationContext .getBean (beanName );
83
+
84
+ return buildFunctionCallback (beanName , functionType , functionInputType , functionDescription , bean );
85
+ }
75
86
76
- Class <?> functionInputClass = functionInputType . toClass ();
87
+ private String resolveFunctionDescription ( String beanName , String defaultDescription , Class <?> functionInputClass ) {
77
88
String functionDescription = defaultDescription ;
78
89
79
90
if (!StringUtils .hasText (functionDescription )) {
80
- // Look for a Description annotation on the bean
81
91
Description descriptionAnnotation = this .applicationContext .findAnnotationOnBean (beanName ,
82
92
Description .class );
83
-
84
93
if (descriptionAnnotation != null ) {
85
94
functionDescription = descriptionAnnotation .value ();
86
95
}
87
96
88
97
if (!StringUtils .hasText (functionDescription )) {
89
- // Look for a JsonClassDescription annotation on the input class
90
98
JsonClassDescription jsonClassDescriptionAnnotation = functionInputClass
91
99
.getAnnotation (JsonClassDescription .class );
92
100
if (jsonClassDescriptionAnnotation != null ) {
@@ -95,51 +103,79 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
95
103
}
96
104
97
105
if (!StringUtils .hasText (functionDescription )) {
98
- throw new IllegalStateException ("Could not determine function description."
106
+ throw new IllegalStateException ("Could not determine function description. "
99
107
+ "Please provide a description either as a default parameter, via @Description annotation on the bean "
100
108
+ "or @JsonClassDescription annotation on the input class." );
101
109
}
102
110
}
103
111
104
- Object bean = this .applicationContext .getBean (beanName );
112
+ return functionDescription ;
113
+ }
114
+
115
+ private FunctionCallback buildFunctionCallback (String beanName , ResolvableType functionType ,
116
+ ResolvableType functionInputType , String functionDescription , Object bean ) {
105
117
106
118
if (KotlinDetector .isKotlinPresent ()) {
107
119
if (KotlinDelegate .isKotlinFunction (functionType .toClass ())) {
108
120
return FunctionCallback .builder ()
109
121
.schemaType (this .schemaType )
110
122
.description (functionDescription )
111
123
.function (beanName , KotlinDelegate .wrapKotlinFunction (bean ))
112
- .inputType (functionInputClass )
124
+ .inputType (ParameterizedTypeReference . forType ( functionInputType . getType ()) )
113
125
.build ();
114
126
}
115
- else if (KotlinDelegate .isKotlinBiFunction (functionType .toClass ())) {
127
+ if (KotlinDelegate .isKotlinBiFunction (functionType .toClass ())) {
116
128
return FunctionCallback .builder ()
117
129
.description (functionDescription )
118
130
.schemaType (this .schemaType )
119
131
.function (beanName , KotlinDelegate .wrapKotlinBiFunction (bean ))
120
- .inputType (functionInputClass )
132
+ .inputType (ParameterizedTypeReference .forType (functionInputType .getType ()))
133
+ .build ();
134
+ }
135
+ if (KotlinDelegate .isKotlinSupplier (functionType .toClass ())) {
136
+ return FunctionCallback .builder ()
137
+ .description (functionDescription )
138
+ .schemaType (this .schemaType )
139
+ .function (beanName , KotlinDelegate .wrapKotlinSupplier (bean ))
140
+ .inputType (ParameterizedTypeReference .forType (functionInputType .getType ()))
121
141
.build ();
122
142
}
123
143
}
144
+
124
145
if (bean instanceof Function <?, ?> function ) {
125
146
return FunctionCallback .builder ()
126
147
.schemaType (this .schemaType )
127
148
.description (functionDescription )
128
149
.function (beanName , function )
129
- .inputType (functionInputClass )
150
+ .inputType (ParameterizedTypeReference . forType ( functionInputType . getType ()) )
130
151
.build ();
131
152
}
132
- else if (bean instanceof BiFunction <?, ?, ?>) {
153
+ if (bean instanceof BiFunction <?, ?, ?>) {
133
154
return FunctionCallback .builder ()
134
155
.description (functionDescription )
135
156
.schemaType (this .schemaType )
136
157
.function (beanName , (BiFunction <?, ToolContext , ?>) bean )
137
- .inputType (functionInputClass )
158
+ .inputType (ParameterizedTypeReference .forType (functionInputType .getType ()))
159
+ .build ();
160
+ }
161
+ if (bean instanceof Supplier <?> supplier ) {
162
+ return FunctionCallback .builder ()
163
+ .description (functionDescription )
164
+ .schemaType (this .schemaType )
165
+ .function (beanName , supplier )
166
+ .inputType (ParameterizedTypeReference .forType (functionInputType .getType ()))
138
167
.build ();
139
168
}
140
- else {
141
- throw new IllegalStateException ();
169
+ if (bean instanceof Consumer <?> consumer ) {
170
+ return FunctionCallback .builder ()
171
+ .description (functionDescription )
172
+ .schemaType (this .schemaType )
173
+ .function (beanName , consumer )
174
+ .inputType (ParameterizedTypeReference .forType (functionInputType .getType ()))
175
+ .build ();
142
176
}
177
+
178
+ throw new IllegalStateException ("Unsupported function type" );
143
179
}
144
180
145
181
public enum SchemaType {
@@ -148,7 +184,16 @@ public enum SchemaType {
148
184
149
185
}
150
186
151
- private static class KotlinDelegate {
187
+ private static final class KotlinDelegate {
188
+
189
+ public static boolean isKotlinSupplier (Class <?> clazz ) {
190
+ return Function0 .class .isAssignableFrom (clazz );
191
+ }
192
+
193
+ @ SuppressWarnings ("unchecked" )
194
+ public static Supplier <?> wrapKotlinSupplier (Object function ) {
195
+ return () -> ((Function0 <Object >) function ).invoke ();
196
+ }
152
197
153
198
public static boolean isKotlinFunction (Class <?> clazz ) {
154
199
return Function1 .class .isAssignableFrom (clazz );
0 commit comments