Skip to content

Commit a717213

Browse files
committed
Moved utils to separate module
1 parent 7acad90 commit a717213

File tree

3 files changed

+26
-26
lines changed

3 files changed

+26
-26
lines changed

lib/pgvector.ex

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -129,26 +129,6 @@ defmodule Pgvector do
129129
Pgvector.Extensions.Sparsevec
130130
]
131131
end
132-
133-
# TODO move / improve pattern
134-
@doc false
135-
def to_sql(vector) when is_struct(vector, Pgvector.HalfVector) do
136-
vector
137-
end
138-
139-
def to_sql(vector) when is_struct(vector, Pgvector.SparseVector) do
140-
vector
141-
end
142-
143-
def to_sql(vector) do
144-
vector |> Pgvector.new()
145-
end
146-
147-
# TODO move / improve pattern
148-
@doc false
149-
def to_bit_sql(vector) when is_bitstring(vector) do
150-
vector
151-
end
152132
end
153133

154134
defimpl Inspect, for: Pgvector do

lib/pgvector/ecto/query.ex

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ if Code.ensure_loaded?(Ecto) do
99
"""
1010
defmacro l2_distance(column, value) do
1111
quote do
12-
fragment("(? <-> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
12+
fragment("(? <-> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
1313
end
1414
end
1515

@@ -18,7 +18,7 @@ if Code.ensure_loaded?(Ecto) do
1818
"""
1919
defmacro max_inner_product(column, value) do
2020
quote do
21-
fragment("(? <#> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
21+
fragment("(? <#> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
2222
end
2323
end
2424

@@ -27,7 +27,7 @@ if Code.ensure_loaded?(Ecto) do
2727
"""
2828
defmacro cosine_distance(column, value) do
2929
quote do
30-
fragment("(? <=> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
30+
fragment("(? <=> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
3131
end
3232
end
3333

@@ -36,7 +36,7 @@ if Code.ensure_loaded?(Ecto) do
3636
"""
3737
defmacro l1_distance(column, value) do
3838
quote do
39-
fragment("(? <+> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
39+
fragment("(? <+> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
4040
end
4141
end
4242

@@ -45,7 +45,7 @@ if Code.ensure_loaded?(Ecto) do
4545
"""
4646
defmacro hamming_distance(column, value) do
4747
quote do
48-
fragment("(? <~> ?)", unquote(column), ^Pgvector.to_bit_sql(unquote(value)))
48+
fragment("(? <~> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_bit_sql(unquote(value)))
4949
end
5050
end
5151

@@ -54,7 +54,7 @@ if Code.ensure_loaded?(Ecto) do
5454
"""
5555
defmacro jaccard_distance(column, value) do
5656
quote do
57-
fragment("(? <%> ?)", unquote(column), ^Pgvector.to_bit_sql(unquote(value)))
57+
fragment("(? <%> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_bit_sql(unquote(value)))
5858
end
5959
end
6060
end

lib/pgvector/ecto/utils.ex

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# TODO improve pattern
2+
defmodule Pgvector.Ecto.Utils do
3+
@moduledoc false
4+
5+
def to_sql(vector) when is_struct(vector, Pgvector.HalfVector) do
6+
vector
7+
end
8+
9+
def to_sql(vector) when is_struct(vector, Pgvector.SparseVector) do
10+
vector
11+
end
12+
13+
def to_sql(vector) do
14+
vector |> Pgvector.new()
15+
end
16+
17+
def to_bit_sql(vector) when is_bitstring(vector) do
18+
vector
19+
end
20+
end

0 commit comments

Comments
 (0)