1
1
from collections .abc import Mapping
2
2
from datetime import datetime
3
- from typing import NamedTuple , Optional
3
+ from typing import Any , NamedTuple , Optional
4
4
5
- from sqlalchemy import insert , select
5
+ from sqlalchemy import Row , Select , and_ , select , tuple_
6
6
from sqlalchemy .ext .asyncio import AsyncSession
7
7
from typing_extensions import TypeAlias
8
8
9
9
from phoenix .db import models
10
- from phoenix .db .helpers import num_docs_col
10
+ from phoenix .db .helpers import dedup , num_docs_col
11
11
from phoenix .db .insertion .helpers import as_kv
12
12
from phoenix .db .insertion .types import (
13
13
Insertables ,
24
24
_DocumentPosition : TypeAlias = int
25
25
_AnnoRowId : TypeAlias = int
26
26
_NumDocs : TypeAlias = int
27
+ _Identifier : TypeAlias = str
27
28
28
- _Key : TypeAlias = tuple [_Name , _SpanId , _DocumentPosition ]
29
- _UniqueBy : TypeAlias = tuple [_Name , _SpanRowId , _DocumentPosition ]
29
+
30
+ class _Key (NamedTuple ):
31
+ annotation_name : _Name
32
+ annotation_identifier : _Identifier
33
+ span_id : _SpanId
34
+ document_position : _DocumentPosition
35
+
36
+
37
+ _UniqueBy : TypeAlias = tuple [_Name , _SpanRowId , _DocumentPosition , _Identifier ]
30
38
_Existing : TypeAlias = tuple [
31
39
_SpanRowId ,
32
40
_SpanId ,
33
41
_NumDocs ,
34
42
Optional [_AnnoRowId ],
35
43
Optional [_Name ],
36
44
Optional [_DocumentPosition ],
45
+ Optional [_Identifier ],
37
46
Optional [datetime ],
38
47
]
39
48
@@ -46,15 +55,16 @@ class DocumentAnnotationQueueInserter(
46
55
DocumentAnnotationDmlEvent ,
47
56
],
48
57
table = models .DocumentAnnotation ,
49
- unique_by = (),
58
+ unique_by = ("name" , "span_rowid" , "document_position" , "identifier" ),
59
+ constraint_name = "uq_document_annotations_name_span_rowid_document_pos_identifier" ,
50
60
):
51
61
async def _events (
52
62
self ,
53
63
session : AsyncSession ,
54
64
* insertions : Insertables .DocumentAnnotation ,
55
65
) -> list [DocumentAnnotationDmlEvent ]:
56
66
records = [dict (as_kv (ins .row )) for ins in insertions ]
57
- stmt = insert ( self .table ). values ( records ).returning (self .table .id )
67
+ stmt = self ._insert_on_conflict ( * records ).returning (self .table .id )
58
68
ids = tuple ([_ async for _ in await session .stream_scalars (stmt )])
59
69
return [DocumentAnnotationDmlEvent (ids )]
60
70
@@ -71,19 +81,37 @@ async def _partition(
71
81
to_postpone : list [Postponed [Precursors .DocumentAnnotation ]] = []
72
82
to_discard : list [Received [Precursors .DocumentAnnotation ]] = []
73
83
74
- span_ids = {p .item .span_id for p in parcels }
75
- stmt = select (models .Span .id , models .Span .span_id , num_docs_col (self ._db .dialect )).where (
76
- models .Span .span_id .in_ (span_ids )
77
- )
78
- result = await session .execute (stmt )
79
- spans = result .all ()
84
+ stmt = self ._select_existing (* map (_key , parcels ))
85
+ existing : list [Row [_Existing ]] = [_ async for _ in await session .stream (stmt )]
80
86
existing_spans : Mapping [str , _SpanAttr ] = {
81
- row .span_id : _SpanAttr (row .id , row .num_docs ) for row in spans
87
+ e .span_id : _SpanAttr (e .span_rowid , e .num_docs ) for e in existing
88
+ }
89
+ existing_annos : Mapping [_Key , _AnnoAttr ] = {
90
+ _Key (
91
+ annotation_name = e .name ,
92
+ annotation_identifier = e .identifier ,
93
+ span_id = e .span_id ,
94
+ document_position = e .document_position ,
95
+ ): _AnnoAttr (e .span_rowid , e .id , e .updated_at )
96
+ for e in existing
97
+ if e .id is not None and e .name is not None and e .updated_at is not None
82
98
}
83
99
84
100
for p in parcels :
85
- if p .item .span_id in existing_spans :
86
- span = existing_spans [p .item .span_id ]
101
+ if (anno := existing_annos .get (_key (p ))) is not None :
102
+ if p .received_at <= anno .updated_at :
103
+ to_discard .append (p )
104
+ else :
105
+ to_insert .append (
106
+ Received (
107
+ received_at = p .received_at ,
108
+ item = p .item .as_insertable (
109
+ span_rowid = anno .span_rowid ,
110
+ id_ = anno .id_ ,
111
+ ),
112
+ )
113
+ )
114
+ elif (span := existing_spans .get (p .item .span_id )) is not None :
87
115
if 0 <= p .item .document_position < span .num_docs :
88
116
to_insert .append (
89
117
Received (
@@ -106,9 +134,56 @@ async def _partition(
106
134
to_discard .append (p )
107
135
108
136
assert len (to_insert ) + len (to_postpone ) + len (to_discard ) == len (parcels )
137
+ to_insert = dedup (sorted (to_insert , key = _time , reverse = True ), _unique_by )[::- 1 ]
109
138
return to_insert , to_postpone , to_discard
110
139
140
+ def _select_existing (self , * keys : _Key ) -> Select [_Existing ]:
141
+ anno = self .table
142
+ span = (
143
+ select (models .Span .id , models .Span .span_id , num_docs_col (self ._db .dialect ))
144
+ .where (models .Span .span_id .in_ ({k .span_id for k in keys }))
145
+ .cte ()
146
+ )
147
+ onclause = and_ (
148
+ span .c .id == anno .span_rowid ,
149
+ anno .name .in_ ({k .annotation_name for k in keys }),
150
+ tuple_ (anno .name , anno .identifier , span .c .span_id , anno .document_position ).in_ (keys ),
151
+ )
152
+ return select (
153
+ span .c .id .label ("span_rowid" ),
154
+ span .c .span_id ,
155
+ span .c .num_docs ,
156
+ anno .id ,
157
+ anno .name ,
158
+ anno .document_position ,
159
+ anno .identifier ,
160
+ anno .updated_at ,
161
+ ).outerjoin_from (span , anno , onclause )
162
+
111
163
112
164
class _SpanAttr (NamedTuple ):
113
165
span_rowid : _SpanRowId
114
166
num_docs : _NumDocs
167
+
168
+
169
+ class _AnnoAttr (NamedTuple ):
170
+ span_rowid : _SpanRowId
171
+ id_ : _AnnoRowId
172
+ updated_at : datetime
173
+
174
+
175
+ def _key (p : Received [Precursors .DocumentAnnotation ]) -> _Key :
176
+ return _Key (
177
+ annotation_name = p .item .obj .name ,
178
+ annotation_identifier = p .item .obj .identifier ,
179
+ span_id = p .item .span_id ,
180
+ document_position = p .item .document_position ,
181
+ )
182
+
183
+
184
+ def _unique_by (p : Received [Insertables .DocumentAnnotation ]) -> _UniqueBy :
185
+ return p .item .obj .name , p .item .span_rowid , p .item .document_position , p .item .identifier
186
+
187
+
188
+ def _time (p : Received [Any ]) -> datetime :
189
+ return p .received_at
0 commit comments