Skip to content

Commit 295810f

Browse files
committed
Added test for halfvec with Postgrex
1 parent 963ba57 commit 295810f

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

test/postgrex_test.exs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Postgrex.Types.define(PostgrexApp.PostgrexTypes, [Pgvector.Extensions.Vector], [])
1+
Postgrex.Types.define(PostgrexApp.PostgrexTypes, Pgvector.extensions(), [])
22

33
# needed if postgrex is optional
44
# Application.ensure_all_started(:postgrex)
@@ -10,7 +10,7 @@ defmodule PostgrexTest do
1010
{:ok, pid} = Postgrex.start_link(database: "pgvector_elixir_test", types: PostgrexApp.PostgrexTypes)
1111
Postgrex.query!(pid, "CREATE EXTENSION IF NOT EXISTS vector", [])
1212
Postgrex.query!(pid, "DROP TABLE IF EXISTS postgrex_items", [])
13-
Postgrex.query!(pid, "CREATE TABLE postgrex_items (id bigserial primary key, embedding vector(3), binary_embedding bit(3))", [])
13+
Postgrex.query!(pid, "CREATE TABLE postgrex_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3))", [])
1414
{:ok, pid: pid}
1515
end
1616

@@ -19,7 +19,7 @@ defmodule PostgrexTest do
1919
context
2020
end
2121

22-
test "l2 distance", %{pid: pid} = _context do
22+
test "vector l2 distance", %{pid: pid} = _context do
2323
embeddings = [Pgvector.new([1, 1, 1]), [2, 2, 2], Nx.tensor([1, 1, 2], type: :f32)]
2424
Postgrex.query!(pid, "INSERT INTO postgrex_items (embedding) VALUES ($1), ($2), ($3)", embeddings)
2525

@@ -30,6 +30,17 @@ defmodule PostgrexTest do
3030
assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]]
3131
end
3232

33+
test "halfvec l2 distance", %{pid: pid} = _context do
34+
embeddings = [Pgvector.HalfVector.new([1, 1, 1]), [2, 2, 2], Nx.tensor([1, 1, 2], type: :f16)]
35+
Postgrex.query!(pid, "INSERT INTO postgrex_items (half_embedding) VALUES ($1), ($2), ($3)", embeddings)
36+
37+
result = Postgrex.query!(pid, "SELECT id, half_embedding FROM postgrex_items ORDER BY half_embedding <-> $1 LIMIT 5", [[1, 1, 1]])
38+
39+
assert ["id", "half_embedding"] == result.columns
40+
assert Enum.map(result.rows, fn v -> Enum.at(v, 0) end) == [1, 3, 2]
41+
assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]]
42+
end
43+
3344
test "create index", %{pid: pid} = _context do
3445
Postgrex.query!(pid, "CREATE INDEX my_index ON postgrex_items USING ivfflat (embedding vector_l2_ops) WITH (lists = 1)", [])
3546
end

0 commit comments

Comments
 (0)