Skip to content

Commit 1af5060

Browse files
authored
Add metadata to user agent in conversation handler runtime (#2181)
* Add metadata to user agent in conversation handler runtime * tests * remove validation.
1 parent 298f971 commit 1af5060

14 files changed

+330
-40
lines changed

.changeset/thin-candles-perform.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@aws-amplify/ai-constructs': patch
3+
---
4+
5+
Add metadata to user agent in conversation handler runtime.

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import {
2525
import { ConversationTurnEventToolsProvider } from './event-tools-provider';
2626
import { randomBytes, randomUUID } from 'node:crypto';
2727
import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever';
28+
import { UserAgentProvider } from './user_agent_provider';
2829

2930
void describe('Bedrock converse adapter', () => {
3031
const commonEvent: Readonly<ConversationTurnEvent> = {
@@ -897,17 +898,20 @@ void describe('Bedrock converse adapter', () => {
897898
...commonEvent,
898899
};
899900

900-
event.request.headers['x-amz-user-agent'] = 'testUserAgent';
901-
902901
const bedrockClient = new BedrockRuntimeClient();
903902
const addMiddlewareMock = mock.method(bedrockClient.middlewareStack, 'add');
903+
const userAgentProvider = new UserAgentProvider(
904+
{} as unknown as ConversationTurnEvent
905+
);
906+
mock.method(userAgentProvider, 'getUserAgent', () => 'testUserAgent');
904907

905908
new BedrockConverseAdapter(
906909
event,
907910
[],
908911
bedrockClient,
909912
undefined,
910-
messageHistoryRetriever
913+
messageHistoryRetriever,
914+
userAgentProvider
911915
);
912916

913917
assert.strictEqual(addMiddlewareMock.mock.calls.length, 1);

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import { ConversationTurnEventToolsProvider } from './event-tools-provider';
2222
import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever';
2323
import * as bedrock from '@aws-sdk/client-bedrock-runtime';
2424
import { ValidationError } from './errors';
25+
import { UserAgentProvider } from './user_agent_provider';
2526

2627
/**
2728
* This class is responsible for interacting with Bedrock Converse API
@@ -48,23 +49,22 @@ export class BedrockConverseAdapter {
4849
private readonly messageHistoryRetriever = new ConversationMessageHistoryRetriever(
4950
event
5051
),
52+
userAgentProvider = new UserAgentProvider(event),
5153
private readonly logger = console
5254
) {
53-
if (event.request.headers['x-amz-user-agent']) {
54-
this.bedrockClient.middlewareStack.add(
55-
(next) => (args) => {
56-
// @ts-expect-error Request is typed as unknown.
57-
// But this is recommended way to alter headers per https://github.com/aws/aws-sdk-js-v3/blob/main/README.md.
58-
args.request.headers['x-amz-user-agent'] =
59-
event.request.headers['x-amz-user-agent'];
60-
return next(args);
61-
},
62-
{
63-
step: 'build',
64-
name: 'amplify-user-agent-injector',
65-
}
66-
);
67-
}
55+
this.bedrockClient.middlewareStack.add(
56+
(next) => (args) => {
57+
// @ts-expect-error Request is typed as unknown.
58+
// But this is recommended way to alter headers per https://github.com/aws/aws-sdk-js-v3/blob/main/README.md.
59+
args.request.headers['x-amz-user-agent'] =
60+
userAgentProvider.getUserAgent();
61+
return next(args);
62+
},
63+
{
64+
step: 'build',
65+
name: 'amplify-user-agent-injector',
66+
}
67+
);
6868
this.executableTools = [
6969
...eventToolsProvider.getEventTools(),
7070
...additionalTools,

packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.test.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
GetQueryOutput,
1313
ListQueryOutput,
1414
} from './conversation_message_history_retriever';
15+
import { UserAgentProvider } from './user_agent_provider';
1516

1617
type TestCase = {
1718
name: string;
@@ -704,7 +705,15 @@ void describe('Conversation message history retriever', () => {
704705

705706
for (const testCase of testCases) {
706707
void it(testCase.name, async () => {
707-
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', '');
708+
const userAgentProvider = new UserAgentProvider(
709+
{} as unknown as ConversationTurnEvent
710+
);
711+
mock.method(userAgentProvider, 'getUserAgent', () => '');
712+
const graphqlRequestExecutor = new GraphqlRequestExecutor(
713+
'',
714+
'',
715+
userAgentProvider
716+
);
708717
const executeGraphqlMock = mock.method(
709718
graphqlRequestExecutor,
710719
'executeGraphql',

packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
ConversationTurnEvent,
55
} from './types';
66
import { GraphqlRequestExecutor } from './graphql_request_executor';
7+
import { UserAgentProvider } from './user_agent_provider';
78

89
export type ConversationHistoryMessageItem = ConversationMessage & {
910
id: string;
@@ -107,7 +108,7 @@ export class ConversationMessageHistoryRetriever {
107108
private readonly graphqlRequestExecutor = new GraphqlRequestExecutor(
108109
event.graphqlApiEndpoint,
109110
event.request.headers.authorization,
110-
event.request.headers['x-amz-user-agent']
111+
new UserAgentProvider(event)
111112
)
112113
) {}
113114

packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import {
1515
GraphqlRequest,
1616
GraphqlRequestExecutor,
1717
} from './graphql_request_executor';
18+
import { UserAgentProvider } from './user_agent_provider';
1819

1920
void describe('Conversation turn response sender', () => {
2021
const event: ConversationTurnEvent = {
@@ -37,7 +38,19 @@ void describe('Conversation turn response sender', () => {
3738
};
3839

3940
void it('sends response back to appsync', async () => {
40-
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', '');
41+
const userAgentProvider = new UserAgentProvider(
42+
{} as unknown as ConversationTurnEvent
43+
);
44+
const userAgentProviderMock = mock.method(
45+
userAgentProvider,
46+
'getUserAgent',
47+
() => 'testUserAgent'
48+
);
49+
const graphqlRequestExecutor = new GraphqlRequestExecutor(
50+
'',
51+
'',
52+
userAgentProvider
53+
);
4154
const executeGraphqlMock = mock.method(
4255
graphqlRequestExecutor,
4356
'executeGraphql',
@@ -47,6 +60,7 @@ void describe('Conversation turn response sender', () => {
4760
);
4861
const sender = new ConversationTurnResponseSender(
4962
event,
63+
userAgentProvider,
5064
graphqlRequestExecutor
5165
);
5266
const response: Array<ContentBlock> = [
@@ -57,7 +71,14 @@ void describe('Conversation turn response sender', () => {
5771
];
5872
await sender.sendResponse(response);
5973

74+
assert.strictEqual(userAgentProviderMock.mock.calls.length, 1);
75+
assert.deepStrictEqual(userAgentProviderMock.mock.calls[0].arguments[0], {
76+
'turn-response-type': 'single',
77+
});
6078
assert.strictEqual(executeGraphqlMock.mock.calls.length, 1);
79+
assert.deepStrictEqual(executeGraphqlMock.mock.calls[0].arguments[1], {
80+
userAgent: 'testUserAgent',
81+
});
6182
const request = executeGraphqlMock.mock.calls[0]
6283
.arguments[0] as GraphqlRequest<MutationResponseInput>;
6384
assert.deepStrictEqual(request, {
@@ -85,7 +106,15 @@ void describe('Conversation turn response sender', () => {
85106
});
86107

87108
void it('serializes tool use input to JSON', async () => {
88-
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', '');
109+
const userAgentProvider = new UserAgentProvider(
110+
{} as unknown as ConversationTurnEvent
111+
);
112+
mock.method(userAgentProvider, 'getUserAgent', () => '');
113+
const graphqlRequestExecutor = new GraphqlRequestExecutor(
114+
'',
115+
'',
116+
userAgentProvider
117+
);
89118
const executeGraphqlMock = mock.method(
90119
graphqlRequestExecutor,
91120
'executeGraphql',
@@ -95,6 +124,7 @@ void describe('Conversation turn response sender', () => {
95124
);
96125
const sender = new ConversationTurnResponseSender(
97126
event,
127+
userAgentProvider,
98128
graphqlRequestExecutor
99129
);
100130
const toolUseBlock: ContentBlock.ToolUseMember = {
@@ -140,7 +170,19 @@ void describe('Conversation turn response sender', () => {
140170
});
141171

142172
void it('sends streaming response chunk back to appsync', async () => {
143-
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', '');
173+
const userAgentProvider = new UserAgentProvider(
174+
{} as unknown as ConversationTurnEvent
175+
);
176+
const userAgentProviderMock = mock.method(
177+
userAgentProvider,
178+
'getUserAgent',
179+
() => 'testUserAgent'
180+
);
181+
const graphqlRequestExecutor = new GraphqlRequestExecutor(
182+
'',
183+
'',
184+
userAgentProvider
185+
);
144186
const executeGraphqlMock = mock.method(
145187
graphqlRequestExecutor,
146188
'executeGraphql',
@@ -150,6 +192,7 @@ void describe('Conversation turn response sender', () => {
150192
);
151193
const sender = new ConversationTurnResponseSender(
152194
event,
195+
userAgentProvider,
153196
graphqlRequestExecutor
154197
);
155198
const chunk: StreamingResponseChunk = {
@@ -162,7 +205,14 @@ void describe('Conversation turn response sender', () => {
162205
};
163206
await sender.sendResponseChunk(chunk);
164207

208+
assert.strictEqual(userAgentProviderMock.mock.calls.length, 1);
209+
assert.deepStrictEqual(userAgentProviderMock.mock.calls[0].arguments[0], {
210+
'turn-response-type': 'streaming',
211+
});
165212
assert.strictEqual(executeGraphqlMock.mock.calls.length, 1);
213+
assert.deepStrictEqual(executeGraphqlMock.mock.calls[0].arguments[1], {
214+
userAgent: 'testUserAgent',
215+
});
166216
const request = executeGraphqlMock.mock.calls[0]
167217
.arguments[0] as GraphqlRequest<MutationStreamingResponseInput>;
168218
assert.deepStrictEqual(request, {
@@ -181,7 +231,15 @@ void describe('Conversation turn response sender', () => {
181231
});
182232

183233
void it('serializes tool use input to JSON when streaming', async () => {
184-
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', '');
234+
const userAgentProvider = new UserAgentProvider(
235+
{} as unknown as ConversationTurnEvent
236+
);
237+
mock.method(userAgentProvider, 'getUserAgent', () => '');
238+
const graphqlRequestExecutor = new GraphqlRequestExecutor(
239+
'',
240+
'',
241+
userAgentProvider
242+
);
185243
const executeGraphqlMock = mock.method(
186244
graphqlRequestExecutor,
187245
'executeGraphql',
@@ -191,6 +249,7 @@ void describe('Conversation turn response sender', () => {
191249
);
192250
const sender = new ConversationTurnResponseSender(
193251
event,
252+
userAgentProvider,
194253
graphqlRequestExecutor
195254
);
196255
const toolUseBlock: ContentBlock.ToolUseMember = {
@@ -242,7 +301,19 @@ void describe('Conversation turn response sender', () => {
242301
});
243302

244303
void it('sends errors response back to appsync', async () => {
245-
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', '');
304+
const userAgentProvider = new UserAgentProvider(
305+
{} as unknown as ConversationTurnEvent
306+
);
307+
const userAgentProviderMock = mock.method(
308+
userAgentProvider,
309+
'getUserAgent',
310+
() => 'testUserAgent'
311+
);
312+
const graphqlRequestExecutor = new GraphqlRequestExecutor(
313+
'',
314+
'',
315+
userAgentProvider
316+
);
246317
const executeGraphqlMock = mock.method(
247318
graphqlRequestExecutor,
248319
'executeGraphql',
@@ -252,6 +323,7 @@ void describe('Conversation turn response sender', () => {
252323
);
253324
const sender = new ConversationTurnResponseSender(
254325
event,
326+
userAgentProvider,
255327
graphqlRequestExecutor
256328
);
257329
const errors: Array<ConversationTurnError> = [
@@ -266,6 +338,14 @@ void describe('Conversation turn response sender', () => {
266338
];
267339
await sender.sendErrors(errors);
268340

341+
assert.strictEqual(userAgentProviderMock.mock.calls.length, 1);
342+
assert.deepStrictEqual(userAgentProviderMock.mock.calls[0].arguments[0], {
343+
'turn-response-type': 'error',
344+
});
345+
assert.strictEqual(executeGraphqlMock.mock.calls.length, 1);
346+
assert.deepStrictEqual(executeGraphqlMock.mock.calls[0].arguments[1], {
347+
userAgent: 'testUserAgent',
348+
});
269349
assert.strictEqual(executeGraphqlMock.mock.calls.length, 1);
270350
const request = executeGraphqlMock.mock.calls[0]
271351
.arguments[0] as GraphqlRequest<MutationResponseInput>;

packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {
55
} from './types.js';
66
import type { ContentBlock } from '@aws-sdk/client-bedrock-runtime';
77
import { GraphqlRequestExecutor } from './graphql_request_executor';
8+
import { UserAgentProvider } from './user_agent_provider';
89

910
export type MutationResponseInput = {
1011
input: {
@@ -36,10 +37,11 @@ export class ConversationTurnResponseSender {
3637
*/
3738
constructor(
3839
private readonly event: ConversationTurnEvent,
40+
private readonly userAgentProvider = new UserAgentProvider(event),
3941
private readonly graphqlRequestExecutor = new GraphqlRequestExecutor(
4042
event.graphqlApiEndpoint,
4143
event.request.headers.authorization,
42-
event.request.headers['x-amz-user-agent']
44+
userAgentProvider
4345
),
4446
private readonly logger = console
4547
) {}
@@ -50,7 +52,11 @@ export class ConversationTurnResponseSender {
5052
await this.graphqlRequestExecutor.executeGraphql<
5153
MutationResponseInput,
5254
void
53-
>(responseMutationRequest);
55+
>(responseMutationRequest, {
56+
userAgent: this.userAgentProvider.getUserAgent({
57+
'turn-response-type': 'single',
58+
}),
59+
});
5460
};
5561

5662
sendResponseChunk = async (chunk: StreamingResponseChunk) => {
@@ -59,7 +65,11 @@ export class ConversationTurnResponseSender {
5965
await this.graphqlRequestExecutor.executeGraphql<
6066
MutationStreamingResponseInput,
6167
void
62-
>(responseMutationRequest);
68+
>(responseMutationRequest, {
69+
userAgent: this.userAgentProvider.getUserAgent({
70+
'turn-response-type': 'streaming',
71+
}),
72+
});
6373
};
6474

6575
sendErrors = async (errors: ConversationTurnError[]) => {
@@ -71,7 +81,11 @@ export class ConversationTurnResponseSender {
7181
await this.graphqlRequestExecutor.executeGraphql<
7282
MutationErrorsResponseInput,
7383
void
74-
>(responseMutationRequest);
84+
>(responseMutationRequest, {
85+
userAgent: this.userAgentProvider.getUserAgent({
86+
'turn-response-type': 'error',
87+
}),
88+
});
7589
};
7690

7791
private createMutationErrorsRequest = (errors: ConversationTurnError[]) => {

packages/ai-constructs/src/conversation/runtime/event-tools-provider/event_tools_provider.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { ConversationTurnEvent, ExecutableTool } from '../types';
22
import { GraphQlTool } from './graphql_tool';
33
import { GraphQlQueryFactory } from './graphql_query_factory';
4+
import { UserAgentProvider } from '../user_agent_provider';
45

56
/**
67
* Creates executable tools from definitions in conversation turn event.
@@ -29,7 +30,7 @@ export class ConversationTurnEventToolsProvider {
2930
graphqlApiEndpoint,
3031
query,
3132
this.event.request.headers.authorization,
32-
this.event.request.headers['x-amz-user-agent']
33+
new UserAgentProvider(this.event)
3334
);
3435
});
3536
return tools ?? [];

0 commit comments

Comments
 (0)