Skip to content

Commit 7acad90

Browse files
committed
Added support for bit type to Ecto
1 parent cbb6458 commit 7acad90

File tree

5 files changed

+60
-5
lines changed

5 files changed

+60
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
## 0.2.2 (unreleased)
22

33
- Added support for `halfvec` and `sparsevec` types
4+
- Added support for `bit` type to Ecto
45
- Added `Pgvector.extensions/0` function
5-
- Added `l1_distance` function for Ecto
6+
- Added `l1_distance`, `hamming_distance`, and `jaccard_distance` functions for Ecto
67

78
## 0.2.1 (2023-09-25)
89

lib/pgvector.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ defmodule Pgvector do
143143
def to_sql(vector) do
144144
vector |> Pgvector.new()
145145
end
146+
147+
# TODO move / improve pattern
148+
@doc false
149+
def to_bit_sql(vector) when is_bitstring(vector) do
150+
vector
151+
end
146152
end
147153

148154
defimpl Inspect, for: Pgvector do

lib/pgvector/ecto/bit.ex

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
if Code.ensure_loaded?(Ecto) do
2+
defmodule Pgvector.Ecto.Bit do
3+
use Ecto.Type
4+
5+
def type, do: :bit
6+
7+
def cast(value) do
8+
{:ok, value}
9+
end
10+
11+
def load(data) do
12+
{:ok, data}
13+
end
14+
15+
def dump(value) do
16+
{:ok, value}
17+
end
18+
end
19+
end

lib/pgvector/ecto/query.ex

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,23 @@ if Code.ensure_loaded?(Ecto) do
3939
fragment("(? <+> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
4040
end
4141
end
42+
43+
@doc """
44+
Returns the Hamming distance
45+
"""
46+
defmacro hamming_distance(column, value) do
47+
quote do
48+
fragment("(? <~> ?)", unquote(column), ^Pgvector.to_bit_sql(unquote(value)))
49+
end
50+
end
51+
52+
@doc """
53+
Returns the Jaccard distance
54+
"""
55+
defmacro jaccard_distance(column, value) do
56+
quote do
57+
fragment("(? <%> ?)", unquote(column), ^Pgvector.to_bit_sql(unquote(value)))
58+
end
59+
end
4260
end
4361
end

test/ecto_test.exs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ defmodule Item do
44
schema "ecto_items" do
55
field :embedding, Pgvector.Ecto.Vector
66
field :half_embedding, Pgvector.Ecto.HalfVector
7+
field :binary_embedding, Pgvector.Ecto.Bit
78
field :sparse_embedding, Pgvector.Ecto.SparseVector
89
end
910
end
@@ -17,15 +18,15 @@ defmodule EctoTest do
1718
setup_all do
1819
Ecto.Adapters.SQL.query!(Repo, "CREATE EXTENSION IF NOT EXISTS vector", [])
1920
Ecto.Adapters.SQL.query!(Repo, "DROP TABLE IF EXISTS ecto_items", [])
20-
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), sparse_embedding sparsevec(3))", [])
21+
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))", [])
2122
create_items()
2223
:ok
2324
end
2425

2526
defp create_items do
26-
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1]), sparse_embedding: Pgvector.SparseVector.new([1, 1, 1])})
27-
Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3], sparse_embedding: [2, 2, 3]})
28-
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16), sparse_embedding: Nx.tensor([1, 1, 2], type: :f32)})
27+
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1]), binary_embedding: <<0::1, 0::1, 0::1>>, sparse_embedding: Pgvector.SparseVector.new([1, 1, 1])})
28+
Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3], binary_embedding: <<1::1, 0::1, 1::1>>, sparse_embedding: [2, 2, 3]})
29+
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16), binary_embedding: <<1::1, 1::1, 1::1>>, sparse_embedding: Nx.tensor([1, 1, 2], type: :f32)})
2930
end
3031

3132
test "vector l2 distance" do
@@ -80,6 +81,16 @@ defmodule EctoTest do
8081
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
8182
end
8283

84+
test "bit hamming distance" do
85+
items = Repo.all(from i in Item, order_by: hamming_distance(i.binary_embedding, <<1::1, 0::1, 1::1>>), limit: 5)
86+
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
87+
end
88+
89+
test "bit jaccard distance" do
90+
items = Repo.all(from i in Item, order_by: jaccard_distance(i.binary_embedding, <<1::1, 0::1, 1::1>>), limit: 5)
91+
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
92+
end
93+
8394
test "sparsevec l2 distance" do
8495
items = Repo.all(from i in Item, order_by: l2_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
8596
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]

0 commit comments

Comments
 (0)