|
1 | 1 | import numpy as np
|
| 2 | +import os |
2 | 3 | from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum
|
3 | 4 | import pytest
|
4 | 5 | from sqlalchemy import create_engine, event, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY
|
|
16 | 17 | sqlalchemy_version = 1
|
17 | 18 |
|
18 | 19 | psycopg2_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
|
19 |
| -engines = [psycopg2_engine] |
| 20 | +pg8000_engine = create_engine(f'postgresql+pg8000://{os.environ['USER']}@localhost/pgvector_python_test') |
| 21 | +engines = [psycopg2_engine, pg8000_engine] |
20 | 22 |
|
21 | 23 | if sqlalchemy_version > 1:
|
22 | 24 | psycopg_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
|
@@ -151,9 +153,9 @@ def test_orm(self, engine):
|
151 | 153 | stmt = select(Item)
|
152 | 154 | with Session(engine) as session:
|
153 | 155 | items = [v[0] for v in session.execute(stmt).all()]
|
154 |
| - assert items[0].id in [1, 4] |
155 |
| - assert items[1].id in [2, 5] |
156 |
| - assert items[2].id in [3, 6] |
| 156 | + assert items[0].id in [1, 4, 7] |
| 157 | + assert items[1].id in [2, 5, 8] |
| 158 | + assert items[2].id in [3, 6, 9] |
157 | 159 | assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3]))
|
158 | 160 | assert items[0].embedding.dtype == np.float32
|
159 | 161 | assert np.array_equal(items[1].embedding, np.array([4, 5, 6]))
|
@@ -290,12 +292,18 @@ def test_bit_hamming_distance_orm(self, engine):
|
290 | 292 | assert [v.id for v in items] == [2, 3, 1]
|
291 | 293 |
|
292 | 294 | def test_bit_jaccard_distance(self, engine):
|
| 295 | + if engine == pg8000_engine: |
| 296 | + return |
| 297 | + |
293 | 298 | create_items()
|
294 | 299 | with Session(engine) as session:
|
295 | 300 | items = session.query(Item).order_by(Item.binary_embedding.jaccard_distance('101')).all()
|
296 | 301 | assert [v.id for v in items] == [2, 3, 1]
|
297 | 302 |
|
298 | 303 | def test_bit_jaccard_distance_orm(self, engine):
|
| 304 | + if engine == pg8000_engine: |
| 305 | + return |
| 306 | + |
299 | 307 | create_items()
|
300 | 308 | with Session(engine) as session:
|
301 | 309 | items = session.scalars(select(Item).order_by(Item.binary_embedding.jaccard_distance('101')))
|
|
0 commit comments