Skip to content

Commit cbb6458

Browse files
committed
Added support for sparsevec type
1 parent 7ef2332 commit cbb6458

File tree

8 files changed

+203
-7
lines changed

8 files changed

+203
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## 0.2.2 (unreleased)
22

3-
- Added support for `halfvec` type
3+
- Added support for `halfvec` and `sparsevec` types
44
- Added `Pgvector.extensions/0` function
55
- Added `l1_distance` function for Ecto
66

lib/pgvector.ex

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ defmodule Pgvector do
5656
vector.data
5757
end
5858

59+
def to_binary(vector) when is_struct(vector, Pgvector.SparseVector) do
60+
vector.data
61+
end
62+
5963
@doc """
6064
Converts the vector to a list
6165
"""
@@ -69,6 +73,16 @@ defmodule Pgvector do
6973
for <<v::float-16 <- bin>>, do: v
7074
end
7175

76+
def to_list(vector) when is_struct(vector, Pgvector.SparseVector) do
77+
<<dim::signed-32, nnz::signed-32, 0::signed-32, indices::binary-size(nnz)-unit(32),
78+
values::binary-size(nnz)-unit(32)>> = vector.data
79+
80+
indices = for <<v::signed-32 <- indices>>, do: v
81+
values = for <<v::float-32 <- values>>, do: v
82+
list = List.duplicate(0.0, dim)
83+
Enum.zip_reduce(indices, values, list, fn x, y, acc -> List.replace_at(acc, x, y) end)
84+
end
85+
7286
if Code.ensure_loaded?(Nx) do
7387
@doc """
7488
Converts the vector to a tensor
@@ -83,6 +97,11 @@ defmodule Pgvector do
8397
bin |> f16_big_to_native() |> Nx.from_binary(:f16)
8498
end
8599

100+
def to_tensor(vector) when is_struct(vector, Pgvector.SparseVector) do
101+
# TODO improve
102+
vector |> to_list() |> Nx.tensor(type: :f32)
103+
end
104+
86105
defp f32_big_to_native(binary) do
87106
if System.endianness() == :big do
88107
binary
@@ -106,7 +125,8 @@ defmodule Pgvector do
106125
def extensions do
107126
[
108127
Pgvector.Extensions.Vector,
109-
Pgvector.Extensions.Halfvec
128+
Pgvector.Extensions.Halfvec,
129+
Pgvector.Extensions.Sparsevec
110130
]
111131
end
112132

@@ -116,6 +136,10 @@ defmodule Pgvector do
116136
vector
117137
end
118138

139+
def to_sql(vector) when is_struct(vector, Pgvector.SparseVector) do
140+
vector
141+
end
142+
119143
def to_sql(vector) do
120144
vector |> Pgvector.new()
121145
end

lib/pgvector/ecto/sparse_vector.ex

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
if Code.ensure_loaded?(Ecto) do
2+
defmodule Pgvector.Ecto.SparseVector do
3+
use Ecto.Type
4+
5+
def type, do: :sparsevec
6+
7+
def cast(value) do
8+
{:ok, value |> Pgvector.SparseVector.new()}
9+
end
10+
11+
def load(data) do
12+
{:ok, data}
13+
end
14+
15+
def dump(value) do
16+
{:ok, value}
17+
end
18+
end
19+
end

lib/pgvector/extensions/sparsevec.ex

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
defmodule Pgvector.Extensions.Sparsevec do
2+
import Postgrex.BinaryUtils, warn: false
3+
4+
def init(opts), do: Keyword.get(opts, :decode_binary, :copy)
5+
6+
def matching(_), do: [type: "sparsevec"]
7+
8+
def format(_), do: :binary
9+
10+
def encode(_) do
11+
quote do
12+
vec ->
13+
data = vec |> Pgvector.SparseVector.new() |> Pgvector.to_binary()
14+
[<<IO.iodata_length(data)::int32()>> | data]
15+
end
16+
end
17+
18+
def decode(:copy) do
19+
quote do
20+
<<len::int32(), bin::binary-size(len)>> ->
21+
bin |> :binary.copy() |> Pgvector.SparseVector.from_binary()
22+
end
23+
end
24+
25+
def decode(_) do
26+
quote do
27+
<<len::int32(), bin::binary-size(len)>> ->
28+
bin |> Pgvector.SparseVector.from_binary()
29+
end
30+
end
31+
end

lib/pgvector/sparse_vector.ex

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
defmodule Pgvector.SparseVector do
2+
@moduledoc """
3+
A sparse vector struct for pgvector
4+
"""
5+
6+
defstruct [:data]
7+
8+
@doc """
9+
Creates a new sparse vector from a list, tensor or sparse vector
10+
"""
11+
def new(list) when is_list(list) do
12+
indices =
13+
list
14+
|> Enum.with_index()
15+
|> Enum.filter(fn {v, _} -> v != 0 end)
16+
|> Enum.map(fn {_, i} -> i end)
17+
18+
values = list |> Enum.filter(fn v -> v != 0 end)
19+
dim = list |> length()
20+
nnz = indices |> length()
21+
indices = for v <- indices, into: "", do: <<v::signed-32>>
22+
values = for v <- values, into: "", do: <<v::float-32>>
23+
from_binary(<<dim::signed-32, nnz::signed-32, 0::signed-32, indices::binary, values::binary>>)
24+
end
25+
26+
def new(%Pgvector.SparseVector{} = vector) do
27+
vector
28+
end
29+
30+
if Code.ensure_loaded?(Nx) do
31+
def new(tensor) when is_struct(tensor, Nx.Tensor) do
32+
if Nx.rank(tensor) != 1 do
33+
raise ArgumentError, "expected rank to be 1"
34+
end
35+
36+
# TODO improve
37+
new(tensor |> Nx.to_list())
38+
end
39+
end
40+
41+
@doc """
42+
Creates a new sparse vector from its binary representation
43+
"""
44+
def from_binary(binary) when is_binary(binary) do
45+
%Pgvector.SparseVector{data: binary}
46+
end
47+
end
48+
49+
defimpl Inspect, for: Pgvector.SparseVector do
50+
import Inspect.Algebra
51+
52+
def inspect(vec, opts) do
53+
# TODO improve
54+
concat(["Pgvector.SparseVector.new(", Inspect.List.inspect(Pgvector.to_list(vec), opts), ")"])
55+
end
56+
end

test/ecto_test.exs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ defmodule Item do
44
schema "ecto_items" do
55
field :embedding, Pgvector.Ecto.Vector
66
field :half_embedding, Pgvector.Ecto.HalfVector
7+
field :sparse_embedding, Pgvector.Ecto.SparseVector
78
end
89
end
910

@@ -16,15 +17,15 @@ defmodule EctoTest do
1617
setup_all do
1718
Ecto.Adapters.SQL.query!(Repo, "CREATE EXTENSION IF NOT EXISTS vector", [])
1819
Ecto.Adapters.SQL.query!(Repo, "DROP TABLE IF EXISTS ecto_items", [])
19-
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3))", [])
20+
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), sparse_embedding sparsevec(3))", [])
2021
create_items()
2122
:ok
2223
end
2324

2425
defp create_items do
25-
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1])})
26-
Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3]})
27-
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16)})
26+
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1]), sparse_embedding: Pgvector.SparseVector.new([1, 1, 1])})
27+
Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3], sparse_embedding: [2, 2, 3]})
28+
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16), sparse_embedding: Nx.tensor([1, 1, 2], type: :f32)})
2829
end
2930

3031
test "vector l2 distance" do
@@ -79,6 +80,32 @@ defmodule EctoTest do
7980
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
8081
end
8182

83+
test "sparsevec l2 distance" do
84+
items = Repo.all(from i in Item, order_by: l2_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
85+
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
86+
assert Enum.map(items, fn v -> v.sparse_embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]]
87+
end
88+
89+
test "sparsevec max inner product" do
90+
items = Repo.all(from i in Item, order_by: max_inner_product(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
91+
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
92+
end
93+
94+
test "sparsevec cosine distance" do
95+
items = Repo.all(from i in Item, order_by: cosine_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
96+
assert Enum.map(items, fn v -> v.id end) == [1, 2, 3]
97+
end
98+
99+
test "sparsevec cosine similarity" do
100+
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1]))), limit: 5)
101+
assert Enum.map(items, fn v -> v.id end) == [3, 2, 1]
102+
end
103+
104+
test "sparsevec l1 distance" do
105+
items = Repo.all(from i in Item, order_by: l1_distance(i.sparse_embedding, Pgvector.SparseVector.new([1, 1, 1])), limit: 5)
106+
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
107+
end
108+
82109
test "cast" do
83110
embedding = [1, 1, 1]
84111
items = Repo.all(from i in Item, where: i.embedding == ^embedding)

test/postgrex_test.exs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ defmodule PostgrexTest do
1010
{:ok, pid} = Postgrex.start_link(database: "pgvector_elixir_test", types: PostgrexApp.PostgrexTypes)
1111
Postgrex.query!(pid, "CREATE EXTENSION IF NOT EXISTS vector", [])
1212
Postgrex.query!(pid, "DROP TABLE IF EXISTS postgrex_items", [])
13-
Postgrex.query!(pid, "CREATE TABLE postgrex_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3))", [])
13+
Postgrex.query!(pid, "CREATE TABLE postgrex_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))", [])
1414
{:ok, pid: pid}
1515
end
1616

@@ -41,6 +41,17 @@ defmodule PostgrexTest do
4141
assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]]
4242
end
4343

44+
test "sparsevec l2 distance", %{pid: pid} = _context do
45+
embeddings = [Pgvector.SparseVector.new([1, 1, 1]), [2, 2, 2], Nx.tensor([1, 1, 2], type: :f32)]
46+
Postgrex.query!(pid, "INSERT INTO postgrex_items (sparse_embedding) VALUES ($1), ($2), ($3)", embeddings)
47+
48+
result = Postgrex.query!(pid, "SELECT id, sparse_embedding FROM postgrex_items ORDER BY sparse_embedding <-> $1 LIMIT 5", [[1, 1, 1]])
49+
50+
assert ["id", "sparse_embedding"] == result.columns
51+
assert Enum.map(result.rows, fn v -> Enum.at(v, 0) end) == [1, 3, 2]
52+
assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]]
53+
end
54+
4455
test "create index", %{pid: pid} = _context do
4556
Postgrex.query!(pid, "CREATE INDEX my_index ON postgrex_items USING ivfflat (embedding vector_l2_ops) WITH (lists = 1)", [])
4657
end

test/sparse_vector_test.exs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
defmodule SparseVectorTest do
2+
use ExUnit.Case
3+
4+
test "sparse vector" do
5+
vector = Pgvector.SparseVector.new([1, 2, 3])
6+
assert vector == vector |> Pgvector.SparseVector.new()
7+
end
8+
9+
test "list" do
10+
list = [1.0, 2.0, 3.0]
11+
assert list == list |> Pgvector.SparseVector.new() |> Pgvector.to_list()
12+
end
13+
14+
test "tensor" do
15+
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32)
16+
assert tensor == tensor |> Pgvector.SparseVector.new() |> Pgvector.to_tensor()
17+
end
18+
19+
test "inspect" do
20+
vector = Pgvector.SparseVector.new([1, 2, 3])
21+
assert "Pgvector.SparseVector.new([1.0, 2.0, 3.0])" == inspect(vector)
22+
end
23+
24+
test "equals" do
25+
assert Pgvector.SparseVector.new([1, 2, 3]) == Pgvector.SparseVector.new([1, 2, 3])
26+
refute Pgvector.SparseVector.new([1, 2, 3]) == Pgvector.SparseVector.new([1, 2, 4])
27+
end
28+
end

0 commit comments

Comments
 (0)