Skip to content

Support vector index for Peewee #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 75 additions & 79 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ Base.metadata.drop_all(engine) # clean data from last run
Base.metadata.create_all(engine)

# Create index for L2 distance
adaptor = VectorAdaptor(engine)
adaptor.create_vector_index(
VectorAdaptor(engine).create_vector_index(
Doc.embedding, tidb_vector.DistanceMetric.L2, skip_existing=True
# For cosine distance, use tidb_vector.DistanceMetric.COSINE
)

# Insert content with vectors
Expand All @@ -69,6 +69,7 @@ with Session(engine) as session:
results = session.execute(
select(Doc.id, Doc.content)
.order_by(Doc.embedding.l2_distance([1, 2, 3]))
# For cosine distance, use Doc.embedding.cosine_distance(...)
.limit(1)
).all()
print(results)
Expand All @@ -84,6 +85,78 @@ with Session(engine) as session:
print(results)
```

### Peewee

Install:

```bash
pip install tidb-vector peewee pymysql
```

Usage:

```python
import tidb_vector
from peewee import Model, MySQLDatabase, IntegerField, TextField
from tidb_vector.peewee import VectorField, VectorAdaptor

db = MySQLDatabase(
database="test",
user="root",
password="",
host="127.0.0.1",
port=4000,
)


# Define table schema
class Doc(Model):
class Meta:
database = db
table_name = "peewee_test"

id = IntegerField(primary_key=True)
embedding = VectorField(3)
content = TextField()


# Create empty table and index for L2 distance
db.drop_tables([Doc]) # clean data from last run
db.create_tables([Doc])
# For cosine distance, use tidb_vector.DistanceMetric.COSINE
VectorAdaptor(db).create_vector_index(Doc.embedding, tidb_vector.DistanceMetric.L2)

# Insert content with vectors
Doc.insert_many(
[
{"id": 1, "content": "dog", "embedding": [1, 2, 1]},
{"id": 2, "content": "fish", "embedding": [1, 2, 4]},
{"id": 3, "content": "tree", "embedding": [1, 0, 0]},
]
).execute()

# Perform Vector Search for Top K=1
cursor = (
Doc.select(Doc.id, Doc.content)
# For cosine distance, use Doc.embedding.cosine_distance(...)
.order_by(Doc.embedding.l2_distance([1, 2, 3]))
.limit(1)
)
for row in cursor:
print(row.id, row.content)


# Perform filtered Vector Search by adding a Where Clause:
cursor = (
Doc.select(Doc.id, Doc.content)
.where(Doc.content == "dog")
.order_by(Doc.embedding.l2_distance([1, 2, 3]))
.limit(1)
)
for row in cursor:
print(row.id, row.content)
```

### Django

> [!TIP]
Expand Down Expand Up @@ -162,83 +235,6 @@ print(queryset)

For more details, see [django-tidb](https://github.com/pingcap/django-tidb?tab=readme-ov-file#vector-beta).

### Peewee

Define peewee table with vector field

```python
from peewee import Model, MySQLDatabase
from tidb_vector.peewee import VectorField

# Using `pymysql` as the driver
connect_kwargs = {
'ssl_verify_cert': True,
'ssl_verify_identity': True,
}

# Using `mysqlclient` as the driver
connect_kwargs = {
'ssl_mode': 'VERIFY_IDENTITY',
'ssl': {
# Root certificate default path
# https://docs.pingcap.com/tidbcloud/secure-connections-to-serverless-clusters/#root-certificate-default-path
'ca': '/etc/ssl/cert.pem' # MacOS
},
}

db = MySQLDatabase(
'peewee_test',
user='xxxxxxxx.root',
password='xxxxxxxx',
host='xxxxxxxx.shared.aws.tidbcloud.com',
port=4000,
**connect_kwargs,
)

class TestModel(Model):
class Meta:
database = db
table_name = 'test'

embedding = VectorField(3)

# or add hnsw index when creating table
class TestModelWithIndex(Model):
class Meta:
database = db
table_name = 'test_with_index'

embedding = VectorField(3, constraints=[SQL("COMMENT 'hnsw(distance=l2)'")])


db.connect()
db.create_tables([TestModel, TestModelWithIndex])
```

Insert vector data

```python
TestModel.create(embedding=[1, 2, 3])
```

Get the nearest neighbors

```python
TestModel.select().order_by(TestModel.embedding.l2_distance([1, 2, 3.1])).limit(5)
```

Get the distance

```python
TestModel.select(TestModel.embedding.cosine_distance([1, 2, 3.1]).alias('distance'))
```

Get within a certain distance

```python
TestModel.select().where(TestModel.embedding.l2_distance([1, 2, 3.1]) < 0.5)
```

### TiDB Vector Client

Within the framework, you can directly utilize the built-in `TiDBVectorClient`, as demonstrated by integrations like [Langchain](https://python.langchain.com/docs/integrations/vectorstores/tidb_vector) and [Llama index](https://docs.llamaindex.ai/en/stable/community/integrations/vector_stores.html#using-a-vector-store-as-an-index), to seamlessly interact with TiDB Vector. This approach abstracts away the need to manage the underlying ORM, simplifying your interaction with the vector store.
Expand Down
67 changes: 66 additions & 1 deletion tests/peewee/test_peewee.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import pytest
from peewee import MySQLDatabase, Model, OperationalError
from tidb_vector.peewee import VectorField
import tidb_vector
from tidb_vector.peewee import VectorField, VectorAdaptor
from ..config import TestConfig


Expand Down Expand Up @@ -273,3 +274,67 @@ def test_negative_inner_product(self):
assert items.count() == 1
assert items.get().id == item.id
assert items[0].distance == -14


class TestPeeweeAdaptor:
def setup_method(self):
db.drop_tables([Item1Model, Item2Model])
db.create_tables([Item1Model, Item2Model])

def teardown_method(self):
db.drop_tables([Item1Model, Item2Model])

def test_create_index_on_dyn_vector(self):
adaptor = VectorAdaptor(db)
with pytest.raises(ValueError):
adaptor.create_vector_index(
Item1Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
)
assert adaptor.has_vector_index(Item1Model.embedding) is False

def test_create_index_on_fixed_vector(self):
adaptor = VectorAdaptor(db)
adaptor.create_vector_index(
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
)
assert adaptor.has_vector_index(Item2Model.embedding) is True

with pytest.raises(Exception):
adaptor.create_vector_index(
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
)

assert adaptor.has_vector_index(Item2Model.embedding) is True

adaptor.create_vector_index(
Item2Model.embedding,
distance_metric=tidb_vector.DistanceMetric.L2,
skip_existing=True,
)

adaptor.create_vector_index(
Item2Model.embedding,
distance_metric=tidb_vector.DistanceMetric.COSINE,
skip_existing=True,
)

def test_index_and_search(self):
adaptor = VectorAdaptor(db)
adaptor.create_vector_index(
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
)
assert adaptor.has_vector_index(Item2Model.embedding) is True

Item2Model.insert_many(
[
{"embedding": [1, 2, 3]},
{"embedding": [1, 2, 3.2]},
]
).execute()

distance = Item2Model.embedding.cosine_distance([1, 2, 3])
items = (
Item2Model.select(distance.alias("distance")).order_by(distance).limit(5)
)
assert items.count() == 2
assert items[0].distance == 0.0
4 changes: 2 additions & 2 deletions tidb_vector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .constants import MAX_DIM, MIN_DIM, DistanceMetric
from .constants import MAX_DIM, MIN_DIM, DistanceMetric, VectorDataType

__version__ = "0.0.13"
__all__ = ["MAX_DIM", "MIN_DIM", "DistanceMetric"]
__all__ = ["MAX_DIM", "MIN_DIM", "DistanceMetric", "VectorDataType"]
22 changes: 22 additions & 0 deletions tidb_vector/constants.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
import enum
import typing

import numpy

# TiDB Vector has a limitation on the dimension length
MAX_DIM = 16000
MIN_DIM = 1


VectorDataType = typing.Union[numpy.ndarray, typing.List[float]]


class DistanceMetric(enum.Enum):
"""
An enumeration representing different types of distance metrics.

- `DistanceMetric.L2`: L2 (Euclidean) distance metric.
- `DistanceMetric.COSINE`: Cosine distance metric.
"""

L2 = "L2"
COSINE = "COSINE"

def to_sql_func(self):
"""
Converts the DistanceMetric to its corresponding SQL function name.

Returns:
str: The SQL function name.

Raises:
ValueError: If the DistanceMetric enum member is not supported.
"""
if self == DistanceMetric.L2:
return "VEC_L2_DISTANCE"
elif self == DistanceMetric.COSINE:
Expand Down
34 changes: 3 additions & 31 deletions tidb_vector/peewee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,4 @@
from peewee import Field, fn
from .vector_type import VectorField
from .adaptor import VectorAdaptor

from tidb_vector.utils import decode_vector, encode_vector


class VectorField(Field):
field_type = "VECTOR"

def __init__(self, dimensions=None, *args, **kwargs):
self.dimensions = dimensions
super(VectorField, self).__init__(*args, **kwargs)

def get_modifiers(self):
return self.dimensions and [self.dimensions] or None

def db_value(self, value):
return encode_vector(value)

def python_value(self, value):
return decode_vector(value)

def l1_distance(self, vector):
return fn.VEC_L1_DISTANCE(self, self.to_value(vector))

def l2_distance(self, vector):
return fn.VEC_L2_DISTANCE(self, self.to_value(vector))

def cosine_distance(self, vector):
return fn.VEC_COSINE_DISTANCE(self, self.to_value(vector))

def negative_inner_product(self, vector):
return fn.VEC_NEGATIVE_INNER_PRODUCT(self, self.to_value(vector))
__all__ = ["VectorField", "VectorAdaptor"]
Loading
Loading