Skip to content

Commit 55b7c01

Browse files
committed
Native Vector support
1 parent 122c1b6 commit 55b7c01

File tree

5 files changed

+199
-4
lines changed

5 files changed

+199
-4
lines changed

sqlalchemy_iris/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .base import VARBINARY
2626
from .base import VARCHAR
2727
from .base import IRISListBuild
28+
from .base import IRISVector
2829

2930
base.dialect = dialect = iris.dialect
3031

@@ -47,5 +48,6 @@
4748
"VARBINARY",
4849
"VARCHAR",
4950
"IRISListBuild",
51+
"IRISVector",
5052
"dialect",
5153
]

sqlalchemy_iris/base.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
import intersystems_iris.dbapi._DBAPI as dbapi
3+
import intersystems_iris._IRISNative as IRISNative
34
from . import information_schema as ischema
45
from sqlalchemy import exc
56
from sqlalchemy.orm import aliased
@@ -91,7 +92,8 @@ def check_constraints(cls):
9192
from .types import IRISDate
9293
from .types import IRISDateTime
9394
from .types import IRISUniqueIdentifier
94-
from .types import IRISListBuild
95+
from .types import IRISListBuild # noqa
96+
from .types import IRISVector # noqa
9597

9698

9799
ischema_names = {
@@ -398,7 +400,9 @@ def check_constraints(cls):
398400
class IRISCompiler(sql.compiler.SQLCompiler):
399401
"""IRIS specific idiosyncrasies"""
400402

401-
def visit_exists_unary_operator(self, element, operator, within_columns_clause=False, **kw):
403+
def visit_exists_unary_operator(
404+
self, element, operator, within_columns_clause=False, **kw
405+
):
402406
if within_columns_clause:
403407
return "(SELECT 1 WHERE EXISTS(%s))" % self.process(element.element, **kw)
404408
else:
@@ -853,6 +857,8 @@ class IRISDialect(default.DefaultDialect):
853857
supports_empty_insert = False
854858
supports_is_distinct_from = False
855859

860+
supports_vectors = None
861+
856862
colspecs = colspecs
857863

858864
ischema_names = ischema_names
@@ -870,6 +876,11 @@ class IRISDialect(default.DefaultDialect):
870876
def __init__(self, **kwargs):
871877
default.DefaultDialect.__init__(self, **kwargs)
872878

879+
def _get_server_version_info(self, connection):
880+
server_version = connection.connection._connection_info._server_version
881+
server_version = server_version[server_version.find("Version") + 8:].split(" ")[0].split(".")
882+
return tuple([int(''.join(filter(str.isdigit, v))) for v in server_version])
883+
873884
_isolation_lookup = set(
874885
[
875886
"READ UNCOMMITTED",
@@ -888,6 +899,14 @@ def on_connect(conn):
888899
if super_ is not None:
889900
super_(conn)
890901

902+
iris = IRISNative.createIRIS(conn)
903+
self.supports_vectors = iris.classMethodBoolean("%SYSTEM.License", "GetFeature", 28)
904+
if self.supports_vectors:
905+
with conn.cursor() as cursor:
906+
# Distance or similarity
907+
cursor.execute("select vector_cosine(to_vector('1'), to_vector('1'))")
908+
self.vector_cosine_similarity = cursor.fetchone()[0] == 0
909+
891910
self._dictionary_access = False
892911
with conn.cursor() as cursor:
893912
cursor.execute("%CHECKPRIV SELECT ON %Dictionary.PropertyDefinition")

sqlalchemy_iris/requirements.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from sqlalchemy.testing.requirements import SuiteRequirements
2+
from sqlalchemy.testing.exclusions import against
3+
from sqlalchemy.testing.exclusions import only_on
24

35
try:
46
from alembic.testing.requirements import SuiteRequirements as AlembicRequirements
@@ -257,3 +259,13 @@ def fk_onupdate_restrict(self):
257259
@property
258260
def fk_ondelete_restrict(self):
259261
return exclusions.closed()
262+
263+
def _iris_vector(self, config):
264+
if not against(config, "iris >= 2024.1"):
265+
return False
266+
else:
267+
return config.db.dialect.supports_vectors
268+
269+
@property
270+
def iris_vector(self):
271+
return only_on(lambda config: self._iris_vector(config))

sqlalchemy_iris/types.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import datetime
22
from decimal import Decimal
3-
from sqlalchemy import func
3+
from sqlalchemy import func, text
44
from sqlalchemy.sql import sqltypes
5-
from sqlalchemy.types import UserDefinedType
5+
from sqlalchemy.types import UserDefinedType, Float
66
from uuid import UUID as _python_UUID
77
from intersystems_iris import IRISList
88

@@ -247,6 +247,72 @@ def func(self, funcname: str, other):
247247
return getattr(func, funcname)(self, irislist.getBuffer())
248248

249249

250+
class IRISVector(UserDefinedType):
251+
cache_ok = True
252+
253+
def __init__(self, max_items: int = None, item_type: type = float):
254+
super(UserDefinedType, self).__init__()
255+
if item_type not in [float, int, Decimal]:
256+
raise TypeError(
257+
f"IRISVector expected int, float or Decimal; got {type.__name__}; expected: int, float, Decimal"
258+
)
259+
self.max_items = max_items
260+
self.item_type = item_type
261+
item_type_server = (
262+
"decimal"
263+
if self.item_type is float
264+
else "float"
265+
if self.item_type is Decimal
266+
else "int"
267+
)
268+
self.item_type_server = item_type_server
269+
270+
def get_col_spec(self, **kw):
271+
if self.max_items is None and self.item_type is None:
272+
return "VECTOR"
273+
len = str(self.max_items or "")
274+
return f"VECTOR({self.item_type_server}, {len})"
275+
276+
def bind_processor(self, dialect):
277+
def process(value):
278+
if not value:
279+
return value
280+
if not isinstance(value, list) and not isinstance(value, tuple):
281+
raise ValueError("expected list or tuple, got '%s'" % type(value))
282+
return f"[{','.join([str(v) for v in value])}]"
283+
284+
return process
285+
286+
def result_processor(self, dialect, coltype):
287+
def process(value):
288+
if not value:
289+
return value
290+
vals = value.split(",")
291+
vals = [self.item_type(v) for v in vals]
292+
return vals
293+
294+
return process
295+
296+
class comparator_factory(UserDefinedType.Comparator):
297+
# def l2_distance(self, other):
298+
# return self.func('vector_l2', other)
299+
300+
def max_inner_product(self, other):
301+
return self.func('vector_dot_product', other)
302+
303+
def cosine_distance(self, other):
304+
return self.func('vector_cosine', other)
305+
306+
def cosine(self, other):
307+
return (1 - self.func('vector_cosine', other))
308+
309+
def func(self, funcname: str, other):
310+
if not isinstance(other, list) and not isinstance(other, tuple):
311+
raise ValueError("expected list or tuple, got '%s'" % type(other))
312+
othervalue = f"[{','.join([str(v) for v in other])}]"
313+
return getattr(func, funcname)(self, func.to_vector(othervalue, text(self.type.item_type_server)))
314+
315+
250316
class BIT(sqltypes.TypeEngine):
251317
__visit_name__ = "BIT"
252318

tests/test_suite.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from sqlalchemy.types import VARBINARY
1919
from sqlalchemy.types import BINARY
2020
from sqlalchemy_iris import TINYINT
21+
from sqlalchemy_iris import INTEGER
2122
from sqlalchemy_iris import IRISListBuild
23+
from sqlalchemy_iris import IRISVector
2224
from sqlalchemy.exc import DatabaseError
2325
import pytest
2426

@@ -337,3 +339,97 @@ def test_listbuild(self):
337339
([1.0] * 50, 1),
338340
],
339341
)
342+
343+
344+
class IRISVectorTest(fixtures.TablesTest):
345+
__backend__ = True
346+
347+
__requires__ = ("iris_vector",)
348+
349+
@classmethod
350+
def define_tables(cls, metadata):
351+
Table(
352+
"data",
353+
metadata,
354+
Column("id", INTEGER),
355+
Column("emb", IRISVector(3, float)),
356+
)
357+
358+
@classmethod
359+
def fixtures(cls):
360+
return dict(
361+
data=(
362+
(
363+
"id",
364+
"emb",
365+
),
366+
(
367+
1,
368+
[1, 1, 1],
369+
),
370+
(
371+
2,
372+
[2, 2, 2],
373+
),
374+
(
375+
3,
376+
[1, 1, 2],
377+
),
378+
)
379+
)
380+
381+
def _assert_result(self, select, result):
382+
with config.db.connect() as conn:
383+
eq_(conn.execute(select).fetchall(), result)
384+
385+
def test_vector(self):
386+
self._assert_result(
387+
select(self.tables.data.c.emb),
388+
[
389+
([1, 1, 1],),
390+
([2, 2, 2],),
391+
([1, 1, 2],),
392+
],
393+
)
394+
self._assert_result(
395+
select(self.tables.data.c.id).where(self.tables.data.c.emb == [2, 2, 2]),
396+
[
397+
(2,),
398+
],
399+
)
400+
401+
def test_cosine(self):
402+
self._assert_result(
403+
select(
404+
self.tables.data.c.id,
405+
).order_by(self.tables.data.c.emb.cosine([1, 1, 1])),
406+
[
407+
(1,),
408+
(2,),
409+
(3,),
410+
],
411+
)
412+
413+
def test_cosine_distance(self):
414+
self._assert_result(
415+
select(
416+
self.tables.data.c.id,
417+
).order_by(1 - self.tables.data.c.emb.cosine_distance([1, 1, 1])),
418+
[
419+
(1,),
420+
(2,),
421+
(3,),
422+
],
423+
)
424+
425+
def test_max_inner_product(self):
426+
self._assert_result(
427+
select(
428+
self.tables.data.c.id,
429+
).order_by(self.tables.data.c.emb.max_inner_product([1, 1, 1])),
430+
[
431+
(1,),
432+
(3,),
433+
(2,),
434+
],
435+
)

0 commit comments

Comments
 (0)