Skip to content

Commit 6732a07

Browse files
committed
IRIS ListBuild as a Column type
1 parent f471175 commit 6732a07

File tree

4 files changed

+90
-0
lines changed

4 files changed

+90
-0
lines changed

sqlalchemy_iris/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .base import TINYINT
2525
from .base import VARBINARY
2626
from .base import VARCHAR
27+
from .base import IRISListBuild
2728

2829
base.dialect = dialect = iris.dialect
2930

@@ -45,5 +46,6 @@
4546
"TINYINT",
4647
"VARBINARY",
4748
"VARCHAR",
49+
"IRISListBuild",
4850
"dialect",
4951
]

sqlalchemy_iris/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def check_constraints(cls):
9191
from .types import IRISDate
9292
from .types import IRISDateTime
9393
from .types import IRISUniqueIdentifier
94+
from .types import IRISListBuild
9495

9596

9697
ischema_names = {

sqlalchemy_iris/types.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import datetime
22
from decimal import Decimal
33
from sqlalchemy.sql import sqltypes
4+
from sqlalchemy.types import UserDefinedType
45
from uuid import UUID as _python_UUID
6+
from intersystems_iris import IRISList
57

68
HOROLOG_ORDINAL = datetime.date(1840, 12, 31).toordinal()
79

@@ -194,6 +196,47 @@ def process(value):
194196
return None
195197

196198

199+
class IRISListBuild(UserDefinedType):
200+
cache_ok = True
201+
202+
def __init__(self, max_items: int = None, item_type: type = float):
203+
super(UserDefinedType, self).__init__()
204+
self.max_items = max_items
205+
max_length = None
206+
if type is float or type is int:
207+
max_length = max_items * 10
208+
elif max_items:
209+
max_length = 65535
210+
self.max_length = max_length
211+
212+
def get_col_spec(self, **kw):
213+
if self.max_length is None:
214+
return "VARBINARY(65535)"
215+
return "VARBINARY(%d)" % self.max_length
216+
217+
def bind_processor(self, dialect):
218+
def process(value):
219+
irislist = IRISList()
220+
if not value:
221+
return value
222+
if not isinstance(value, list) and not isinstance(value, tuple):
223+
raise ValueError("expected list or tuple, got '%s'" % type(value))
224+
for item in value:
225+
irislist.add(item)
226+
return irislist.getBuffer()
227+
228+
return process
229+
230+
def result_processor(self, dialect, coltype):
231+
def process(value):
232+
if value:
233+
irislist = IRISList(value)
234+
return irislist._list_data
235+
return value
236+
237+
return process
238+
239+
197240
class BIT(sqltypes.TypeEngine):
198241
__visit_name__ = "BIT"
199242

@@ -212,3 +255,7 @@ class LONGVARCHAR(sqltypes.VARCHAR):
212255

213256
class LONGVARBINARY(sqltypes.VARBINARY):
214257
__visit_name__ = "LONGVARBINARY"
258+
259+
260+
class LISTBUILD(sqltypes.VARBINARY):
261+
__visit_name__ = "VARCHAR"

tests/test_suite.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sqlalchemy.types import VARBINARY
1919
from sqlalchemy.types import BINARY
2020
from sqlalchemy_iris import TINYINT
21+
from sqlalchemy_iris import IRISListBuild
2122
from sqlalchemy.exc import DatabaseError
2223
import pytest
2324

@@ -281,3 +282,42 @@ class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest):
281282
)
282283
def test_fk_ref(self, connection, metadata, use_composite, tablename, columnname):
283284
super().test_fk_ref(connection, metadata, use_composite, tablename, columnname)
285+
286+
287+
class IRISListBuildTest(fixtures.TablesTest):
288+
__backend__ = True
289+
290+
@classmethod
291+
def define_tables(cls, metadata):
292+
Table(
293+
"data",
294+
metadata,
295+
Column("val", IRISListBuild(10, float)),
296+
)
297+
298+
@classmethod
299+
def fixtures(cls):
300+
return dict(
301+
data=(
302+
("val",),
303+
([1.0] * 50,),
304+
([1.23] * 50,),
305+
([i for i in range(0, 50)],),
306+
(None,),
307+
)
308+
)
309+
310+
def _assert_result(self, select, result):
311+
with config.db.connect() as conn:
312+
eq_(conn.execute(select).fetchall(), result)
313+
314+
def test_listbuild(self):
315+
self._assert_result(
316+
select(self.tables.data),
317+
[
318+
([1.0] * 50,),
319+
([1.23] * 50,),
320+
([i for i in range(0, 50)],),
321+
(None,),
322+
],
323+
)

0 commit comments

Comments
 (0)