Skip to content

Commit 3a29d43

Browse files
authored
Pass user agent in conversation handler lambda (#2086)
* Pass user agent in conversation handler lambda * Pass user agent in conversation handler lambda
1 parent 6e4a62f commit 3a29d43

14 files changed

+82
-14
lines changed

.changeset/soft-gifts-exist.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@aws-amplify/ai-constructs': minor
3+
---
4+
5+
Pass user agent in conversation handler lambda

packages/ai-constructs/API.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,7 @@ type ConversationTurnEvent = {
9292
};
9393
};
9494
request: {
95-
headers: {
96-
authorization: string;
97-
};
95+
headers: Record<string, string>;
9896
};
9997
messages?: Array<ConversationMessage>;
10098
messageHistoryQuery: {

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,4 +840,43 @@ void describe('Bedrock converse adapter', () => {
840840
},
841841
]);
842842
});
843+
844+
void it('adds user agent middleware', async () => {
845+
const event: ConversationTurnEvent = {
846+
...commonEvent,
847+
};
848+
849+
event.request.headers['x-amz-user-agent'] = 'testUserAgent';
850+
851+
const bedrockClient = new BedrockRuntimeClient();
852+
const addMiddlewareMock = mock.method(bedrockClient.middlewareStack, 'add');
853+
854+
new BedrockConverseAdapter(
855+
event,
856+
[],
857+
bedrockClient,
858+
undefined,
859+
messageHistoryRetriever
860+
);
861+
862+
assert.strictEqual(addMiddlewareMock.mock.calls.length, 1);
863+
const middlewareHandler = addMiddlewareMock.mock.calls[0].arguments[0];
864+
const options = addMiddlewareMock.mock.calls[0].arguments[1];
865+
assert.strictEqual(options.name, 'amplify-user-agent-injector');
866+
const args: {
867+
request: {
868+
headers: Record<string, string>;
869+
};
870+
} = {
871+
request: {
872+
headers: {},
873+
},
874+
};
875+
// @ts-expect-error We mock subset of middleware inputs here.
876+
await middlewareHandler(mock.fn(), {})(args);
877+
assert.strictEqual(
878+
args.request.headers['x-amz-user-agent'],
879+
'testUserAgent'
880+
);
881+
});
843882
});

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,21 @@ export class BedrockConverseAdapter {
4444
),
4545
private readonly logger = console
4646
) {
47+
if (event.request.headers['x-amz-user-agent']) {
48+
this.bedrockClient.middlewareStack.add(
49+
(next) => (args) => {
50+
// @ts-expect-error Request is typed as unknown.
51+
// But this is recommended way to alter headers per https://github.com/aws/aws-sdk-js-v3/blob/main/README.md.
52+
args.request.headers['x-amz-user-agent'] =
53+
event.request.headers['x-amz-user-agent'];
54+
return next(args);
55+
},
56+
{
57+
step: 'build',
58+
name: 'amplify-user-agent-injector',
59+
}
60+
);
61+
}
4762
this.executableTools = [
4863
...eventToolsProvider.getEventTools(),
4964
...additionalTools,

packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ void describe('Conversation message history retriever', () => {
346346

347347
for (const testCase of testCases) {
348348
void it(testCase.name, async () => {
349-
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '');
349+
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', '');
350350
const executeGraphqlMock = mock.method(
351351
graphqlRequestExecutor,
352352
'executeGraphql',

packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ export class ConversationMessageHistoryRetriever {
102102
private readonly event: ConversationTurnEvent,
103103
private readonly graphqlRequestExecutor = new GraphqlRequestExecutor(
104104
event.graphqlApiEndpoint,
105-
event.request.headers.authorization
105+
event.request.headers.authorization,
106+
event.request.headers['x-amz-user-agent']
106107
)
107108
) {}
108109

packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void describe('Conversation turn response sender', () => {
3333
};
3434

3535
void it('sends response back to appsync', async () => {
36-
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '');
36+
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', '');
3737
const executeGraphqlMock = mock.method(
3838
graphqlRequestExecutor,
3939
'executeGraphql',
@@ -81,7 +81,7 @@ void describe('Conversation turn response sender', () => {
8181
});
8282

8383
void it('serializes tool use input to JSON', async () => {
84-
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '');
84+
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', '');
8585
const executeGraphqlMock = mock.method(
8686
graphqlRequestExecutor,
8787
'executeGraphql',

packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ export class ConversationTurnResponseSender {
2222
private readonly event: ConversationTurnEvent,
2323
private readonly graphqlRequestExecutor = new GraphqlRequestExecutor(
2424
event.graphqlApiEndpoint,
25-
event.request.headers.authorization
25+
event.request.headers.authorization,
26+
event.request.headers['x-amz-user-agent']
2627
),
2728
private readonly logger = console
2829
) {}

packages/ai-constructs/src/conversation/runtime/event-tools-provider/event_tools_provider.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ export class ConversationTurnEventToolsProvider {
2828
inputSchema,
2929
graphqlApiEndpoint,
3030
query,
31-
this.event.request.headers.authorization
31+
this.event.request.headers.authorization,
32+
this.event.request.headers['x-amz-user-agent']
3233
);
3334
});
3435
return tools ?? [];

packages/ai-constructs/src/conversation/runtime/event-tools-provider/graphql_tool.test.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ void describe('GraphQl tool', () => {
2222
graphQlEndpoint,
2323
query,
2424
accessToken,
25+
'',
2526
graphqlRequestExecutor
2627
);
2728
};
@@ -30,7 +31,7 @@ void describe('GraphQl tool', () => {
3031
const testResponse = {
3132
test: 'response',
3233
};
33-
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '');
34+
const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', '');
3435
const executeGraphqlMock = mock.method(
3536
graphqlRequestExecutor,
3637
'executeGraphql',

0 commit comments

Comments
 (0)