From 497d44a2d88d57e545dbd956262deaa6c6faf434 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Oct 2024 16:17:42 -0700 Subject: [PATCH 1/9] feat(datastore): add vector search samples --- datastore/cloud-client/vector_search.py | 119 +++++++++++++++++++ datastore/cloud-client/vector_search_test.py | 105 ++++++++++++++++ 2 files changed, 224 insertions(+) create mode 100644 datastore/cloud-client/vector_search.py create mode 100644 datastore/cloud-client/vector_search_test.py diff --git a/datastore/cloud-client/vector_search.py b/datastore/cloud-client/vector_search.py new file mode 100644 index 0000000000..398e15a729 --- /dev/null +++ b/datastore/cloud-client/vector_search.py @@ -0,0 +1,119 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +def store_vectors(): + # [START datastore_store_vectors] + from google.cloud import datastore + from google.cloud.datastore.vector import Vector + + client = datastore.Client() + + key = client.key("coffee-beans") + entity = datastore.Entity(key=key) + entity.update( + { + "name": "Kahawa coffee beans", + "description": "Information about the Kahawa coffee beans.", + "embedding_field": Vector([1.0, 2.0, 3.0]), + } + ) + + client.put(entity) + # [END datastore_store_vectors] + return client + + +def vector_search_basic(db): + # [START datastore_vector_search_basic] + from google.cloud.datastore.vector import DistanceMeasure + from google.cloud.datastore.vector import Vector + from google.cloud.datastore.vector import FindNearest + + vector_query = db.query( + kind="coffee-beans", + find_nearest=FindNearest( + vector_property="embedding_field", + query_vector=Vector([3.0, 1.0, 2.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + ) + ) + # [END datastore_vector_search_basic] + return vector_query + + +def vector_search_prefilter(db): + # [START datastore_vector_search_prefilter] + from google.cloud.datastore.vector import DistanceMeasure + from google.cloud.datastore.vector import Vector + from google.cloud.datastore.vector import FindNearest + from google.cloud.datastore.query import PropertyFilter + + vector_query = db.query( + kind="coffee-beans", + filters=PropertyFilter("color", "=", "red"), + find_nearest=FindNearest( + vector_property="embedding_field", + query_vector=Vector([3.0, 1.0, 2.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + ) + ) + # [END datastore_vector_search_prefilter] + return vector_query + + +def vector_search_distance_result_field(db): + # [START datastore_vector_search_distance_result_field] + from google.cloud.datastore.vector import DistanceMeasure + from google.cloud.datastore.vector import Vector + from google.cloud.datastore.vector import FindNearest + + vector_query = db.query( + kind="coffee-beans", + find_nearest=FindNearest( + vector_property="embedding_field", + query_vector=Vector([3.0, 1.0, 2.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + distance_result_property="vector_distance", + ) + ) + + for entity in vector_query.fetch(): + print(f"{entity.id}, Distance: {entity['distance']}") + # [END datastore_vector_search_distance_result_field] + return vector_query + + +def vector_search_distance_threshold(db): + # [START datastore_vector_search_distance_threshold] + from google.cloud.datastore.vector import DistanceMeasure + from google.cloud.datastore.vector import Vector + from google.cloud.datastore.vector import FindNearest + + vector_query = db.query( + kind="coffee-beans", + find_nearest=FindNearest( + vector_property="embedding_field", + query_vector=Vector([3.0, 1.0, 2.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=10, + distance_threshold=4.5 + ) + ) + + for entity in vector_query.fetch(): + print(f"{entity.id}") + # [END datastore_vector_search_distance_threshold] + return vector_query \ No newline at end of file diff --git a/datastore/cloud-client/vector_search_test.py b/datastore/cloud-client/vector_search_test.py new file mode 100644 index 0000000000..349fd661ae --- /dev/null +++ b/datastore/cloud-client/vector_search_test.py @@ -0,0 +1,105 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.cloud import datastore +from google.cloud.datastore.vector import Vector + +from vector_search import store_vectors +from vector_search import vector_search_basic +from vector_search import vector_search_distance_result_field +from vector_search import vector_search_distance_threshold +from vector_search import vector_search_prefilter + + +os.environ["GOOGLE_CLOUD_PROJECT"] = os.environ["FIRESTORE_PROJECT"] +PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"] + + +def test_store_vectors(): + client = store_vectors() + + results = client.query("coffee-beans", limit=5).fetch() + + assert len(list(results)) == 1 + + +def add_coffee_beans_data(db): + entity1 = datastore.Entity(db.key("coffee-beans", "Arabica")) + entity1.update({"embedding_field": Vector([10.0, 1.0, 2.0]), "color": "red"}) + entity2 = datastore.Entity(db.key("coffee-beans", "Robusta")) + entity2.update({"embedding_field": Vector([4.0, 1.0, 2.0]), "color": ""}) + entity3 = datastore.Entity(db.key("coffee-beans", "Excelsa")) + entity3.update({"embedding_field": Vector([11.0, 1.0, 2.0]), "color": "red"}) + entity4 = datastore.Entity(db.key("coffee-beans", "Liberica")) + entity4.update({"embedding_field": Vector([3.0, 1.0, 2.0]), "color": "green"}) + + db.put_multi([entity1, entity2, entity3, entity4]) + + +def test_vector_search_basic(): + db = datastore.Client() + add_coffee_beans_data(db) + + vector_query = vector_search_basic(db) + results = list(vector_query.fetch()) + + assert len(results) == 4 + assert results[0].name == "Liberica" + assert results[1].name == "Robusta" + assert results[2].name == "Arabica" + assert results[3].name == "Excelsa" + + +def test_vector_search_prefilter(): + db = datastore.Client() + add_coffee_beans_data(db) + + vector_query = vector_search_prefilter(db) + results = list(vector_query.fetch()) + + assert len(results) == 2 + assert results[0].name == "Arabica" + assert results[1].name == "Excelsa" + + +def test_vector_search_distance_result_field(): + db = datastore.Client() + add_coffee_beans_data(db) + + vector_query = vector_search_distance_result_field(db) + results = list(vector_query.fetch()) + + assert len(results) == 4 + assert results[0].name == "Liberica" + assert results[0]["vector_distance"] == 0.0 + assert results[1].name == "Robusta" + assert results[1]["vector_distance"] == 1.0 + assert results[2].name == "Arabica" + assert results[2]["vector_distance"] == 7.0 + assert results[3].name == "Excelsa" + assert results[3]["vector_distance"] == 8.0 + + +def test_vector_search_distance_threshold(): + db = datastore.Client() + add_coffee_beans_data(db) + + vector_query = vector_search_distance_threshold(db) + results = list(vector_query.fetch()) + + assert len(results) == 2 + assert results[0].name == "Liberica" + assert results[1].name == "Robusta" \ No newline at end of file From c4d5b4d93fa9aca7defebac7fe4d5e23c2366cc6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 12 Nov 2024 16:12:45 -0800 Subject: [PATCH 2/9] added index --- datastore/cloud-client/index.yaml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/datastore/cloud-client/index.yaml b/datastore/cloud-client/index.yaml index 47d57d9841..e132e83d09 100644 --- a/datastore/cloud-client/index.yaml +++ b/datastore/cloud-client/index.yaml @@ -57,3 +57,18 @@ indexes: direction: desc - name: experience direction: desc +- kind: coffee-beans + properties: + - name: __key__ + - name: embedding_field + vectorConfig: + dimension: 3 + flat: {} +- kind: coffee-beans + properties: + - name: color + - name: __key__ + - name: embedding_field + vectorConfig: + dimension: 3 + flat: {} \ No newline at end of file From 9d94b67cfc14ba3b0fd308d32dd330f8e06de08a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 12 Nov 2024 16:13:30 -0800 Subject: [PATCH 3/9] fixed sample issues --- datastore/cloud-client/vector_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datastore/cloud-client/vector_search.py b/datastore/cloud-client/vector_search.py index 398e15a729..5256c77cf5 100644 --- a/datastore/cloud-client/vector_search.py +++ b/datastore/cloud-client/vector_search.py @@ -61,7 +61,7 @@ def vector_search_prefilter(db): vector_query = db.query( kind="coffee-beans", - filters=PropertyFilter("color", "=", "red"), + filters=[PropertyFilter("color", "=", "red")], find_nearest=FindNearest( vector_property="embedding_field", query_vector=Vector([3.0, 1.0, 2.0]), @@ -91,7 +91,7 @@ def vector_search_distance_result_field(db): ) for entity in vector_query.fetch(): - print(f"{entity.id}, Distance: {entity['distance']}") + print(f"{entity.id}, Distance: {entity['vector_distance']}") # [END datastore_vector_search_distance_result_field] return vector_query From 1833f1e8ad78a5835b9be4269da4bc4a83c469be Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 12 Nov 2024 16:15:34 -0800 Subject: [PATCH 4/9] fixed sample tests --- datastore/cloud-client/vector_search.py | 2 +- datastore/cloud-client/vector_search_test.py | 72 +++++++++++--------- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/datastore/cloud-client/vector_search.py b/datastore/cloud-client/vector_search.py index 5256c77cf5..3947077fd0 100644 --- a/datastore/cloud-client/vector_search.py +++ b/datastore/cloud-client/vector_search.py @@ -30,7 +30,7 @@ def store_vectors(): client.put(entity) # [END datastore_store_vectors] - return client + return client, entity def vector_search_basic(db): diff --git a/datastore/cloud-client/vector_search_test.py b/datastore/cloud-client/vector_search_test.py index 349fd661ae..3aa6174ea0 100644 --- a/datastore/cloud-client/vector_search_test.py +++ b/datastore/cloud-client/vector_search_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import pytest from google.cloud import datastore from google.cloud.datastore.vector import Vector @@ -28,12 +29,21 @@ PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"] -def test_store_vectors(): - client = store_vectors() - - results = client.query("coffee-beans", limit=5).fetch() +@pytest.fixture(scope="module") +def db(): + client = datastore.Client() + _clear_db(client) + entity_list = add_coffee_beans_data(client) + yield client + for e in entity_list: + client.delete(e) - assert len(list(results)) == 1 +def _clear_db(db): + """remove all entities with kind-coffee-beans, so we have a new databse""" + query = db.query(kind="coffee-beans") + query.keys_only() + keys = list(query.fetch()) + db.delete_multi(keys) def add_coffee_beans_data(db): @@ -46,60 +56,54 @@ def add_coffee_beans_data(db): entity4 = datastore.Entity(db.key("coffee-beans", "Liberica")) entity4.update({"embedding_field": Vector([3.0, 1.0, 2.0]), "color": "green"}) - db.put_multi([entity1, entity2, entity3, entity4]) + entity_list = [entity1, entity2, entity3, entity4] + db.put_multi(entity_list) + return entity_list +def test_store_vectors(): + # run an ensure there are no exceptions + client, entity = store_vectors() + client.delete(entity) -def test_vector_search_basic(): - db = datastore.Client() - add_coffee_beans_data(db) - +def test_vector_search_basic(db): vector_query = vector_search_basic(db) results = list(vector_query.fetch()) assert len(results) == 4 - assert results[0].name == "Liberica" - assert results[1].name == "Robusta" - assert results[2].name == "Arabica" - assert results[3].name == "Excelsa" + assert results[0].key.name == "Liberica" + assert results[1].key.name == "Robusta" + assert results[2].key.name == "Arabica" + assert results[3].key.name == "Excelsa" -def test_vector_search_prefilter(): - db = datastore.Client() - add_coffee_beans_data(db) - +def test_vector_search_prefilter(db): vector_query = vector_search_prefilter(db) results = list(vector_query.fetch()) assert len(results) == 2 - assert results[0].name == "Arabica" - assert results[1].name == "Excelsa" - + assert results[0].key.name == "Arabica" + assert results[1].key.name == "Excelsa" -def test_vector_search_distance_result_field(): - db = datastore.Client() - add_coffee_beans_data(db) +def test_vector_search_distance_result_field(db): vector_query = vector_search_distance_result_field(db) results = list(vector_query.fetch()) assert len(results) == 4 - assert results[0].name == "Liberica" + assert results[0].key.name == "Liberica" assert results[0]["vector_distance"] == 0.0 - assert results[1].name == "Robusta" + assert results[1].key.name == "Robusta" assert results[1]["vector_distance"] == 1.0 - assert results[2].name == "Arabica" + assert results[2].key.name == "Arabica" assert results[2]["vector_distance"] == 7.0 - assert results[3].name == "Excelsa" + assert results[3].key.name == "Excelsa" assert results[3]["vector_distance"] == 8.0 -def test_vector_search_distance_threshold(): - db = datastore.Client() - add_coffee_beans_data(db) - +def test_vector_search_distance_threshold(db): vector_query = vector_search_distance_threshold(db) results = list(vector_query.fetch()) assert len(results) == 2 - assert results[0].name == "Liberica" - assert results[1].name == "Robusta" \ No newline at end of file + assert results[0].key.name == "Liberica" + assert results[1].key.name == "Robusta" \ No newline at end of file From 5bcb03089577cbb610c56129e4b88cbb7ef1fcf0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 12 Nov 2024 16:52:13 -0800 Subject: [PATCH 5/9] added sample for getting around 2mb limit --- datastore/cloud-client/vector_search.py | 33 +++++++++++++++++++- datastore/cloud-client/vector_search_test.py | 27 +++++++++++++++- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/datastore/cloud-client/vector_search.py b/datastore/cloud-client/vector_search.py index 3947077fd0..2c18943bbd 100644 --- a/datastore/cloud-client/vector_search.py +++ b/datastore/cloud-client/vector_search.py @@ -116,4 +116,35 @@ def vector_search_distance_threshold(db): for entity in vector_query.fetch(): print(f"{entity.id}") # [END datastore_vector_search_distance_threshold] - return vector_query \ No newline at end of file + return vector_query + + +def vector_search_large_query(db): + # [START datastore_vector_search_large_query] + from google.cloud.datastore.vector import DistanceMeasure + from google.cloud.datastore.vector import Vector + from google.cloud.datastore.vector import FindNearest + + # first, perform a vector search query retrieving just the keys + vector_query = db.query( + kind="coffee-beans", + find_nearest=FindNearest( + vector_property="embedding_field", + query_vector=Vector([3.0, 1.0, 2.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=100, + distance_result_property="vector_distance", + ) + ) + vector_query.keys_only() + vector_results = list(vector_query.fetch()) + key_list = [entity.key for entity in vector_results] + # next, perfrom a second query for the remaining data + full_results = db.get_multi(key_list) + # combine and print results + vector_map = {entity.key: entity for entity in vector_results} + full_map = {entity.key: entity for entity in full_results} + for key in key_list: + print(f"distance: {vector_map[key]['vector_distance']} entity: {full_map[key]}") + # [END datastore_vector_search_large_query] + return key_list, vector_results, full_results \ No newline at end of file diff --git a/datastore/cloud-client/vector_search_test.py b/datastore/cloud-client/vector_search_test.py index 3aa6174ea0..f36833552b 100644 --- a/datastore/cloud-client/vector_search_test.py +++ b/datastore/cloud-client/vector_search_test.py @@ -106,4 +106,29 @@ def test_vector_search_distance_threshold(db): assert len(results) == 2 assert results[0].key.name == "Liberica" - assert results[1].key.name == "Robusta" \ No newline at end of file + assert results[1].key.name == "Robusta" + +def test_vector_search_large_query(db): + key_list, vector_results, full_results = vector_search_large_query(db) + assert len(key_list) == 4 + # each list should have same number of elements + assert len(key_list) == len(vector_results) + assert len(key_list) == len(full_results) + # should all have the same keys + vector_map = {entity.key: entity for entity in vector_results} + full_map = {entity.key: entity for entity in full_results} + for key in key_list: + assert key in vector_map.keys() + assert key in full_map.keys() + # vector_results should just contain key and distance + for entity in vector_results: + assert entity.key is not None + assert entity["vector_distance"] is not None + with pytest.raises(KeyError): + entity["embedding_field"] + # full_results should have other fields, but no vector_distance + for entity in full_results: + assert entity.key is not None + assert isinstance(entity["embedding_field"], Vector) + with pytest.raises(KeyError): + entity["vector_distance"] \ No newline at end of file From fb1957868da97416362447d0ecc142e4c1f26ef9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 25 Nov 2024 16:06:51 -0800 Subject: [PATCH 6/9] rename large_query to large_response --- datastore/cloud-client/vector_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datastore/cloud-client/vector_search.py b/datastore/cloud-client/vector_search.py index 2c18943bbd..531b137861 100644 --- a/datastore/cloud-client/vector_search.py +++ b/datastore/cloud-client/vector_search.py @@ -120,7 +120,7 @@ def vector_search_distance_threshold(db): def vector_search_large_query(db): - # [START datastore_vector_search_large_query] + # [START datastore_vector_search_large_reqponse] from google.cloud.datastore.vector import DistanceMeasure from google.cloud.datastore.vector import Vector from google.cloud.datastore.vector import FindNearest @@ -146,5 +146,5 @@ def vector_search_large_query(db): full_map = {entity.key: entity for entity in full_results} for key in key_list: print(f"distance: {vector_map[key]['vector_distance']} entity: {full_map[key]}") - # [END datastore_vector_search_large_query] + # [END datastore_vector_search_large_response] return key_list, vector_results, full_results \ No newline at end of file From 500613e176fa30c0e47cddf81eebe6565185ff9f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 25 Nov 2024 16:13:40 -0800 Subject: [PATCH 7/9] fixed sample names --- datastore/cloud-client/vector_search.py | 10 +++++----- datastore/cloud-client/vector_search_test.py | 11 ++++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/datastore/cloud-client/vector_search.py b/datastore/cloud-client/vector_search.py index 531b137861..0b75958d9e 100644 --- a/datastore/cloud-client/vector_search.py +++ b/datastore/cloud-client/vector_search.py @@ -73,8 +73,8 @@ def vector_search_prefilter(db): return vector_query -def vector_search_distance_result_field(db): - # [START datastore_vector_search_distance_result_field] +def vector_search_distance_result_property(db): + # [START datastore_vector_search_distance_result_property] from google.cloud.datastore.vector import DistanceMeasure from google.cloud.datastore.vector import Vector from google.cloud.datastore.vector import FindNearest @@ -92,7 +92,7 @@ def vector_search_distance_result_field(db): for entity in vector_query.fetch(): print(f"{entity.id}, Distance: {entity['vector_distance']}") - # [END datastore_vector_search_distance_result_field] + # [END datastore_vector_search_distance_result_property] return vector_query @@ -119,8 +119,8 @@ def vector_search_distance_threshold(db): return vector_query -def vector_search_large_query(db): - # [START datastore_vector_search_large_reqponse] +def vector_search_large_response(db): + # [START datastore_vector_search_large_response] from google.cloud.datastore.vector import DistanceMeasure from google.cloud.datastore.vector import Vector from google.cloud.datastore.vector import FindNearest diff --git a/datastore/cloud-client/vector_search_test.py b/datastore/cloud-client/vector_search_test.py index f36833552b..5f1a9ddf52 100644 --- a/datastore/cloud-client/vector_search_test.py +++ b/datastore/cloud-client/vector_search_test.py @@ -20,9 +20,10 @@ from vector_search import store_vectors from vector_search import vector_search_basic -from vector_search import vector_search_distance_result_field +from vector_search import vector_search_distance_result_property from vector_search import vector_search_distance_threshold from vector_search import vector_search_prefilter +from vector_search import vector_search_large_response os.environ["GOOGLE_CLOUD_PROJECT"] = os.environ["FIRESTORE_PROJECT"] @@ -85,8 +86,8 @@ def test_vector_search_prefilter(db): assert results[1].key.name == "Excelsa" -def test_vector_search_distance_result_field(db): - vector_query = vector_search_distance_result_field(db) +def test_vector_search_distance_result_property(db): + vector_query = vector_search_distance_result_property(db) results = list(vector_query.fetch()) assert len(results) == 4 @@ -108,8 +109,8 @@ def test_vector_search_distance_threshold(db): assert results[0].key.name == "Liberica" assert results[1].key.name == "Robusta" -def test_vector_search_large_query(db): - key_list, vector_results, full_results = vector_search_large_query(db) +def test_vector_search_large_response(db): + key_list, vector_results, full_results = vector_search_large_response(db) assert len(key_list) == 4 # each list should have same number of elements assert len(key_list) == len(vector_results) From c9319c9be6b64af357977c4e5d204c5970b84130 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 25 Nov 2024 16:42:19 -0800 Subject: [PATCH 8/9] added projection sample --- datastore/cloud-client/vector_search.py | 23 ++++++++++++++++++++ datastore/cloud-client/vector_search_test.py | 21 ++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/datastore/cloud-client/vector_search.py b/datastore/cloud-client/vector_search.py index 0b75958d9e..d63602284c 100644 --- a/datastore/cloud-client/vector_search.py +++ b/datastore/cloud-client/vector_search.py @@ -95,6 +95,29 @@ def vector_search_distance_result_property(db): # [END datastore_vector_search_distance_result_property] return vector_query +def vector_search_distance_result_property_projection(db): + # [START datastore_vector_search_distance_result_property_projection] + from google.cloud.datastore.vector import DistanceMeasure + from google.cloud.datastore.vector import Vector + from google.cloud.datastore.vector import FindNearest + + vector_query = db.query( + kind="coffee-beans", + find_nearest=FindNearest( + vector_property="embedding_field", + query_vector=Vector([3.0, 1.0, 2.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + distance_result_property="vector_distance", + ) + ) + vector_query.projection = ["color"] + + for entity in vector_query.fetch(): + print(f"{entity.id}, Distance: {entity['vector_distance']}") + # [END datastore_vector_search_distance_result_property_projection] + return vector_query + def vector_search_distance_threshold(db): # [START datastore_vector_search_distance_threshold] diff --git a/datastore/cloud-client/vector_search_test.py b/datastore/cloud-client/vector_search_test.py index 5f1a9ddf52..bb9242f514 100644 --- a/datastore/cloud-client/vector_search_test.py +++ b/datastore/cloud-client/vector_search_test.py @@ -93,12 +93,33 @@ def test_vector_search_distance_result_property(db): assert len(results) == 4 assert results[0].key.name == "Liberica" assert results[0]["vector_distance"] == 0.0 + assert results[0]["embedding_field"] == Vector([3.0, 1.0, 2.0]) assert results[1].key.name == "Robusta" assert results[1]["vector_distance"] == 1.0 + assert results[1]["embedding_field"] == Vector([4.0, 1.0, 2.0]) assert results[2].key.name == "Arabica" assert results[2]["vector_distance"] == 7.0 + assert results[2]["embedding_field"] == Vector([10.0, 1.0, 2.0]) assert results[3].key.name == "Excelsa" assert results[3]["vector_distance"] == 8.0 + assert results[3]["embedding_field"] == Vector([11.0, 1.0, 2.0]) + + +def test_vector_search_distance_result_property_projection(db): + vector_query = vector_search_distance_result_property_projection(db) + results = list(vector_query.fetch()) + + assert len(results) == 4 + assert results[0].key.name == "Liberica" + assert results[0]["vector_distance"] == 0.0 + assert results[1].key.name == "Robusta" + assert results[1]["vector_distance"] == 1.0 + assert results[2].key.name == "Arabica" + assert results[2]["vector_distance"] == 7.0 + assert results[3].key.name == "Excelsa" + assert results[3]["vector_distance"] == 8.0 + + assert all("embedding_field" not in d for d in results) def test_vector_search_distance_threshold(db): From a8e80cc7c8e0efe61dc291d7be7ac6843f118d28 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 2 Dec 2024 13:36:36 -0800 Subject: [PATCH 9/9] use more realistic vector values --- datastore/cloud-client/vector_search.py | 16 +++++------ datastore/cloud-client/vector_search_test.py | 29 ++++++++++---------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/datastore/cloud-client/vector_search.py b/datastore/cloud-client/vector_search.py index d63602284c..00ecaabd77 100644 --- a/datastore/cloud-client/vector_search.py +++ b/datastore/cloud-client/vector_search.py @@ -24,7 +24,7 @@ def store_vectors(): { "name": "Kahawa coffee beans", "description": "Information about the Kahawa coffee beans.", - "embedding_field": Vector([1.0, 2.0, 3.0]), + "embedding_field": Vector([0.18332680, 0.24160706, 0.3416704]), } ) @@ -43,7 +43,7 @@ def vector_search_basic(db): kind="coffee-beans", find_nearest=FindNearest( vector_property="embedding_field", - query_vector=Vector([3.0, 1.0, 2.0]), + query_vector=Vector([0.3416704, 0.18332680, 0.24160706]), distance_measure=DistanceMeasure.EUCLIDEAN, limit=5, ) @@ -64,7 +64,7 @@ def vector_search_prefilter(db): filters=[PropertyFilter("color", "=", "red")], find_nearest=FindNearest( vector_property="embedding_field", - query_vector=Vector([3.0, 1.0, 2.0]), + query_vector=Vector([0.3416704, 0.18332680, 0.24160706]), distance_measure=DistanceMeasure.EUCLIDEAN, limit=5, ) @@ -83,7 +83,7 @@ def vector_search_distance_result_property(db): kind="coffee-beans", find_nearest=FindNearest( vector_property="embedding_field", - query_vector=Vector([3.0, 1.0, 2.0]), + query_vector=Vector([0.3416704, 0.18332680, 0.24160706]), distance_measure=DistanceMeasure.EUCLIDEAN, limit=5, distance_result_property="vector_distance", @@ -105,7 +105,7 @@ def vector_search_distance_result_property_projection(db): kind="coffee-beans", find_nearest=FindNearest( vector_property="embedding_field", - query_vector=Vector([3.0, 1.0, 2.0]), + query_vector=Vector([0.3416704, 0.18332680, 0.24160706]), distance_measure=DistanceMeasure.EUCLIDEAN, limit=5, distance_result_property="vector_distance", @@ -129,10 +129,10 @@ def vector_search_distance_threshold(db): kind="coffee-beans", find_nearest=FindNearest( vector_property="embedding_field", - query_vector=Vector([3.0, 1.0, 2.0]), + query_vector=Vector([0.3416704, 0.18332680, 0.24160706]), distance_measure=DistanceMeasure.EUCLIDEAN, limit=10, - distance_threshold=4.5 + distance_threshold=0.4 ) ) @@ -153,7 +153,7 @@ def vector_search_large_response(db): kind="coffee-beans", find_nearest=FindNearest( vector_property="embedding_field", - query_vector=Vector([3.0, 1.0, 2.0]), + query_vector=Vector([0.3416704, 0.18332680, 0.24160706]), distance_measure=DistanceMeasure.EUCLIDEAN, limit=100, distance_result_property="vector_distance", diff --git a/datastore/cloud-client/vector_search_test.py b/datastore/cloud-client/vector_search_test.py index bb9242f514..41295ffee8 100644 --- a/datastore/cloud-client/vector_search_test.py +++ b/datastore/cloud-client/vector_search_test.py @@ -21,6 +21,7 @@ from vector_search import store_vectors from vector_search import vector_search_basic from vector_search import vector_search_distance_result_property +from vector_search import vector_search_distance_result_property_projection from vector_search import vector_search_distance_threshold from vector_search import vector_search_prefilter from vector_search import vector_search_large_response @@ -49,13 +50,13 @@ def _clear_db(db): def add_coffee_beans_data(db): entity1 = datastore.Entity(db.key("coffee-beans", "Arabica")) - entity1.update({"embedding_field": Vector([10.0, 1.0, 2.0]), "color": "red"}) + entity1.update({"embedding_field": Vector([0.80522226, 0.18332680, 0.24160706]), "color": "red"}) entity2 = datastore.Entity(db.key("coffee-beans", "Robusta")) - entity2.update({"embedding_field": Vector([4.0, 1.0, 2.0]), "color": ""}) + entity2.update({"embedding_field": Vector([0.43979567, 0.18332680, 0.24160706]), "color": ""}) entity3 = datastore.Entity(db.key("coffee-beans", "Excelsa")) - entity3.update({"embedding_field": Vector([11.0, 1.0, 2.0]), "color": "red"}) + entity3.update({"embedding_field": Vector([0.90477061, 0.18332680, 0.24160706]), "color": "red"}) entity4 = datastore.Entity(db.key("coffee-beans", "Liberica")) - entity4.update({"embedding_field": Vector([3.0, 1.0, 2.0]), "color": "green"}) + entity4.update({"embedding_field": Vector([0.3416704, 0.18332680, 0.24160706]), "color": "green"}) entity_list = [entity1, entity2, entity3, entity4] db.put_multi(entity_list) @@ -93,16 +94,16 @@ def test_vector_search_distance_result_property(db): assert len(results) == 4 assert results[0].key.name == "Liberica" assert results[0]["vector_distance"] == 0.0 - assert results[0]["embedding_field"] == Vector([3.0, 1.0, 2.0]) + assert results[0]["embedding_field"] == Vector([0.3416704, 0.18332680, 0.24160706]) assert results[1].key.name == "Robusta" - assert results[1]["vector_distance"] == 1.0 - assert results[1]["embedding_field"] == Vector([4.0, 1.0, 2.0]) + assert results[1]["vector_distance"] == pytest.approx(0.09812527) + assert results[1]["embedding_field"] == Vector([0.43979567, 0.18332680, 0.24160706]) assert results[2].key.name == "Arabica" - assert results[2]["vector_distance"] == 7.0 - assert results[2]["embedding_field"] == Vector([10.0, 1.0, 2.0]) + assert results[2]["vector_distance"] == pytest.approx(0.46355186) + assert results[2]["embedding_field"] == Vector([0.80522226, 0.18332680, 0.24160706]) assert results[3].key.name == "Excelsa" - assert results[3]["vector_distance"] == 8.0 - assert results[3]["embedding_field"] == Vector([11.0, 1.0, 2.0]) + assert results[3]["vector_distance"] == pytest.approx(0.56310021) + assert results[3]["embedding_field"] == Vector([0.90477061, 0.18332680, 0.24160706]) def test_vector_search_distance_result_property_projection(db): @@ -113,11 +114,11 @@ def test_vector_search_distance_result_property_projection(db): assert results[0].key.name == "Liberica" assert results[0]["vector_distance"] == 0.0 assert results[1].key.name == "Robusta" - assert results[1]["vector_distance"] == 1.0 + assert results[1]["vector_distance"] == pytest.approx(0.09812527) assert results[2].key.name == "Arabica" - assert results[2]["vector_distance"] == 7.0 + assert results[2]["vector_distance"] == pytest.approx(0.46355186) assert results[3].key.name == "Excelsa" - assert results[3]["vector_distance"] == 8.0 + assert results[3]["vector_distance"] == pytest.approx(0.56310021) assert all("embedding_field" not in d for d in results)