13
13
* See the License for the specific language governing permissions and
14
14
* limitations under the License.
15
15
*/
16
- package org .springframework .ai .bedrock .llama2 ;
16
+ package org .springframework .ai .bedrock .llama ;
17
17
18
18
import java .util .List ;
19
19
20
20
import reactor .core .publisher .Flux ;
21
21
22
22
import org .springframework .ai .bedrock .MessageToPromptConverter ;
23
- import org .springframework .ai .bedrock .llama2 .api .Llama2ChatBedrockApi ;
24
- import org .springframework .ai .bedrock .llama2 .api .Llama2ChatBedrockApi . Llama2ChatRequest ;
25
- import org .springframework .ai .bedrock .llama2 .api .Llama2ChatBedrockApi . Llama2ChatResponse ;
23
+ import org .springframework .ai .bedrock .llama .api .LlamaChatBedrockApi ;
24
+ import org .springframework .ai .bedrock .llama .api .LlamaChatBedrockApi . LlamaChatRequest ;
25
+ import org .springframework .ai .bedrock .llama .api .LlamaChatBedrockApi . LlamaChatResponse ;
26
26
import org .springframework .ai .chat .ChatClient ;
27
27
import org .springframework .ai .chat .prompt .ChatOptions ;
28
28
import org .springframework .ai .chat .ChatResponse ;
35
35
import org .springframework .util .Assert ;
36
36
37
37
/**
38
- * Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama2 chat
38
+ * Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama chat
39
39
* generative.
40
40
*
41
41
* @author Christian Tzolov
42
+ * @author Wei Jiang
42
43
* @since 0.8.0
43
44
*/
44
- public class BedrockLlama2ChatClient implements ChatClient , StreamingChatClient {
45
+ public class BedrockLlamaChatClient implements ChatClient , StreamingChatClient {
45
46
46
- private final Llama2ChatBedrockApi chatApi ;
47
+ private final LlamaChatBedrockApi chatApi ;
47
48
48
- private final BedrockLlama2ChatOptions defaultOptions ;
49
+ private final BedrockLlamaChatOptions defaultOptions ;
49
50
50
- public BedrockLlama2ChatClient ( Llama2ChatBedrockApi chatApi ) {
51
+ public BedrockLlamaChatClient ( LlamaChatBedrockApi chatApi ) {
51
52
this (chatApi ,
52
- BedrockLlama2ChatOptions .builder ().withTemperature (0.8f ).withTopP (0.9f ).withMaxGenLen (100 ).build ());
53
+ BedrockLlamaChatOptions .builder ().withTemperature (0.8f ).withTopP (0.9f ).withMaxGenLen (100 ).build ());
53
54
}
54
55
55
- public BedrockLlama2ChatClient ( Llama2ChatBedrockApi chatApi , BedrockLlama2ChatOptions options ) {
56
- Assert .notNull (chatApi , "Llama2ChatBedrockApi must not be null" );
57
- Assert .notNull (options , "BedrockLlama2ChatOptions must not be null" );
56
+ public BedrockLlamaChatClient ( LlamaChatBedrockApi chatApi , BedrockLlamaChatOptions options ) {
57
+ Assert .notNull (chatApi , "LlamaChatBedrockApi must not be null" );
58
+ Assert .notNull (options , "BedrockLlamaChatOptions must not be null" );
58
59
59
60
this .chatApi = chatApi ;
60
61
this .defaultOptions = options ;
@@ -65,7 +66,7 @@ public ChatResponse call(Prompt prompt) {
65
66
66
67
var request = createRequest (prompt );
67
68
68
- Llama2ChatResponse response = this .chatApi .chatCompletion (request );
69
+ LlamaChatResponse response = this .chatApi .chatCompletion (request );
69
70
70
71
return new ChatResponse (List .of (new Generation (response .generation ()).withGenerationMetadata (
71
72
ChatGenerationMetadata .from (response .stopReason ().name (), extractUsage (response )))));
@@ -76,7 +77,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
76
77
77
78
var request = createRequest (prompt );
78
79
79
- Flux <Llama2ChatResponse > fluxResponse = this .chatApi .chatCompletionStream (request );
80
+ Flux <LlamaChatResponse > fluxResponse = this .chatApi .chatCompletionStream (request );
80
81
81
82
return fluxResponse .map (response -> {
82
83
String stopReason = response .stopReason () != null ? response .stopReason ().name () : null ;
@@ -85,7 +86,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
85
86
});
86
87
}
87
88
88
- private Usage extractUsage (Llama2ChatResponse response ) {
89
+ private Usage extractUsage (LlamaChatResponse response ) {
89
90
return new Usage () {
90
91
91
92
@ Override
@@ -103,22 +104,22 @@ public Long getGenerationTokens() {
103
104
/**
104
105
* Accessible for testing.
105
106
*/
106
- Llama2ChatRequest createRequest (Prompt prompt ) {
107
+ LlamaChatRequest createRequest (Prompt prompt ) {
107
108
108
109
final String promptValue = MessageToPromptConverter .create ().toPrompt (prompt .getInstructions ());
109
110
110
- Llama2ChatRequest request = Llama2ChatRequest .builder (promptValue ).build ();
111
+ LlamaChatRequest request = LlamaChatRequest .builder (promptValue ).build ();
111
112
112
113
if (this .defaultOptions != null ) {
113
- request = ModelOptionsUtils .merge (request , this .defaultOptions , Llama2ChatRequest .class );
114
+ request = ModelOptionsUtils .merge (request , this .defaultOptions , LlamaChatRequest .class );
114
115
}
115
116
116
117
if (prompt .getOptions () != null ) {
117
118
if (prompt .getOptions () instanceof ChatOptions runtimeOptions ) {
118
- BedrockLlama2ChatOptions updatedRuntimeOptions = ModelOptionsUtils .copyToTarget (runtimeOptions ,
119
- ChatOptions .class , BedrockLlama2ChatOptions .class );
119
+ BedrockLlamaChatOptions updatedRuntimeOptions = ModelOptionsUtils .copyToTarget (runtimeOptions ,
120
+ ChatOptions .class , BedrockLlamaChatOptions .class );
120
121
121
- request = ModelOptionsUtils .merge (updatedRuntimeOptions , request , Llama2ChatRequest .class );
122
+ request = ModelOptionsUtils .merge (updatedRuntimeOptions , request , LlamaChatRequest .class );
122
123
}
123
124
else {
124
125
throw new IllegalArgumentException ("Prompt options are not of type ChatOptions: "
0 commit comments