Skip to content

Commit c792451

Browse files
committed
Test more engine configurations
1 parent a1d8997 commit c792451

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

tests/test_sqlalchemy.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,37 +19,39 @@
1919

2020
psycopg2_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
2121
pg8000_engine = create_engine(f'postgresql+pg8000://{os.environ["USER"]}@localhost/pgvector_python_test')
22-
psycopg2_array_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
22+
psycopg2_type_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
2323

2424

25-
@event.listens_for(psycopg2_array_engine, "connect")
25+
@event.listens_for(psycopg2_type_engine, "connect")
2626
def psycopg2_connect(dbapi_connection, connection_record):
2727
from pgvector.psycopg2 import register_vector
2828
register_vector(dbapi_connection, globally=False, arrays=True)
2929

3030

31-
engines = [psycopg2_engine, pg8000_engine]
32-
array_engines = [psycopg2_array_engine]
31+
engines = [psycopg2_engine, pg8000_engine, psycopg2_type_engine]
32+
array_engines = [psycopg2_type_engine]
3333
async_engines = []
3434

3535
if sqlalchemy_version > 1:
3636
psycopg_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
3737
engines.append(psycopg_engine)
3838

39+
psycopg_type_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
40+
41+
@event.listens_for(psycopg_type_engine, "connect")
42+
def psycopg_connect(dbapi_connection, connection_record):
43+
from pgvector.psycopg import register_vector
44+
register_vector(dbapi_connection)
45+
46+
engines.append(psycopg_type_engine)
47+
array_engines.append(psycopg_type_engine)
48+
3949
psycopg_async_engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
4050
async_engines.append(psycopg_async_engine)
4151

4252
asyncpg_engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
4353
async_engines.append(asyncpg_engine)
4454

45-
psycopg_array_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
46-
array_engines.append(psycopg_array_engine)
47-
48-
@event.listens_for(psycopg_array_engine, "connect")
49-
def psycopg_connect(dbapi_connection, connection_record):
50-
from pgvector.psycopg import register_vector
51-
register_vector(dbapi_connection)
52-
5355
setup_engine = engines[0]
5456
with Session(setup_engine) as session:
5557
session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
@@ -169,9 +171,10 @@ def test_orm(self, engine):
169171
stmt = select(Item)
170172
with Session(engine) as session:
171173
items = [v[0] for v in session.execute(stmt).all()]
172-
assert items[0].id in [1, 4, 7]
173-
assert items[1].id in [2, 5, 8]
174-
assert items[2].id in [3, 6, 9]
174+
# TODO improve
175+
assert items[0].id % 3 == 1
176+
assert items[1].id % 3 == 2
177+
assert items[2].id % 3 == 0
175178
assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3]))
176179
assert items[0].embedding.dtype == np.float32
177180
assert np.array_equal(items[1].embedding, np.array([4, 5, 6]))

0 commit comments

Comments
 (0)