2
2
3
3
import functools
4
4
import typing
5
- from collections .abc import AsyncIterator , Iterable , Mapping
5
+ from collections .abc import AsyncIterator , Iterable , Iterator , Mapping
6
6
from contextlib import asynccontextmanager
7
7
from dataclasses import dataclass , field
8
8
from datetime import datetime
9
+ from itertools import count
9
10
from typing import TYPE_CHECKING , Any , Generic , Literal , Union , cast , overload
10
11
11
12
import anyio
@@ -369,13 +370,14 @@ async def _map_messages(
369
370
"""Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`."""
370
371
system_prompt : list [SystemContentBlockTypeDef ] = []
371
372
bedrock_messages : list [MessageUnionTypeDef ] = []
373
+ document_count : Iterator [int ] = count (1 )
372
374
for m in messages :
373
375
if isinstance (m , ModelRequest ):
374
376
for part in m .parts :
375
377
if isinstance (part , SystemPromptPart ):
376
378
system_prompt .append ({'text' : part .content })
377
379
elif isinstance (part , UserPromptPart ):
378
- bedrock_messages .extend (await self ._map_user_prompt (part ))
380
+ bedrock_messages .extend (await self ._map_user_prompt (part , document_count ))
379
381
elif isinstance (part , ToolReturnPart ):
380
382
assert part .tool_call_id is not None
381
383
bedrock_messages .append (
@@ -430,20 +432,18 @@ async def _map_messages(
430
432
return system_prompt , bedrock_messages
431
433
432
434
@staticmethod
433
- async def _map_user_prompt (part : UserPromptPart ) -> list [MessageUnionTypeDef ]:
435
+ async def _map_user_prompt (part : UserPromptPart , document_count : Iterator [ int ] ) -> list [MessageUnionTypeDef ]:
434
436
content : list [ContentBlockUnionTypeDef ] = []
435
437
if isinstance (part .content , str ):
436
438
content .append ({'text' : part .content })
437
439
else :
438
- document_count = 0
439
440
for item in part .content :
440
441
if isinstance (item , str ):
441
442
content .append ({'text' : item })
442
443
elif isinstance (item , BinaryContent ):
443
444
format = item .format
444
445
if item .is_document :
445
- document_count += 1
446
- name = f'Document { document_count } '
446
+ name = f'Document { next (document_count )} '
447
447
assert format in ('pdf' , 'txt' , 'csv' , 'doc' , 'docx' , 'xls' , 'xlsx' , 'html' , 'md' )
448
448
content .append ({'document' : {'name' : name , 'format' : format , 'source' : {'bytes' : item .data }}})
449
449
elif item .is_image :
@@ -464,8 +464,7 @@ async def _map_user_prompt(part: UserPromptPart) -> list[MessageUnionTypeDef]:
464
464
content .append ({'image' : image })
465
465
466
466
elif item .kind == 'document-url' :
467
- document_count += 1
468
- name = f'Document { document_count } '
467
+ name = f'Document { next (document_count )} '
469
468
data = response .content
470
469
content .append ({'document' : {'name' : name , 'format' : item .format , 'source' : {'bytes' : data }}})
471
470
0 commit comments