Skip to content

Commit a157c4e

Browse files
committed
Merge branch 'master' into return-vector
2 parents e350968 + e566d4c commit a157c4e

13 files changed

+118
-66
lines changed

.github/workflows/build.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,6 @@ jobs:
2424
make
2525
sudo make install
2626
- run: pytest
27+
28+
- run: pip install "SQLAlchemy<2" -U
29+
- run: pytest tests/test_sqlalchemy.py

pgvector/bit.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@ def __init__(self, value):
77
if isinstance(value, str):
88
self._value = self.from_text(value)._value
99
else:
10-
# TODO change in 0.4.0
11-
# TODO raise if dtype not bool or uint8
12-
# if isinstance(value, np.ndarray) and value.dtype == np.uint8:
13-
# value = np.unpackbits(value)
14-
# else:
15-
# value = np.asarray(value, dtype=bool)
16-
17-
value = np.asarray(value, dtype=bool)
10+
if isinstance(value, np.ndarray):
11+
if value.dtype == np.uint8:
12+
value = np.unpackbits(value).astype(bool)
13+
elif value.dtype != np.bool:
14+
raise ValueError('expected dtype to be bool or uint8')
15+
else:
16+
value = np.asarray(value, dtype=bool)
1817

1918
if value.ndim != 1:
2019
raise ValueError('expected ndim to be 1')

tests/test_asyncpg.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncpg
22
import numpy as np
3-
from pgvector import Vector, HalfVector, SparseVector
3+
from pgvector import HalfVector, SparseVector, Vector
44
from pgvector.asyncpg import register_vector
55
import pytest
66

@@ -16,11 +16,13 @@ async def test_vector(self):
1616
await register_vector(conn)
1717

1818
embedding = Vector([1.5, 2, 3])
19-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
19+
embedding2 = np.array([4.5, 5, 6])
20+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
2021

2122
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
2223
assert res[0]['embedding'] == embedding
23-
assert res[1]['embedding'] is None
24+
assert res[1]['embedding'] == Vector(embedding2)
25+
assert res[2]['embedding'] is None
2426

2527
# ensures binary format is correct
2628
text_res = await conn.fetch("SELECT embedding::text FROM asyncpg_items ORDER BY id LIMIT 1")
@@ -38,11 +40,13 @@ async def test_halfvec(self):
3840
await register_vector(conn)
3941

4042
embedding = HalfVector([1.5, 2, 3])
41-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
43+
embedding2 = [4.5, 5, 6]
44+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
4245

4346
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
4447
assert res[0]['embedding'] == embedding
45-
assert res[1]['embedding'] is None
48+
assert res[1]['embedding'] == HalfVector(embedding2)
49+
assert res[2]['embedding'] is None
4650

4751
# ensures binary format is correct
4852
text_res = await conn.fetch("SELECT embedding::text FROM asyncpg_items ORDER BY id LIMIT 1")
@@ -105,11 +109,14 @@ async def test_vector_array(self):
105109
await register_vector(conn)
106110

107111
embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])]
108-
await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])", embeddings[0], embeddings[1])
112+
await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES ($1)", embeddings)
113+
114+
embeddings2 = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])]
115+
await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])", embeddings2[0], embeddings2[1])
109116

110117
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
111-
assert res[0]['embeddings'][0] == embeddings[0]
112-
assert res[0]['embeddings'][1] == embeddings[1]
118+
assert res[0]['embeddings'] == embeddings
119+
assert res[1]['embeddings'] == [Vector(e) for e in embeddings2]
113120

114121
await conn.close()
115122

@@ -126,8 +133,10 @@ async def init(conn):
126133
await conn.execute('CREATE TABLE asyncpg_items (id bigserial PRIMARY KEY, embedding vector(3))')
127134

128135
embedding = Vector([1.5, 2, 3])
129-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
136+
embedding2 = np.array([1.5, 2, 3])
137+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
130138

131139
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
132140
assert res[0]['embedding'] == embedding
133-
assert res[1]['embedding'] is None
141+
assert res[1]['embedding'] == Vector(embedding2)
142+
assert res[2]['embedding'] is None

tests/test_bit.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@ def test_str(self):
1515

1616
def test_ndarray_uint8(self):
1717
arr = np.array([254, 7, 0], dtype=np.uint8)
18-
# TODO change in 0.4.0
19-
# assert Bit(arr).to_text() == '111111100000011100000000'
20-
assert Bit(arr).to_text() == '110'
18+
assert Bit(arr).to_text() == '111111100000011100000000'
19+
20+
def test_ndarray_uint16(self):
21+
arr = np.array([254, 7, 0], dtype=np.uint16)
22+
with pytest.raises(ValueError) as error:
23+
Bit(arr)
24+
assert str(error.value) == 'expected dtype to be bool or uint8'
2125

2226
def test_ndarray_same_object(self):
2327
arr = np.array([True, False, True])

tests/test_django.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def test_vector_l1_distance(self):
198198
def test_halfvec(self):
199199
Item(id=1, half_embedding=[1, 2, 3]).save()
200200
item = Item.objects.get(pk=1)
201-
assert item.half_embedding.to_list() == [1, 2, 3]
201+
assert item.half_embedding == HalfVector([1, 2, 3])
202202

203203
def test_halfvec_l2_distance(self):
204204
create_items()
@@ -250,7 +250,7 @@ def test_bit_jaccard_distance(self):
250250
def test_sparsevec(self):
251251
Item(id=1, sparse_embedding=SparseVector([1, 2, 3])).save()
252252
item = Item.objects.get(pk=1)
253-
assert item.sparse_embedding.to_list() == [1, 2, 3]
253+
assert item.sparse_embedding == SparseVector([1, 2, 3])
254254

255255
def test_sparsevec_l2_distance(self):
256256
create_items()
@@ -346,7 +346,7 @@ def test_vector_form_save(self):
346346
assert form.has_changed()
347347
assert form.is_valid()
348348
assert form.save()
349-
assert [4, 5, 6] == Item.objects.get(pk=1).embedding.to_list()
349+
assert Item.objects.get(pk=1).embedding == Vector([4, 5, 6])
350350

351351
def test_vector_form_save_missing(self):
352352
Item(id=1).save()
@@ -374,7 +374,7 @@ def test_halfvec_form_save(self):
374374
assert form.has_changed()
375375
assert form.is_valid()
376376
assert form.save()
377-
assert [4, 5, 6] == Item.objects.get(pk=1).half_embedding.to_list()
377+
assert Item.objects.get(pk=1).half_embedding == HalfVector([4, 5, 6])
378378

379379
def test_halfvec_form_save_missing(self):
380380
Item(id=1).save()
@@ -431,7 +431,7 @@ def test_sparsevec_form_save(self):
431431
assert form.has_changed()
432432
assert form.is_valid()
433433
assert form.save()
434-
assert [4, 5, 6] == Item.objects.get(pk=1).sparse_embedding.to_list()
434+
assert Item.objects.get(pk=1).sparse_embedding == SparseVector([4, 5, 6])
435435

436436
def test_sparesevec_form_save_missing(self):
437437
Item(id=1).save()
@@ -464,8 +464,7 @@ def test_vector_array(self):
464464

465465
# this fails if the driver does not cast arrays
466466
item = Item.objects.get(pk=1)
467-
assert item.embeddings[0].to_list() == [1, 2, 3]
468-
assert item.embeddings[1].to_list() == [4, 5, 6]
467+
assert item.embeddings == [Vector([1, 2, 3]), Vector([4, 5, 6])]
469468

470469
def test_double_array(self):
471470
Item(id=1, double_embedding=[1, 1, 1]).save()

tests/test_half_vector.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from pgvector import HalfVector
33
import pytest
4+
from struct import pack
45

56

67
class TestHalfVector:
@@ -44,3 +45,15 @@ def test_equality(self):
4445

4546
def test_dimensions(self):
4647
assert HalfVector([1, 2, 3]).dimensions() == 3
48+
49+
def test_from_text(self):
50+
vec = HalfVector.from_text('[1.5,2,3]')
51+
assert vec.to_list() == [1.5, 2, 3]
52+
assert np.array_equal(vec.to_numpy(), [1.5, 2, 3])
53+
54+
def test_from_binary(self):
55+
data = pack('>HH3e', 3, 0, 1.5, 2, 3)
56+
vec = HalfVector.from_binary(data)
57+
assert vec.to_list() == [1.5, 2, 3]
58+
assert np.array_equal(vec.to_numpy(), [1.5, 2, 3])
59+
assert vec.to_binary() == data

tests/test_peewee.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from math import sqrt
22
import numpy as np
33
from peewee import Model, PostgresqlDatabase, fn
4-
from pgvector import Vector, HalfVector, SparseVector
4+
from pgvector import HalfVector, SparseVector, Vector
55
from pgvector.peewee import VectorField, HalfVectorField, FixedBitField, SparseVectorField
66

77
db = PostgresqlDatabase('pgvector_python_test')
@@ -76,7 +76,7 @@ def test_vector_l1_distance(self):
7676
def test_halfvec(self):
7777
Item.create(id=1, half_embedding=[1, 2, 3])
7878
item = Item.get_by_id(1)
79-
assert item.half_embedding.to_list() == [1, 2, 3]
79+
assert item.half_embedding == HalfVector([1, 2, 3])
8080

8181
def test_halfvec_l2_distance(self):
8282
create_items()
@@ -128,7 +128,7 @@ def test_bit_jaccard_distance(self):
128128
def test_sparsevec(self):
129129
Item.create(id=1, sparse_embedding=[1, 2, 3])
130130
item = Item.get_by_id(1)
131-
assert item.sparse_embedding.to_list() == [1, 2, 3]
131+
assert item.sparse_embedding == SparseVector([1, 2, 3])
132132

133133
def test_sparsevec_l2_distance(self):
134134
create_items()
@@ -219,5 +219,5 @@ class Meta:
219219
# fails with column "embeddings" is of type vector[] but expression is of type text[]
220220
# ExtItem.create(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])])
221221
# item = ExtItem.get_by_id(1)
222-
# assert np.array_equal(item.embeddings[0], np.array([1, 2, 3]))
223-
# assert np.array_equal(item.embeddings[1], np.array([4, 5, 6]))
222+
# assert np.array_equal(item.embeddings[0], [1, 2, 3])
223+
# assert np.array_equal(item.embeddings[1], [4, 5, 6])

tests/test_psycopg.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,17 @@ def test_halfvec(self):
6868
conn.execute('INSERT INTO psycopg_items (half_embedding) VALUES (%s)', (embedding,))
6969

7070
res = conn.execute('SELECT half_embedding FROM psycopg_items ORDER BY id').fetchone()[0]
71-
assert res.to_list() == [1.5, 2, 3]
71+
assert res == HalfVector([1.5, 2, 3])
7272

7373
def test_halfvec_binary_format(self):
7474
embedding = HalfVector([1.5, 2, 3])
7575
res = conn.execute('SELECT %b::halfvec', (embedding,), binary=True).fetchone()[0]
76-
assert res.to_list() == [1.5, 2, 3]
77-
assert np.array_equal(res.to_numpy(), np.array([1.5, 2, 3]))
76+
assert res == HalfVector([1.5, 2, 3])
7877

7978
def test_halfvec_text_format(self):
8079
embedding = HalfVector([1.5, 2, 3])
8180
res = conn.execute('SELECT %t::halfvec', (embedding,)).fetchone()[0]
82-
assert res.to_list() == [1.5, 2, 3]
83-
assert np.array_equal(res.to_numpy(), np.array([1.5, 2, 3]))
81+
assert res == HalfVector([1.5, 2, 3])
8482

8583
def test_bit(self):
8684
embedding = Bit([True, False, True])
@@ -105,25 +103,17 @@ def test_sparsevec(self):
105103
conn.execute('INSERT INTO psycopg_items (sparse_embedding) VALUES (%s)', (embedding,))
106104

107105
res = conn.execute('SELECT sparse_embedding FROM psycopg_items ORDER BY id').fetchone()[0]
108-
assert res.to_list() == [1.5, 2, 3]
106+
assert res == SparseVector([1.5, 2, 3])
109107

110108
def test_sparsevec_binary_format(self):
111109
embedding = SparseVector([1.5, 0, 2, 0, 3, 0])
112110
res = conn.execute('SELECT %b::sparsevec', (embedding,), binary=True).fetchone()[0]
113-
assert res.dimensions() == 6
114-
assert res.indices() == [0, 2, 4]
115-
assert res.values() == [1.5, 2, 3]
116-
assert res.to_list() == [1.5, 0, 2, 0, 3, 0]
117-
assert np.array_equal(res.to_numpy(), np.array([1.5, 0, 2, 0, 3, 0]))
111+
assert res == embedding
118112

119113
def test_sparsevec_text_format(self):
120114
embedding = SparseVector([1.5, 0, 2, 0, 3, 0])
121115
res = conn.execute('SELECT %t::sparsevec', (embedding,)).fetchone()[0]
122-
assert res.dimensions() == 6
123-
assert res.indices() == [0, 2, 4]
124-
assert res.values() == [1.5, 2, 3]
125-
assert res.to_list() == [1.5, 0, 2, 0, 3, 0]
126-
assert np.array_equal(res.to_numpy(), np.array([1.5, 0, 2, 0, 3, 0]))
116+
assert res == embedding
127117

128118
def test_text_copy_from(self):
129119
embedding = np.array([1.5, 2, 3])
@@ -161,8 +151,8 @@ def test_binary_copy_to(self):
161151
cur = conn.cursor()
162152
with cur.copy("COPY psycopg_items (embedding, half_embedding) TO STDOUT WITH (FORMAT BINARY)") as copy:
163153
for row in copy.rows():
164-
assert Vector.from_binary(row[0]).to_list() == [1.5, 2, 3]
165-
assert HalfVector.from_binary(row[1]).to_list() == [1.5, 2, 3]
154+
assert Vector.from_binary(row[0]) == embedding
155+
assert HalfVector.from_binary(row[1]) == half_embedding
166156

167157
def test_binary_copy_to_set_types(self):
168158
embedding = Vector([1.5, 2, 3])

tests/test_psycopg2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,17 @@ def test_halfvec_array(self):
8080

8181
cur.execute('SELECT half_embeddings FROM psycopg2_items ORDER BY id')
8282
res = cur.fetchone()
83-
assert res[0][0].to_list() == [1.5, 2, 3]
84-
assert res[0][1].to_list() == [4.5, 5, 6]
83+
assert res[0][0] == HalfVector([1.5, 2, 3])
84+
assert res[0][1] == HalfVector([4.5, 5, 6])
8585

8686
def test_sparsevec_array(self):
8787
embeddings = [SparseVector([1.5, 2, 3]), SparseVector([4.5, 5, 6])]
8888
cur.execute('INSERT INTO psycopg2_items (sparse_embeddings) VALUES (%s::sparsevec[])', (embeddings,))
8989

9090
cur.execute('SELECT sparse_embeddings FROM psycopg2_items ORDER BY id')
9191
res = cur.fetchone()
92-
assert res[0][0].to_list() == [1.5, 2, 3]
93-
assert res[0][1].to_list() == [4.5, 5, 6]
92+
assert res[0][0] == SparseVector([1.5, 2, 3])
93+
assert res[0][1] == SparseVector([4.5, 5, 6])
9494

9595
def test_cursor_factory(self):
9696
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:

tests/test_sparse_vector.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from pgvector import SparseVector
33
import pytest
44
from scipy.sparse import coo_array
5+
from struct import pack
56

67

78
class TestSparseVector:
89
def test_list(self):
910
vec = SparseVector([1, 0, 2, 0, 3, 0])
1011
assert vec.to_list() == [1, 0, 2, 0, 3, 0]
11-
assert vec.to_numpy().tolist() == [1, 0, 2, 0, 3, 0]
12+
assert np.array_equal(vec.to_numpy(), [1, 0, 2, 0, 3, 0])
1213
assert vec.indices() == [0, 2, 4]
1314

1415
def test_list_dimensions(self):
@@ -56,6 +57,7 @@ def test_equality(self):
5657
assert SparseVector([1, 0, 2, 0, 3, 0]) == SparseVector([1, 0, 2, 0, 3, 0])
5758
assert SparseVector([1, 0, 2, 0, 3, 0]) != SparseVector([1, 0, 2, 0, 3, 1])
5859
assert SparseVector([1, 0, 2, 0, 3, 0]) == SparseVector({2: 2, 4: 3, 0: 1, 3: 0}, 6)
60+
assert SparseVector({}, 1) != SparseVector({}, 2)
5961

6062
def test_dimensions(self):
6163
assert SparseVector([1, 0, 2, 0, 3, 0]).dimensions() == 6
@@ -67,8 +69,26 @@ def test_values(self):
6769
assert SparseVector([1, 0, 2, 0, 3, 0]).values() == [1, 2, 3]
6870

6971
def test_to_coo(self):
70-
assert SparseVector([1, 0, 2, 0, 3, 0]).to_coo().toarray().tolist() == [[1, 0, 2, 0, 3, 0]]
72+
assert np.array_equal(SparseVector([1, 0, 2, 0, 3, 0]).to_coo().toarray(), [[1, 0, 2, 0, 3, 0]])
7173

7274
def test_zero_vector_text(self):
7375
vec = SparseVector({}, 3)
7476
assert vec.to_list() == SparseVector.from_text(vec.to_text()).to_list()
77+
78+
def test_from_text(self):
79+
vec = SparseVector.from_text('{1:1.5,3:2,5:3}/6')
80+
assert vec.dimensions() == 6
81+
assert vec.indices() == [0, 2, 4]
82+
assert vec.values() == [1.5, 2, 3]
83+
assert vec.to_list() == [1.5, 0, 2, 0, 3, 0]
84+
assert np.array_equal(vec.to_numpy(), [1.5, 0, 2, 0, 3, 0])
85+
86+
def test_from_binary(self):
87+
data = pack('>iii3i3f', 6, 3, 0, 0, 2, 4, 1.5, 2, 3)
88+
vec = SparseVector.from_binary(data)
89+
assert vec.dimensions() == 6
90+
assert vec.indices() == [0, 2, 4]
91+
assert vec.values() == [1.5, 2, 3]
92+
assert vec.to_list() == [1.5, 0, 2, 0, 3, 0]
93+
assert np.array_equal(vec.to_numpy(), [1.5, 0, 2, 0, 3, 0])
94+
assert vec.to_binary() == data

tests/test_sqlalchemy.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncpg
22
import numpy as np
33
import os
4-
from pgvector import Vector, HalfVector, SparseVector
4+
from pgvector import HalfVector, SparseVector, Vector
55
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, avg, sum
66
import pytest
77
from sqlalchemy import create_engine, event, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY
@@ -539,8 +539,7 @@ def test_vector_array(self, engine):
539539

540540
# this fails if the driver does not cast arrays
541541
item = session.get(Item, 1)
542-
assert item.embeddings[0] == Vector([1, 2, 3])
543-
assert item.embeddings[1] == Vector([4, 5, 6])
542+
assert item.embeddings == [Vector([1, 2, 3]), Vector([4, 5, 6])]
544543

545544
def test_halfvec_array(self, engine):
546545
with Session(engine) as session:
@@ -637,7 +636,10 @@ async def test_vector_array(self, engine):
637636
async with session.begin():
638637
session.add(Item(id=1, embeddings=[Vector([1, 2, 3]), Vector([4, 5, 6])]))
639638
item = await session.get(Item, 1)
640-
assert item.embeddings[0] == Vector([1, 2, 3])
641-
assert item.embeddings[1] == Vector([4, 5, 6])
639+
assert item.embeddings == [Vector([1, 2, 3]), Vector([4, 5, 6])]
640+
641+
session.add(Item(id=2, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
642+
item = await session.get(Item, 2)
643+
assert item.embeddings == [Vector([1, 2, 3]), Vector([4, 5, 6])]
642644

643645
await engine.dispose()

0 commit comments

Comments
 (0)