|
19 | 19 |
|
20 | 20 | psycopg2_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
|
21 | 21 | pg8000_engine = create_engine(f'postgresql+pg8000://{os.environ["USER"]}@localhost/pgvector_python_test')
|
22 |
| -psycopg2_array_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test') |
| 22 | +psycopg2_type_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test') |
23 | 23 |
|
24 | 24 |
|
25 |
| -@event.listens_for(psycopg2_array_engine, "connect") |
| 25 | +@event.listens_for(psycopg2_type_engine, "connect") |
26 | 26 | def psycopg2_connect(dbapi_connection, connection_record):
|
27 | 27 | from pgvector.psycopg2 import register_vector
|
28 | 28 | register_vector(dbapi_connection, globally=False, arrays=True)
|
29 | 29 |
|
30 | 30 |
|
31 |
| -engines = [psycopg2_engine, pg8000_engine] |
32 |
| -array_engines = [psycopg2_array_engine] |
| 31 | +engines = [psycopg2_engine, pg8000_engine, psycopg2_type_engine] |
| 32 | +array_engines = [psycopg2_type_engine] |
33 | 33 | async_engines = []
|
34 | 34 |
|
35 | 35 | if sqlalchemy_version > 1:
|
36 | 36 | psycopg_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
|
37 | 37 | engines.append(psycopg_engine)
|
38 | 38 |
|
| 39 | + psycopg_type_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test') |
| 40 | + |
| 41 | + @event.listens_for(psycopg_type_engine, "connect") |
| 42 | + def psycopg_connect(dbapi_connection, connection_record): |
| 43 | + from pgvector.psycopg import register_vector |
| 44 | + register_vector(dbapi_connection) |
| 45 | + |
| 46 | + engines.append(psycopg_type_engine) |
| 47 | + array_engines.append(psycopg_type_engine) |
| 48 | + |
39 | 49 | psycopg_async_engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
|
40 | 50 | async_engines.append(psycopg_async_engine)
|
41 | 51 |
|
42 | 52 | asyncpg_engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
|
43 | 53 | async_engines.append(asyncpg_engine)
|
44 | 54 |
|
45 |
| - psycopg_array_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test') |
46 |
| - array_engines.append(psycopg_array_engine) |
47 |
| - |
48 |
| - @event.listens_for(psycopg_array_engine, "connect") |
49 |
| - def psycopg_connect(dbapi_connection, connection_record): |
50 |
| - from pgvector.psycopg import register_vector |
51 |
| - register_vector(dbapi_connection) |
52 |
| - |
53 | 55 | setup_engine = engines[0]
|
54 | 56 | with Session(setup_engine) as session:
|
55 | 57 | session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
|
@@ -169,9 +171,10 @@ def test_orm(self, engine):
|
169 | 171 | stmt = select(Item)
|
170 | 172 | with Session(engine) as session:
|
171 | 173 | items = [v[0] for v in session.execute(stmt).all()]
|
172 |
| - assert items[0].id in [1, 4, 7] |
173 |
| - assert items[1].id in [2, 5, 8] |
174 |
| - assert items[2].id in [3, 6, 9] |
| 174 | + # TODO improve |
| 175 | + assert items[0].id % 3 == 1 |
| 176 | + assert items[1].id % 3 == 2 |
| 177 | + assert items[2].id % 3 == 0 |
175 | 178 | assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3]))
|
176 | 179 | assert items[0].embedding.dtype == np.float32
|
177 | 180 | assert np.array_equal(items[1].embedding, np.array([4, 5, 6]))
|
|
0 commit comments