Skip to content

Commit d9b8cee

Browse files
authored
Support vector index for Peewee (#67)
A VectorAdaptor is introduced similar to how we did for SQLAlchemy. It is possible to utilize Peewee's native index, when we only need one SQL statement to add index. Currently we must execute two SQL statements, so that an adaptor is needed.
1 parent 3e99f50 commit d9b8cee

File tree

10 files changed

+322
-125
lines changed

10 files changed

+322
-125
lines changed

README.md

Lines changed: 75 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ Base.metadata.drop_all(engine) # clean data from last run
5252
Base.metadata.create_all(engine)
5353

5454
# Create index for L2 distance
55-
adaptor = VectorAdaptor(engine)
56-
adaptor.create_vector_index(
55+
VectorAdaptor(engine).create_vector_index(
5756
Doc.embedding, tidb_vector.DistanceMetric.L2, skip_existing=True
57+
# For cosine distance, use tidb_vector.DistanceMetric.COSINE
5858
)
5959

6060
# Insert content with vectors
@@ -69,6 +69,7 @@ with Session(engine) as session:
6969
results = session.execute(
7070
select(Doc.id, Doc.content)
7171
.order_by(Doc.embedding.l2_distance([1, 2, 3]))
72+
# For cosine distance, use Doc.embedding.cosine_distance(...)
7273
.limit(1)
7374
).all()
7475
print(results)
@@ -84,6 +85,78 @@ with Session(engine) as session:
8485
print(results)
8586
```
8687

88+
### Peewee
89+
90+
Install:
91+
92+
```bash
93+
pip install tidb-vector peewee pymysql
94+
```
95+
96+
Usage:
97+
98+
```python
99+
import tidb_vector
100+
from peewee import Model, MySQLDatabase, IntegerField, TextField
101+
from tidb_vector.peewee import VectorField, VectorAdaptor
102+
103+
db = MySQLDatabase(
104+
database="test",
105+
user="root",
106+
password="",
107+
host="127.0.0.1",
108+
port=4000,
109+
)
110+
111+
112+
# Define table schema
113+
class Doc(Model):
114+
class Meta:
115+
database = db
116+
table_name = "peewee_test"
117+
118+
id = IntegerField(primary_key=True)
119+
embedding = VectorField(3)
120+
content = TextField()
121+
122+
123+
# Create empty table and index for L2 distance
124+
db.drop_tables([Doc]) # clean data from last run
125+
db.create_tables([Doc])
126+
# For cosine distance, use tidb_vector.DistanceMetric.COSINE
127+
VectorAdaptor(db).create_vector_index(Doc.embedding, tidb_vector.DistanceMetric.L2)
128+
129+
# Insert content with vectors
130+
Doc.insert_many(
131+
[
132+
{"id": 1, "content": "dog", "embedding": [1, 2, 1]},
133+
{"id": 2, "content": "fish", "embedding": [1, 2, 4]},
134+
{"id": 3, "content": "tree", "embedding": [1, 0, 0]},
135+
]
136+
).execute()
137+
138+
# Perform Vector Search for Top K=1
139+
cursor = (
140+
Doc.select(Doc.id, Doc.content)
141+
# For cosine distance, use Doc.embedding.cosine_distance(...)
142+
.order_by(Doc.embedding.l2_distance([1, 2, 3]))
143+
.limit(1)
144+
)
145+
for row in cursor:
146+
print(row.id, row.content)
147+
148+
149+
# Perform filtered Vector Search by adding a Where Clause:
150+
cursor = (
151+
Doc.select(Doc.id, Doc.content)
152+
.where(Doc.content == "dog")
153+
.order_by(Doc.embedding.l2_distance([1, 2, 3]))
154+
.limit(1)
155+
)
156+
for row in cursor:
157+
print(row.id, row.content)
158+
```
159+
87160
### Django
88161

89162
> [!TIP]
@@ -162,83 +235,6 @@ print(queryset)
162235

163236
For more details, see [django-tidb](https://github.com/pingcap/django-tidb?tab=readme-ov-file#vector-beta).
164237

165-
### Peewee
166-
167-
Define peewee table with vector field
168-
169-
```python
170-
from peewee import Model, MySQLDatabase
171-
from tidb_vector.peewee import VectorField
172-
173-
# Using `pymysql` as the driver
174-
connect_kwargs = {
175-
'ssl_verify_cert': True,
176-
'ssl_verify_identity': True,
177-
}
178-
179-
# Using `mysqlclient` as the driver
180-
connect_kwargs = {
181-
'ssl_mode': 'VERIFY_IDENTITY',
182-
'ssl': {
183-
# Root certificate default path
184-
# https://docs.pingcap.com/tidbcloud/secure-connections-to-serverless-clusters/#root-certificate-default-path
185-
'ca': '/etc/ssl/cert.pem' # MacOS
186-
},
187-
}
188-
189-
db = MySQLDatabase(
190-
'peewee_test',
191-
user='xxxxxxxx.root',
192-
password='xxxxxxxx',
193-
host='xxxxxxxx.shared.aws.tidbcloud.com',
194-
port=4000,
195-
**connect_kwargs,
196-
)
197-
198-
class TestModel(Model):
199-
class Meta:
200-
database = db
201-
table_name = 'test'
202-
203-
embedding = VectorField(3)
204-
205-
# or add hnsw index when creating table
206-
class TestModelWithIndex(Model):
207-
class Meta:
208-
database = db
209-
table_name = 'test_with_index'
210-
211-
embedding = VectorField(3, constraints=[SQL("COMMENT 'hnsw(distance=l2)'")])
212-
213-
214-
db.connect()
215-
db.create_tables([TestModel, TestModelWithIndex])
216-
```
217-
218-
Insert vector data
219-
220-
```python
221-
TestModel.create(embedding=[1, 2, 3])
222-
```
223-
224-
Get the nearest neighbors
225-
226-
```python
227-
TestModel.select().order_by(TestModel.embedding.l2_distance([1, 2, 3.1])).limit(5)
228-
```
229-
230-
Get the distance
231-
232-
```python
233-
TestModel.select(TestModel.embedding.cosine_distance([1, 2, 3.1]).alias('distance'))
234-
```
235-
236-
Get within a certain distance
237-
238-
```python
239-
TestModel.select().where(TestModel.embedding.l2_distance([1, 2, 3.1]) < 0.5)
240-
```
241-
242238
### TiDB Vector Client
243239

244240
Within the framework, you can directly utilize the built-in `TiDBVectorClient`, as demonstrated by integrations like [Langchain](https://python.langchain.com/docs/integrations/vectorstores/tidb_vector) and [Llama index](https://docs.llamaindex.ai/en/stable/community/integrations/vector_stores.html#using-a-vector-store-as-an-index), to seamlessly interact with TiDB Vector. This approach abstracts away the need to manage the underlying ORM, simplifying your interaction with the vector store.

tests/peewee/test_peewee.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22
import pytest
33
from peewee import MySQLDatabase, Model, OperationalError
4-
from tidb_vector.peewee import VectorField
4+
import tidb_vector
5+
from tidb_vector.peewee import VectorField, VectorAdaptor
56
from ..config import TestConfig
67

78

@@ -273,3 +274,67 @@ def test_negative_inner_product(self):
273274
assert items.count() == 1
274275
assert items.get().id == item.id
275276
assert items[0].distance == -14
277+
278+
279+
class TestPeeweeAdaptor:
280+
def setup_method(self):
281+
db.drop_tables([Item1Model, Item2Model])
282+
db.create_tables([Item1Model, Item2Model])
283+
284+
def teardown_method(self):
285+
db.drop_tables([Item1Model, Item2Model])
286+
287+
def test_create_index_on_dyn_vector(self):
288+
adaptor = VectorAdaptor(db)
289+
with pytest.raises(ValueError):
290+
adaptor.create_vector_index(
291+
Item1Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
292+
)
293+
assert adaptor.has_vector_index(Item1Model.embedding) is False
294+
295+
def test_create_index_on_fixed_vector(self):
296+
adaptor = VectorAdaptor(db)
297+
adaptor.create_vector_index(
298+
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
299+
)
300+
assert adaptor.has_vector_index(Item2Model.embedding) is True
301+
302+
with pytest.raises(Exception):
303+
adaptor.create_vector_index(
304+
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
305+
)
306+
307+
assert adaptor.has_vector_index(Item2Model.embedding) is True
308+
309+
adaptor.create_vector_index(
310+
Item2Model.embedding,
311+
distance_metric=tidb_vector.DistanceMetric.L2,
312+
skip_existing=True,
313+
)
314+
315+
adaptor.create_vector_index(
316+
Item2Model.embedding,
317+
distance_metric=tidb_vector.DistanceMetric.COSINE,
318+
skip_existing=True,
319+
)
320+
321+
def test_index_and_search(self):
322+
adaptor = VectorAdaptor(db)
323+
adaptor.create_vector_index(
324+
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
325+
)
326+
assert adaptor.has_vector_index(Item2Model.embedding) is True
327+
328+
Item2Model.insert_many(
329+
[
330+
{"embedding": [1, 2, 3]},
331+
{"embedding": [1, 2, 3.2]},
332+
]
333+
).execute()
334+
335+
distance = Item2Model.embedding.cosine_distance([1, 2, 3])
336+
items = (
337+
Item2Model.select(distance.alias("distance")).order_by(distance).limit(5)
338+
)
339+
assert items.count() == 2
340+
assert items[0].distance == 0.0

tidb_vector/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .constants import MAX_DIM, MIN_DIM, DistanceMetric
1+
from .constants import MAX_DIM, MIN_DIM, DistanceMetric, VectorDataType
22

33
__version__ = "0.0.13"
4-
__all__ = ["MAX_DIM", "MIN_DIM", "DistanceMetric"]
4+
__all__ = ["MAX_DIM", "MIN_DIM", "DistanceMetric", "VectorDataType"]

tidb_vector/constants.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,37 @@
11
import enum
2+
import typing
3+
4+
import numpy
25

36
# TiDB Vector has a limitation on the dimension length
47
MAX_DIM = 16000
58
MIN_DIM = 1
69

710

11+
VectorDataType = typing.Union[numpy.ndarray, typing.List[float]]
12+
13+
814
class DistanceMetric(enum.Enum):
15+
"""
16+
An enumeration representing different types of distance metrics.
17+
18+
- `DistanceMetric.L2`: L2 (Euclidean) distance metric.
19+
- `DistanceMetric.COSINE`: Cosine distance metric.
20+
"""
21+
922
L2 = "L2"
1023
COSINE = "COSINE"
1124

1225
def to_sql_func(self):
26+
"""
27+
Converts the DistanceMetric to its corresponding SQL function name.
28+
29+
Returns:
30+
str: The SQL function name.
31+
32+
Raises:
33+
ValueError: If the DistanceMetric enum member is not supported.
34+
"""
1335
if self == DistanceMetric.L2:
1436
return "VEC_L2_DISTANCE"
1537
elif self == DistanceMetric.COSINE:

tidb_vector/peewee/__init__.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,4 @@
1-
from peewee import Field, fn
1+
from .vector_type import VectorField
2+
from .adaptor import VectorAdaptor
23

3-
from tidb_vector.utils import decode_vector, encode_vector
4-
5-
6-
class VectorField(Field):
7-
field_type = "VECTOR"
8-
9-
def __init__(self, dimensions=None, *args, **kwargs):
10-
self.dimensions = dimensions
11-
super(VectorField, self).__init__(*args, **kwargs)
12-
13-
def get_modifiers(self):
14-
return self.dimensions and [self.dimensions] or None
15-
16-
def db_value(self, value):
17-
return encode_vector(value)
18-
19-
def python_value(self, value):
20-
return decode_vector(value)
21-
22-
def l1_distance(self, vector):
23-
return fn.VEC_L1_DISTANCE(self, self.to_value(vector))
24-
25-
def l2_distance(self, vector):
26-
return fn.VEC_L2_DISTANCE(self, self.to_value(vector))
27-
28-
def cosine_distance(self, vector):
29-
return fn.VEC_COSINE_DISTANCE(self, self.to_value(vector))
30-
31-
def negative_inner_product(self, vector):
32-
return fn.VEC_NEGATIVE_INNER_PRODUCT(self, self.to_value(vector))
4+
__all__ = ["VectorField", "VectorAdaptor"]

0 commit comments

Comments
 (0)