Skip to content

Commit f82e44f

Browse files
committed
Added tests for SQLAlchemy with pg8000
1 parent 5e38160 commit f82e44f

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ asyncpg
22
Django
33
numpy
44
peewee
5+
pg8000
56
psycopg[binary,pool]
67
psycopg2-binary
78
pytest

tests/test_sqlalchemy.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import os
23
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum
34
import pytest
45
from sqlalchemy import create_engine, event, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY
@@ -16,7 +17,8 @@
1617
sqlalchemy_version = 1
1718

1819
psycopg2_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
19-
engines = [psycopg2_engine]
20+
pg8000_engine = create_engine(f'postgresql+pg8000://{os.environ['USER']}@localhost/pgvector_python_test')
21+
engines = [psycopg2_engine, pg8000_engine]
2022

2123
if sqlalchemy_version > 1:
2224
psycopg_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
@@ -151,9 +153,9 @@ def test_orm(self, engine):
151153
stmt = select(Item)
152154
with Session(engine) as session:
153155
items = [v[0] for v in session.execute(stmt).all()]
154-
assert items[0].id in [1, 4]
155-
assert items[1].id in [2, 5]
156-
assert items[2].id in [3, 6]
156+
assert items[0].id in [1, 4, 7]
157+
assert items[1].id in [2, 5, 8]
158+
assert items[2].id in [3, 6, 9]
157159
assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3]))
158160
assert items[0].embedding.dtype == np.float32
159161
assert np.array_equal(items[1].embedding, np.array([4, 5, 6]))
@@ -290,12 +292,18 @@ def test_bit_hamming_distance_orm(self, engine):
290292
assert [v.id for v in items] == [2, 3, 1]
291293

292294
def test_bit_jaccard_distance(self, engine):
295+
if engine == pg8000_engine:
296+
return
297+
293298
create_items()
294299
with Session(engine) as session:
295300
items = session.query(Item).order_by(Item.binary_embedding.jaccard_distance('101')).all()
296301
assert [v.id for v in items] == [2, 3, 1]
297302

298303
def test_bit_jaccard_distance_orm(self, engine):
304+
if engine == pg8000_engine:
305+
return
306+
299307
create_items()
300308
with Session(engine) as session:
301309
items = session.scalars(select(Item).order_by(Item.binary_embedding.jaccard_distance('101')))

0 commit comments

Comments
 (0)