Skip to content

Commit b69e366

Browse files
committed
fix support older engine
1 parent 1113925 commit b69e366

File tree

5 files changed

+250
-210
lines changed

5 files changed

+250
-210
lines changed

.github/workflows/python-publish.yml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,22 @@ jobs:
2424
image:
2525
- intersystemsdc/iris-community:latest
2626
- intersystemsdc/iris-community:preview
27-
- intersystemsdc/iris-community:2024.1-preview
27+
engine:
28+
- old
29+
- new
2830
runs-on: ubuntu-latest
2931
steps:
3032
- uses: actions/checkout@v3
33+
- name: Set up Python
34+
uses: actions/setup-python@v4
35+
with:
36+
python-version: '3.11'
3137
- name: Install requirements
3238
run: |
33-
pip install -r requirements-dev.txt \
34-
-r requirements-iris.txt \
35-
-e .
39+
pip install tox
3640
- name: Run Tests
3741
run: |
38-
pytest --container ${{ matrix.image }}
39-
42+
tox -e py311${{ matrix.engine }} -- --container ${{ matrix.image }}
4043
deploy:
4144
needs: test
4245
if: github.event_name != 'pull_request'

sqlalchemy_iris/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def check_constraints(cls):
9191
from .types import IRISTimeStamp
9292
from .types import IRISDate
9393
from .types import IRISDateTime
94-
from .types import IRISUniqueIdentifier
9594
from .types import IRISListBuild # noqa
9695
from .types import IRISVector # noqa
9796

@@ -819,8 +818,10 @@ def create_cursor(self):
819818
sqltypes.DateTime: IRISDateTime,
820819
sqltypes.TIMESTAMP: IRISTimeStamp,
821820
sqltypes.Time: IRISTime,
822-
sqltypes.UUID: IRISUniqueIdentifier,
823821
}
822+
if sqlalchemy_version.startswith("2."):
823+
from .types import IRISUniqueIdentifier
824+
colspecs[sqltypes.UUID] = IRISUniqueIdentifier
824825

825826

826827
class IRISExact(ReturnTypeFromArgs):

sqlalchemy_iris/types.py

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from decimal import Decimal
33
from sqlalchemy import func, text
44
from sqlalchemy.sql import sqltypes
5-
from sqlalchemy.types import UserDefinedType, Float
5+
from sqlalchemy.types import UserDefinedType
66
from uuid import UUID as _python_UUID
77
from intersystems_iris import IRISList
8+
from sqlalchemy import __version__ as sqlalchemy_version
89

910
HOROLOG_ORDINAL = datetime.date(1840, 12, 31).toordinal()
1011

@@ -134,73 +135,79 @@ def process(value):
134135
return process
135136

136137

137-
class IRISUniqueIdentifier(sqltypes.Uuid):
138-
def literal_processor(self, dialect):
139-
if not self.as_uuid:
138+
if sqlalchemy_version.startswith("2."):
140139

141-
def process(value):
142-
return f"""'{value.replace("'", "''")}'"""
143-
144-
return process
145-
else:
146-
147-
def process(value):
148-
return f"""'{str(value).replace("'", "''")}'"""
149-
150-
return process
151-
152-
def bind_processor(self, dialect):
153-
character_based_uuid = not dialect.supports_native_uuid or not self.native_uuid
154-
155-
if character_based_uuid:
156-
if self.as_uuid:
140+
class IRISUniqueIdentifier(sqltypes.Uuid):
141+
def literal_processor(self, dialect):
142+
if not self.as_uuid:
157143

158144
def process(value):
159-
if value is not None:
160-
value = str(value)
161-
return value
145+
return f"""'{value.replace("'", "''")}'"""
162146

163147
return process
164148
else:
165149

166150
def process(value):
167-
return value
151+
return f"""'{str(value).replace("'", "''")}'"""
168152

169153
return process
170-
else:
171-
return None
172154

173-
def result_processor(self, dialect, coltype):
174-
character_based_uuid = not dialect.supports_native_uuid or not self.native_uuid
155+
def bind_processor(self, dialect):
156+
character_based_uuid = (
157+
not dialect.supports_native_uuid or not self.native_uuid
158+
)
175159

176-
if character_based_uuid:
177-
if self.as_uuid:
160+
if character_based_uuid:
161+
if self.as_uuid:
178162

179-
def process(value):
180-
if value and not isinstance(value, _python_UUID):
181-
value = _python_UUID(value)
182-
return value
163+
def process(value):
164+
if value is not None:
165+
value = str(value)
166+
return value
183167

184-
return process
168+
return process
169+
else:
170+
171+
def process(value):
172+
return value
173+
174+
return process
185175
else:
176+
return None
186177

187-
def process(value):
188-
if value and isinstance(value, _python_UUID):
189-
value = str(value)
190-
return value
178+
def result_processor(self, dialect, coltype):
179+
character_based_uuid = (
180+
not dialect.supports_native_uuid or not self.native_uuid
181+
)
191182

192-
return process
193-
else:
194-
if not self.as_uuid:
183+
if character_based_uuid:
184+
if self.as_uuid:
195185

196-
def process(value):
197-
if value and isinstance(value, _python_UUID):
198-
value = str(value)
199-
return value
186+
def process(value):
187+
if value and not isinstance(value, _python_UUID):
188+
value = _python_UUID(value)
189+
return value
200190

201-
return process
191+
return process
192+
else:
193+
194+
def process(value):
195+
if value and isinstance(value, _python_UUID):
196+
value = str(value)
197+
return value
198+
199+
return process
202200
else:
203-
return None
201+
if not self.as_uuid:
202+
203+
def process(value):
204+
if value and isinstance(value, _python_UUID):
205+
value = str(value)
206+
return value
207+
208+
return process
209+
else:
210+
return None
204211

205212

206213
class IRISListBuild(UserDefinedType):
@@ -267,9 +274,7 @@ def __init__(self, max_items: int = None, item_type: type = float):
267274
item_type_server = (
268275
"decimal"
269276
if self.item_type is float
270-
else "float"
271-
if self.item_type is Decimal
272-
else "int"
277+
else "float" if self.item_type is Decimal else "int"
273278
)
274279
self.item_type_server = item_type_server
275280

@@ -304,19 +309,21 @@ class comparator_factory(UserDefinedType.Comparator):
304309
# return self.func('vector_l2', other)
305310

306311
def max_inner_product(self, other):
307-
return self.func('vector_dot_product', other)
312+
return self.func("vector_dot_product", other)
308313

309314
def cosine_distance(self, other):
310-
return self.func('vector_cosine', other)
315+
return self.func("vector_cosine", other)
311316

312317
def cosine(self, other):
313-
return (1 - self.func('vector_cosine', other))
318+
return 1 - self.func("vector_cosine", other)
314319

315320
def func(self, funcname: str, other):
316321
if not isinstance(other, list) and not isinstance(other, tuple):
317322
raise ValueError("expected list or tuple, got '%s'" % type(other))
318323
othervalue = f"[{','.join([str(v) for v in other])}]"
319-
return getattr(func, funcname)(self, func.to_vector(othervalue, text(self.type.item_type_server)))
324+
return getattr(func, funcname)(
325+
self, func.to_vector(othervalue, text(self.type.item_type_server))
326+
)
320327

321328

322329
class BIT(sqltypes.TypeEngine):

0 commit comments

Comments
 (0)