Skip to content

Commit cd238bc

Browse files
breezewishwd0517
andauthored
Support vector index for SQLAlchemy (#65)
As an alternative to #63 After discussion with @JaySon-Huang, inspired by some other vector databases, I think we could simply introduce some side functions for creating indexes. The usage is as simple as using native SQLAlchemy interfaces, without the need to introduce another dialect like sqlalchemy-tidb. Maintaining a new SQLAlchemy dialect is costy, as we need to support both SQLAlchemy 1.4 and SQLAlchemy 2.0, and support other TiDB features like AUTO_RANDOM. --------- Signed-off-by: Wish <breezewish@outlook.com> Co-authored-by: WD <me@wangdi.ink>
1 parent 5569bb8 commit cd238bc

File tree

10 files changed

+372
-163
lines changed

10 files changed

+372
-163
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@ jobs:
2323
2424
- name: Run lint
2525
run: |
26-
tox -e lint
26+
tox -e lint
2727
2828
tests:
2929
strategy:
3030
fail-fast: false
3131
matrix:
3232
python-version:
33-
- '3.12'
33+
- "3.12"
3434
name: py${{ matrix.python-version }}_test
3535
runs-on: ubuntu-latest
3636
services:
3737
tidb:
38-
image: wangdi4zm/tind:v7.5.3-vector-index
38+
image: wangdi4zm/tind:v8.4.0-vector-index
3939
ports:
4040
- 4000:4000
4141
steps:

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,5 @@ cython_debug/
141141
django_tests_dir
142142

143143
*.swp
144+
145+
.vscode/

README.md

Lines changed: 59 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# tidb-vector-python
22

3-
This is a Python client for TiDB Vector.
4-
5-
> Now only TiDB Cloud Serverless cluster support vector data type, see this [docs](https://docs.pingcap.com/tidbcloud/vector-search-overview?utm_source=github&utm_medium=tidb-vector-python) for more information.
3+
Use TiDB Vector Search with Python.
64

75
## Installation
86

@@ -12,74 +10,79 @@ pip install tidb-vector
1210

1311
## Usage
1412

15-
TiDB vector supports below distance functions:
13+
TiDB is a SQL database so that this package introduces Vector Search capability for Python ORMs:
1614

17-
- `L1Distance`
18-
- `L2Distance`
19-
- `CosineDistance`
20-
- `NegativeInnerProduct`
15+
- [#SQLAlchemy](#sqlalchemy)
16+
- [#Django](#django)
17+
- [#Peewee](#peewee)
2118

22-
It also supports using hnsw index with l2 or cosine distance to speed up the search, for more details see [Vector Search Indexes in TiDB](https://docs.pingcap.com/tidbcloud/vector-search-index)
19+
Pick one that you are familiar with to get started. If you are not using any of them, we recommend [#SQLAlchemy](#sqlalchemy).
2320

24-
Supports following orm or framework:
21+
We also provide a Vector Search client for simple usage:
2522

26-
- [SQLAlchemy](#sqlalchemy)
27-
- [Django](#django)
28-
- [Peewee](#peewee)
29-
- [TiDB Vector Client](#tidb-vector-client)
23+
- [#TiDB Vector Client](#tidb-vector-client)
3024

3125
### SQLAlchemy
3226

33-
Learn how to connect to TiDB Serverless in the [TiDB Cloud documentation](https://docs.pingcap.com/tidbcloud/dev-guide-sample-application-python-sqlalchemy).
34-
35-
Define table with vector field
27+
```bash
28+
pip install tidb-vector sqlalchemy pymysql
29+
```
3630

3731
```python
38-
from sqlalchemy import Column, Integer, create_engine
39-
from sqlalchemy.orm import declarative_base
40-
from tidb_vector.sqlalchemy import VectorType
32+
from sqlalchemy import Integer, Text, Column
33+
from sqlalchemy import create_engine, select
34+
from sqlalchemy.orm import Session, declarative_base
4135

42-
engine = create_engine('mysql://****.root:******@gateway01.xxxxxx.shared.aws.tidbcloud.com:4000/test')
36+
import tidb_vector
37+
from tidb_vector.sqlalchemy import VectorType, VectorAdaptor
38+
39+
engine = create_engine("mysql+pymysql://root@127.0.0.1:4000/test")
4340
Base = declarative_base()
4441

45-
class Test(Base):
46-
__tablename__ = 'test'
47-
id = Column(Integer, primary_key=True)
48-
embedding = Column(VectorType(3))
4942

50-
# or add hnsw index when creating table
51-
class TestWithIndex(Base):
52-
__tablename__ = 'test_with_index'
43+
# Define table schema
44+
class Doc(Base):
45+
__tablename__ = "doc"
5346
id = Column(Integer, primary_key=True)
54-
embedding = Column(VectorType(3), comment="hnsw(distance=l2)")
55-
56-
Base.metadata.create_all(engine)
57-
```
58-
59-
Insert vector data
60-
61-
```python
62-
test = Test(embedding=[1, 2, 3])
63-
session.add(test)
64-
session.commit()
65-
```
66-
67-
Get the nearest neighbors
47+
embedding = Column(VectorType(3)) # Vector with 3 dimensions
48+
content = Column(Text)
6849

69-
```python
70-
session.scalars(select(Test).order_by(Test.embedding.l2_distance([1, 2, 3.1])).limit(5))
71-
```
72-
73-
Get the distance
7450

75-
```python
76-
session.scalars(select(Test.embedding.l2_distance([1, 2, 3.1])))
77-
```
51+
# Create empty table
52+
Base.metadata.drop_all(engine) # clean data from last run
53+
Base.metadata.create_all(engine)
7854

79-
Get within a certain distance
55+
# Create index using L2 distance
56+
adaptor = VectorAdaptor(engine)
57+
adaptor.create_vector_index(
58+
Doc.embedding, tidb_vector.DistanceMetric.L2, skip_existing=True
59+
)
8060

81-
```python
82-
session.scalars(select(Test).filter(Test.embedding.l2_distance([1, 2, 3.1]) < 0.2))
61+
# Insert content with vectors
62+
with Session(engine) as session:
63+
session.add(Doc(id=1, content="dog", embedding=[1, 2, 1]))
64+
session.add(Doc(id=2, content="fish", embedding=[1, 2, 4]))
65+
session.add(Doc(id=3, content="tree", embedding=[1, 0, 0]))
66+
session.commit()
67+
68+
# Perform Vector Search for Top K=1
69+
with Session(engine) as session:
70+
results = session.execute(
71+
select(Doc.id, Doc.content)
72+
.order_by(Doc.embedding.cosine_distance([1, 2, 3]))
73+
.limit(1)
74+
).all()
75+
print(results)
76+
77+
# Perform filtered Vector Search by adding a Where Clause:
78+
with Session(engine) as session:
79+
results = session.execute(
80+
select(Doc.id, Doc.content)
81+
.where(Doc.id > 2)
82+
.order_by(Doc.embedding.cosine_distance([1, 2, 3]))
83+
.limit(1)
84+
).all()
85+
print(results)
8386
```
8487

8588
### Django
@@ -165,7 +168,7 @@ TestModel.select().where(TestModel.embedding.l2_distance([1, 2, 3.1]) < 0.5)
165168

166169
### TiDB Vector Client
167170

168-
Within the framework, you can directly utilize the built-in `TiDBVectorClient`, as demonstrated by integrations like [Langchain](https://python.langchain.com/docs/integrations/vectorstores/tidb_vector) and [Llama index](https://docs.llamaindex.ai/en/stable/community/integrations/vector_stores.html#using-a-vector-store-as-an-index), to seamlessly interact with TiDB Vector. This approach abstracts away the need to manage the underlying ORM, simplifying your interaction with the vector store.
171+
Within the framework, you can directly utilize the built-in `TiDBVectorClient`, as demonstrated by integrations like [Langchain](https://python.langchain.com/docs/integrations/vectorstores/tidb_vector) and [Llama index](https://docs.llamaindex.ai/en/stable/community/integrations/vector_stores.html#using-a-vector-store-as-an-index), to seamlessly interact with TiDB Vector. This approach abstracts away the need to manage the underlying ORM, simplifying your interaction with the vector store.
169172

170173
We provide `TiDBVectorClient` which is based on sqlalchemy, you need to use `pip install tidb-vector[client]` to install it.
171174

@@ -252,4 +255,5 @@ There are some examples to show how to use the tidb-vector-python to interact wi
252255
for more examples, see the [examples](./examples) directory.
253256

254257
## Contributing
255-
Please feel free to reach out to the maintainers if you have any questions or need help with the project. Before contributing, please read the [CONTRIBUTING.md](./CONTRIBUTING.md) file.
258+
259+
Please feel free to reach out to the maintainers if you have any questions or need help with the project. Before contributing, please read the [CONTRIBUTING.md](./CONTRIBUTING.md) file.

tests/sqlalchemy/test_sqlalchemy.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from sqlalchemy import URL, create_engine, Column, Integer, select
44
from sqlalchemy.orm import declarative_base, sessionmaker
55
from sqlalchemy.exc import OperationalError
6-
from tidb_vector.sqlalchemy import VectorType
6+
from tidb_vector.sqlalchemy import VectorType, VectorAdaptor
7+
import tidb_vector
78
from ..config import TestConfig
89

910

@@ -14,9 +15,11 @@
1415
host=TestConfig.TIDB_HOST,
1516
port=TestConfig.TIDB_PORT,
1617
database="test",
17-
query={"ssl_verify_cert": True, "ssl_verify_identity": True}
18-
if TestConfig.TIDB_SSL
19-
else {},
18+
query=(
19+
{"ssl_verify_cert": True, "ssl_verify_identity": True}
20+
if TestConfig.TIDB_SSL
21+
else {}
22+
),
2023
)
2124

2225
engine = create_engine(db_url)
@@ -58,6 +61,15 @@ def test_insert_get_record(self):
5861
assert np.array_equal(item1.embedding, np.array([1, 2, 3]))
5962
assert item1.embedding.dtype == np.float32
6063

64+
def test_insert_get_record_np(self):
65+
with Session() as session:
66+
item1 = Item1Model(embedding=np.array([1, 2, 3]))
67+
session.add(item1)
68+
session.commit()
69+
item1 = session.query(Item1Model).first()
70+
assert np.array_equal(item1.embedding, np.array([1, 2, 3]))
71+
assert item1.embedding.dtype == np.float32
72+
6173
def test_empty_vector(self):
6274
with Session() as session:
6375
item1 = Item1Model(embedding=[])
@@ -303,3 +315,73 @@ def test_negative_inner_product(self):
303315
)
304316
assert len(items) == 2
305317
assert items[1].distance == -14.0
318+
319+
320+
class TestSQLAlchemyAdaptor:
321+
def setup_method(self):
322+
Item1Model.__table__.drop(bind=engine, checkfirst=True)
323+
Item1Model.__table__.create(bind=engine)
324+
Item2Model.__table__.drop(bind=engine, checkfirst=True)
325+
Item2Model.__table__.create(bind=engine)
326+
327+
def teardown_method(self):
328+
Item1Model.__table__.drop(bind=engine, checkfirst=True)
329+
Item2Model.__table__.drop(bind=engine, checkfirst=True)
330+
331+
def test_create_index_on_dyn_vector(self):
332+
adaptor = VectorAdaptor(engine)
333+
with pytest.raises(ValueError):
334+
adaptor.create_vector_index(
335+
Item1Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
336+
)
337+
assert adaptor.has_vector_index(Item1Model.embedding) is False
338+
339+
def test_create_index_on_fixed_vector(self):
340+
adaptor = VectorAdaptor(engine)
341+
adaptor.create_vector_index(
342+
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
343+
)
344+
assert adaptor.has_vector_index(Item2Model.embedding) is True
345+
346+
with pytest.raises(Exception):
347+
adaptor.create_vector_index(
348+
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
349+
)
350+
351+
assert adaptor.has_vector_index(Item2Model.embedding) is True
352+
353+
adaptor.create_vector_index(
354+
Item2Model.embedding,
355+
distance_metric=tidb_vector.DistanceMetric.L2,
356+
skip_existing=True,
357+
)
358+
359+
adaptor.create_vector_index(
360+
Item2Model.embedding,
361+
distance_metric=tidb_vector.DistanceMetric.COSINE,
362+
skip_existing=True,
363+
)
364+
365+
def test_index_and_search(self):
366+
adaptor = VectorAdaptor(engine)
367+
adaptor.create_vector_index(
368+
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
369+
)
370+
assert adaptor.has_vector_index(Item2Model.embedding) is True
371+
372+
with Session() as session:
373+
session.add_all(
374+
[Item2Model(embedding=[1, 2, 3]), Item2Model(embedding=[1, 2, 3.2])]
375+
)
376+
session.commit()
377+
378+
# l2 distance
379+
distance = Item2Model.embedding.cosine_distance([1, 2, 3])
380+
items = (
381+
session.query(Item2Model.id, distance.label("distance"))
382+
.order_by(distance)
383+
.limit(5)
384+
.all()
385+
)
386+
assert len(items) == 2
387+
assert items[0].distance == 0.0

tidb_vector/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
__version__ = "0.0.12"
1+
from .constants import MAX_DIM, MIN_DIM, DistanceMetric
2+
3+
__version__ = "0.0.13"
4+
__all__ = ["MAX_DIM", "MIN_DIM", "DistanceMetric"]

tidb_vector/constants.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
import enum
2+
13
# TiDB Vector has a limitation on the dimension length
2-
MAX_DIMENSION_LENGTH = 16000
3-
MIN_DIMENSION_LENGTH = 1
4+
MAX_DIM = 16000
5+
MIN_DIM = 1
6+
7+
8+
class DistanceMetric(enum.Enum):
9+
L2 = "L2"
10+
COSINE = "COSINE"
11+
12+
def to_sql_func(self):
13+
if self == DistanceMetric.L2:
14+
return "VEC_L2_DISTANCE"
15+
elif self == DistanceMetric.COSINE:
16+
return "VEC_COSINE_DISTANCE"
17+
else:
18+
raise ValueError("unsupported distance metric")

0 commit comments

Comments
 (0)