Skip to content

Commit 2628d5e

Browse files
committed
Improved distance functions
1 parent b9e2637 commit 2628d5e

File tree

3 files changed

+20
-35
lines changed

3 files changed

+20
-35
lines changed

lib/pgvector/ecto/query.ex

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ if Code.ensure_loaded?(Ecto) do
99
"""
1010
defmacro l2_distance(column, value) do
1111
quote do
12-
fragment("(? <-> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
12+
fragment("(? <-> ?)", unquote(column), unquote(value))
1313
end
1414
end
1515

@@ -18,7 +18,7 @@ if Code.ensure_loaded?(Ecto) do
1818
"""
1919
defmacro max_inner_product(column, value) do
2020
quote do
21-
fragment("(? <#> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
21+
fragment("(? <#> ?)", unquote(column), unquote(value))
2222
end
2323
end
2424

@@ -27,7 +27,7 @@ if Code.ensure_loaded?(Ecto) do
2727
"""
2828
defmacro cosine_distance(column, value) do
2929
quote do
30-
fragment("(? <=> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
30+
fragment("(? <=> ?)", unquote(column), unquote(value))
3131
end
3232
end
3333

@@ -36,7 +36,7 @@ if Code.ensure_loaded?(Ecto) do
3636
"""
3737
defmacro l1_distance(column, value) do
3838
quote do
39-
fragment("(? <+> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
39+
fragment("(? <+> ?)", unquote(column), unquote(value))
4040
end
4141
end
4242

lib/pgvector/ecto/utils.ex

Lines changed: 0 additions & 16 deletions
This file was deleted.

test/ecto_test.exs

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,54 +30,55 @@ defmodule EctoTest do
3030
end
3131

3232
test "vector l2 distance" do
33-
items = Repo.all(from i in Item, order_by: l2_distance(i.embedding, [1, 1, 1]), limit: 5)
33+
# TODO restore support for list
34+
items = Repo.all(from i in Item, order_by: l2_distance(i.embedding, ^Pgvector.new([1, 1, 1])), limit: 5)
3435
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
3536
assert Enum.map(items, fn v -> v.embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]]
3637
end
3738

3839
test "vector max inner product" do
39-
items = Repo.all(from i in Item, order_by: max_inner_product(i.embedding, [1, 1, 1]), limit: 5)
40+
items = Repo.all(from i in Item, order_by: max_inner_product(i.embedding, ^Pgvector.new([1, 1, 1])), limit: 5)
4041
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
4142
end
4243

4344
test "vector cosine distance" do
44-
items = Repo.all(from i in Item, order_by: cosine_distance(i.embedding, [1, 1, 1]), limit: 5)
45+
items = Repo.all(from i in Item, order_by: cosine_distance(i.embedding, ^Pgvector.new([1, 1, 1])), limit: 5)
4546
assert Enum.map(items, fn v -> v.id end) == [1, 2, 3]
4647
end
4748

4849
test "vector cosine similarity" do
49-
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.embedding, [1, 1, 1])), limit: 5)
50+
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.embedding, ^Pgvector.new([1, 1, 1]))), limit: 5)
5051
assert Enum.map(items, fn v -> v.id end) == [3, 2, 1]
5152
end
5253

5354
test "vector l1 distance" do
54-
items = Repo.all(from i in Item, order_by: l1_distance(i.embedding, [1, 1, 1]), limit: 5)
55+
items = Repo.all(from i in Item, order_by: l1_distance(i.embedding, ^Pgvector.new([1, 1, 1])), limit: 5)
5556
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
5657
end
5758

5859
test "halfvec l2 distance" do
59-
items = Repo.all(from i in Item, order_by: l2_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
60+
items = Repo.all(from i in Item, order_by: l2_distance(i.half_embedding, ^Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
6061
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
6162
assert Enum.map(items, fn v -> v.half_embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]]
6263
end
6364

6465
test "halfvec max inner product" do
65-
items = Repo.all(from i in Item, order_by: max_inner_product(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
66+
items = Repo.all(from i in Item, order_by: max_inner_product(i.half_embedding, ^Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
6667
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
6768
end
6869

6970
test "halfvec cosine distance" do
70-
items = Repo.all(from i in Item, order_by: cosine_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
71+
items = Repo.all(from i in Item, order_by: cosine_distance(i.half_embedding, ^Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
7172
assert Enum.map(items, fn v -> v.id end) == [1, 2, 3]
7273
end
7374

7475
test "halfvec cosine similarity" do
75-
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1]))), limit: 5)
76+
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.half_embedding, ^Pgvector.HalfVector.new([1, 1, 1]))), limit: 5)
7677
assert Enum.map(items, fn v -> v.id end) == [3, 2, 1]
7778
end
7879

7980
test "halfvec l1 distance" do
80-
items = Repo.all(from i in Item, order_by: l1_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
81+
items = Repo.all(from i in Item, order_by: l1_distance(i.half_embedding, ^Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
8182
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
8283
end
8384

@@ -92,28 +93,28 @@ defmodule EctoTest do
9293
end
9394

9495
test "sparsevec l2 distance" do
95-
items = Repo.all(from i in Item, order_by: l2_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
96+
items = Repo.all(from i in Item, order_by: l2_distance(i.sparse_embedding, ^Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
9697
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
9798
assert Enum.map(items, fn v -> v.sparse_embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]]
9899
end
99100

100101
test "sparsevec max inner product" do
101-
items = Repo.all(from i in Item, order_by: max_inner_product(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
102+
items = Repo.all(from i in Item, order_by: max_inner_product(i.sparse_embedding, ^Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
102103
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
103104
end
104105

105106
test "sparsevec cosine distance" do
106-
items = Repo.all(from i in Item, order_by: cosine_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
107+
items = Repo.all(from i in Item, order_by: cosine_distance(i.sparse_embedding, ^Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
107108
assert Enum.map(items, fn v -> v.id end) == [1, 2, 3]
108109
end
109110

110111
test "sparsevec cosine similarity" do
111-
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1]))), limit: 5)
112+
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.sparse_embedding, ^Pgvector.SparseVector.new([1, 1, 1]))), limit: 5)
112113
assert Enum.map(items, fn v -> v.id end) == [3, 2, 1]
113114
end
114115

115116
test "sparsevec l1 distance" do
116-
items = Repo.all(from i in Item, order_by: l1_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
117+
items = Repo.all(from i in Item, order_by: l1_distance(i.sparse_embedding, ^Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
117118
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
118119
end
119120

0 commit comments

Comments
 (0)