|
2 | 2 | from decimal import Decimal
|
3 | 3 | from sqlalchemy import func, text
|
4 | 4 | from sqlalchemy.sql import sqltypes
|
5 |
| -from sqlalchemy.types import UserDefinedType, Float |
| 5 | +from sqlalchemy.types import UserDefinedType |
6 | 6 | from uuid import UUID as _python_UUID
|
7 | 7 | from intersystems_iris import IRISList
|
| 8 | +from sqlalchemy import __version__ as sqlalchemy_version |
8 | 9 |
|
9 | 10 | HOROLOG_ORDINAL = datetime.date(1840, 12, 31).toordinal()
|
10 | 11 |
|
@@ -134,73 +135,79 @@ def process(value):
|
134 | 135 | return process
|
135 | 136 |
|
136 | 137 |
|
137 |
| -class IRISUniqueIdentifier(sqltypes.Uuid): |
138 |
| - def literal_processor(self, dialect): |
139 |
| - if not self.as_uuid: |
| 138 | +if sqlalchemy_version.startswith("2."): |
140 | 139 |
|
141 |
| - def process(value): |
142 |
| - return f"""'{value.replace("'", "''")}'""" |
143 |
| - |
144 |
| - return process |
145 |
| - else: |
146 |
| - |
147 |
| - def process(value): |
148 |
| - return f"""'{str(value).replace("'", "''")}'""" |
149 |
| - |
150 |
| - return process |
151 |
| - |
152 |
| - def bind_processor(self, dialect): |
153 |
| - character_based_uuid = not dialect.supports_native_uuid or not self.native_uuid |
154 |
| - |
155 |
| - if character_based_uuid: |
156 |
| - if self.as_uuid: |
| 140 | + class IRISUniqueIdentifier(sqltypes.Uuid): |
| 141 | + def literal_processor(self, dialect): |
| 142 | + if not self.as_uuid: |
157 | 143 |
|
158 | 144 | def process(value):
|
159 |
| - if value is not None: |
160 |
| - value = str(value) |
161 |
| - return value |
| 145 | + return f"""'{value.replace("'", "''")}'""" |
162 | 146 |
|
163 | 147 | return process
|
164 | 148 | else:
|
165 | 149 |
|
166 | 150 | def process(value):
|
167 |
| - return value |
| 151 | + return f"""'{str(value).replace("'", "''")}'""" |
168 | 152 |
|
169 | 153 | return process
|
170 |
| - else: |
171 |
| - return None |
172 | 154 |
|
173 |
| - def result_processor(self, dialect, coltype): |
174 |
| - character_based_uuid = not dialect.supports_native_uuid or not self.native_uuid |
| 155 | + def bind_processor(self, dialect): |
| 156 | + character_based_uuid = ( |
| 157 | + not dialect.supports_native_uuid or not self.native_uuid |
| 158 | + ) |
175 | 159 |
|
176 |
| - if character_based_uuid: |
177 |
| - if self.as_uuid: |
| 160 | + if character_based_uuid: |
| 161 | + if self.as_uuid: |
178 | 162 |
|
179 |
| - def process(value): |
180 |
| - if value and not isinstance(value, _python_UUID): |
181 |
| - value = _python_UUID(value) |
182 |
| - return value |
| 163 | + def process(value): |
| 164 | + if value is not None: |
| 165 | + value = str(value) |
| 166 | + return value |
183 | 167 |
|
184 |
| - return process |
| 168 | + return process |
| 169 | + else: |
| 170 | + |
| 171 | + def process(value): |
| 172 | + return value |
| 173 | + |
| 174 | + return process |
185 | 175 | else:
|
| 176 | + return None |
186 | 177 |
|
187 |
| - def process(value): |
188 |
| - if value and isinstance(value, _python_UUID): |
189 |
| - value = str(value) |
190 |
| - return value |
| 178 | + def result_processor(self, dialect, coltype): |
| 179 | + character_based_uuid = ( |
| 180 | + not dialect.supports_native_uuid or not self.native_uuid |
| 181 | + ) |
191 | 182 |
|
192 |
| - return process |
193 |
| - else: |
194 |
| - if not self.as_uuid: |
| 183 | + if character_based_uuid: |
| 184 | + if self.as_uuid: |
195 | 185 |
|
196 |
| - def process(value): |
197 |
| - if value and isinstance(value, _python_UUID): |
198 |
| - value = str(value) |
199 |
| - return value |
| 186 | + def process(value): |
| 187 | + if value and not isinstance(value, _python_UUID): |
| 188 | + value = _python_UUID(value) |
| 189 | + return value |
200 | 190 |
|
201 |
| - return process |
| 191 | + return process |
| 192 | + else: |
| 193 | + |
| 194 | + def process(value): |
| 195 | + if value and isinstance(value, _python_UUID): |
| 196 | + value = str(value) |
| 197 | + return value |
| 198 | + |
| 199 | + return process |
202 | 200 | else:
|
203 |
| - return None |
| 201 | + if not self.as_uuid: |
| 202 | + |
| 203 | + def process(value): |
| 204 | + if value and isinstance(value, _python_UUID): |
| 205 | + value = str(value) |
| 206 | + return value |
| 207 | + |
| 208 | + return process |
| 209 | + else: |
| 210 | + return None |
204 | 211 |
|
205 | 212 |
|
206 | 213 | class IRISListBuild(UserDefinedType):
|
@@ -267,9 +274,7 @@ def __init__(self, max_items: int = None, item_type: type = float):
|
267 | 274 | item_type_server = (
|
268 | 275 | "decimal"
|
269 | 276 | if self.item_type is float
|
270 |
| - else "float" |
271 |
| - if self.item_type is Decimal |
272 |
| - else "int" |
| 277 | + else "float" if self.item_type is Decimal else "int" |
273 | 278 | )
|
274 | 279 | self.item_type_server = item_type_server
|
275 | 280 |
|
@@ -304,19 +309,21 @@ class comparator_factory(UserDefinedType.Comparator):
|
304 | 309 | # return self.func('vector_l2', other)
|
305 | 310 |
|
306 | 311 | def max_inner_product(self, other):
|
307 |
| - return self.func('vector_dot_product', other) |
| 312 | + return self.func("vector_dot_product", other) |
308 | 313 |
|
309 | 314 | def cosine_distance(self, other):
|
310 |
| - return self.func('vector_cosine', other) |
| 315 | + return self.func("vector_cosine", other) |
311 | 316 |
|
312 | 317 | def cosine(self, other):
|
313 |
| - return (1 - self.func('vector_cosine', other)) |
| 318 | + return 1 - self.func("vector_cosine", other) |
314 | 319 |
|
315 | 320 | def func(self, funcname: str, other):
|
316 | 321 | if not isinstance(other, list) and not isinstance(other, tuple):
|
317 | 322 | raise ValueError("expected list or tuple, got '%s'" % type(other))
|
318 | 323 | othervalue = f"[{','.join([str(v) for v in other])}]"
|
319 |
| - return getattr(func, funcname)(self, func.to_vector(othervalue, text(self.type.item_type_server))) |
| 324 | + return getattr(func, funcname)( |
| 325 | + self, func.to_vector(othervalue, text(self.type.item_type_server)) |
| 326 | + ) |
320 | 327 |
|
321 | 328 |
|
322 | 329 | class BIT(sqltypes.TypeEngine):
|
|
0 commit comments