Skip to content

Commit 963ba57

Browse files
committed
Added support for halfvec type
1 parent 8d2ffa1 commit 963ba57

File tree

8 files changed

+208
-14
lines changed

8 files changed

+208
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 0.2.2 (unreleased)
22

3+
- Added support for `halfvec` type
34
- Added `Pgvector.extensions/0` function
45
- Added `l1_distance` function for Ecto
56

lib/pgvector.ex

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ defmodule Pgvector do
5252
vector.data
5353
end
5454

55+
def to_binary(vector) when is_struct(vector, Pgvector.HalfVector) do
56+
vector.data
57+
end
58+
5559
@doc """
5660
Converts the vector to a list
5761
"""
@@ -60,6 +64,11 @@ defmodule Pgvector do
6064
for <<v::float-32 <- bin>>, do: v
6165
end
6266

67+
def to_list(vector) when is_struct(vector, Pgvector.HalfVector) do
68+
<<dim::unsigned-16, 0::unsigned-16, bin::binary-size(dim)-unit(16)>> = vector.data
69+
for <<v::float-16 <- bin>>, do: v
70+
end
71+
6372
if Code.ensure_loaded?(Nx) do
6473
@doc """
6574
Converts the vector to a tensor
@@ -69,23 +78,47 @@ defmodule Pgvector do
6978
bin |> f32_big_to_native() |> Nx.from_binary(:f32)
7079
end
7180

81+
def to_tensor(vector) when is_struct(vector, Pgvector.HalfVector) do
82+
<<dim::unsigned-16, 0::unsigned-16, bin::binary-size(dim)-unit(16)>> = vector.data
83+
bin |> f16_big_to_native() |> Nx.from_binary(:f16)
84+
end
85+
7286
defp f32_big_to_native(binary) do
7387
if System.endianness() == :big do
7488
binary
7589
else
7690
for <<n::float-32-big <- binary>>, into: "", do: <<n::float-32-little>>
7791
end
7892
end
93+
94+
defp f16_big_to_native(binary) do
95+
if System.endianness() == :big do
96+
binary
97+
else
98+
for <<n::float-16-big <- binary>>, into: "", do: <<n::float-16-little>>
99+
end
100+
end
79101
end
80102

81103
@doc """
82104
Extensions for Postgrex
83105
"""
84106
def extensions do
85107
[
86-
Pgvector.Extensions.Vector
108+
Pgvector.Extensions.Vector,
109+
Pgvector.Extensions.Halfvec
87110
]
88111
end
112+
113+
# TODO move / improve pattern
114+
@doc false
115+
def to_sql(vector) when is_struct(vector, Pgvector.HalfVector) do
116+
vector
117+
end
118+
119+
def to_sql(vector) do
120+
vector |> Pgvector.new()
121+
end
89122
end
90123

91124
defimpl Inspect, for: Pgvector do

lib/pgvector/ecto/halfvec.ex

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
if Code.ensure_loaded?(Ecto) do
2+
defmodule Pgvector.Ecto.HalfVector do
3+
use Ecto.Type
4+
5+
def type, do: :halfvec
6+
7+
def cast(value) do
8+
{:ok, value |> Pgvector.HalfVector.new()}
9+
end
10+
11+
def load(data) do
12+
{:ok, data}
13+
end
14+
15+
def dump(value) do
16+
{:ok, value}
17+
end
18+
end
19+
end

lib/pgvector/ecto/query.ex

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ if Code.ensure_loaded?(Ecto) do
99
"""
1010
defmacro l2_distance(column, value) do
1111
quote do
12-
fragment("(? <-> ?::vector)", unquote(column), unquote(value))
12+
fragment("(? <-> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
1313
end
1414
end
1515

@@ -18,7 +18,7 @@ if Code.ensure_loaded?(Ecto) do
1818
"""
1919
defmacro max_inner_product(column, value) do
2020
quote do
21-
fragment("(? <#> ?::vector)", unquote(column), unquote(value))
21+
fragment("(? <#> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
2222
end
2323
end
2424

@@ -27,7 +27,7 @@ if Code.ensure_loaded?(Ecto) do
2727
"""
2828
defmacro cosine_distance(column, value) do
2929
quote do
30-
fragment("(? <=> ?::vector)", unquote(column), unquote(value))
30+
fragment("(? <=> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
3131
end
3232
end
3333

@@ -36,7 +36,7 @@ if Code.ensure_loaded?(Ecto) do
3636
"""
3737
defmacro l1_distance(column, value) do
3838
quote do
39-
fragment("(? <+> ?::vector)", unquote(column), unquote(value))
39+
fragment("(? <+> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
4040
end
4141
end
4242
end

lib/pgvector/extensions/halfvec.ex

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
defmodule Pgvector.Extensions.Halfvec do
2+
import Postgrex.BinaryUtils, warn: false
3+
4+
def init(opts), do: Keyword.get(opts, :decode_binary, :copy)
5+
6+
def matching(_), do: [type: "halfvec"]
7+
8+
def format(_), do: :binary
9+
10+
def encode(_) do
11+
quote do
12+
vec ->
13+
data = vec |> Pgvector.HalfVector.new() |> Pgvector.to_binary()
14+
[<<IO.iodata_length(data)::int32()>> | data]
15+
end
16+
end
17+
18+
def decode(:copy) do
19+
quote do
20+
<<len::int32(), bin::binary-size(len)>> ->
21+
bin |> :binary.copy() |> Pgvector.HalfVector.from_binary()
22+
end
23+
end
24+
25+
def decode(_) do
26+
quote do
27+
<<len::int32(), bin::binary-size(len)>> ->
28+
bin |> Pgvector.HalfVector.from_binary()
29+
end
30+
end
31+
end

lib/pgvector/half_vector.ex

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
defmodule Pgvector.HalfVector do
2+
@moduledoc """
3+
A half vector struct for pgvector
4+
"""
5+
6+
defstruct [:data]
7+
8+
@doc """
9+
Creates a new half vector from a list, tensor, or half vector
10+
"""
11+
def new(list) when is_list(list) do
12+
dim = list |> length()
13+
bin = for v <- list, into: "", do: <<v::float-16>>
14+
from_binary(<<dim::unsigned-16, 0::unsigned-16, bin::binary>>)
15+
end
16+
17+
def new(%Pgvector.HalfVector{} = vector) do
18+
vector
19+
end
20+
21+
if Code.ensure_loaded?(Nx) do
22+
def new(tensor) when is_struct(tensor, Nx.Tensor) do
23+
if Nx.rank(tensor) != 1 do
24+
raise ArgumentError, "expected rank to be 1"
25+
end
26+
27+
dim = tensor |> Nx.size()
28+
bin = tensor |> Nx.as_type(:f16) |> Nx.to_binary() |> f16_native_to_big()
29+
from_binary(<<dim::unsigned-16, 0::unsigned-16, bin::binary>>)
30+
end
31+
32+
defp f16_native_to_big(binary) do
33+
if System.endianness() == :big do
34+
binary
35+
else
36+
for <<n::float-16-little <- binary>>, into: "", do: <<n::float-16-big>>
37+
end
38+
end
39+
end
40+
41+
@doc """
42+
Creates a new half vector from its binary representation
43+
"""
44+
def from_binary(binary) when is_binary(binary) do
45+
%Pgvector.HalfVector{data: binary}
46+
end
47+
end
48+
49+
defimpl Inspect, for: Pgvector.HalfVector do
50+
import Inspect.Algebra
51+
52+
def inspect(vec, opts) do
53+
concat(["Pgvector.HalfVector.new(", Inspect.List.inspect(Pgvector.to_list(vec), opts), ")"])
54+
end
55+
end

test/ecto_test.exs

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ defmodule Item do
33

44
schema "ecto_items" do
55
field :embedding, Pgvector.Ecto.Vector
6+
field :half_embedding, Pgvector.Ecto.HalfVector
67
end
78
end
89

@@ -15,43 +16,69 @@ defmodule EctoTest do
1516
setup_all do
1617
Ecto.Adapters.SQL.query!(Repo, "CREATE EXTENSION IF NOT EXISTS vector", [])
1718
Ecto.Adapters.SQL.query!(Repo, "DROP TABLE IF EXISTS ecto_items", [])
18-
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3))", [])
19+
Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE ecto_items (id bigserial primary key, embedding vector(3), half_embedding halfvec(3))", [])
1920
create_items()
2021
:ok
2122
end
2223

2324
defp create_items do
24-
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1])})
25-
Repo.insert(%Item{embedding: [2, 2, 3]})
26-
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32)})
25+
Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1]), half_embedding: Pgvector.HalfVector.new([1, 1, 1])})
26+
Repo.insert(%Item{embedding: [2, 2, 3], half_embedding: [2, 2, 3]})
27+
Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32), half_embedding: Nx.tensor([1, 1, 2], type: :f16)})
2728
end
2829

29-
test "l2 distance" do
30+
test "vector l2 distance" do
3031
items = Repo.all(from i in Item, order_by: l2_distance(i.embedding, [1, 1, 1]), limit: 5)
3132
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
3233
assert Enum.map(items, fn v -> v.embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]]
3334
end
3435

35-
test "max inner product" do
36+
test "vector max inner product" do
3637
items = Repo.all(from i in Item, order_by: max_inner_product(i.embedding, [1, 1, 1]), limit: 5)
3738
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
3839
end
3940

40-
test "cosine distance" do
41+
test "vector cosine distance" do
4142
items = Repo.all(from i in Item, order_by: cosine_distance(i.embedding, [1, 1, 1]), limit: 5)
4243
assert Enum.map(items, fn v -> v.id end) == [1, 2, 3]
4344
end
4445

45-
test "cosine similarity" do
46+
test "vector cosine similarity" do
4647
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.embedding, [1, 1, 1])), limit: 5)
4748
assert Enum.map(items, fn v -> v.id end) == [3, 2, 1]
4849
end
4950

50-
test "l1 distance" do
51+
test "vector l1 distance" do
5152
items = Repo.all(from i in Item, order_by: l1_distance(i.embedding, [1, 1, 1]), limit: 5)
5253
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
5354
end
5455

56+
test "halfvec l2 distance" do
57+
items = Repo.all(from i in Item, order_by: l2_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
58+
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
59+
assert Enum.map(items, fn v -> v.half_embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]]
60+
end
61+
62+
test "halfvec max inner product" do
63+
items = Repo.all(from i in Item, order_by: max_inner_product(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
64+
assert Enum.map(items, fn v -> v.id end) == [2, 3, 1]
65+
end
66+
67+
test "halfvec cosine distance" do
68+
items = Repo.all(from i in Item, order_by: cosine_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
69+
assert Enum.map(items, fn v -> v.id end) == [1, 2, 3]
70+
end
71+
72+
test "halfvec cosine similarity" do
73+
items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1]))), limit: 5)
74+
assert Enum.map(items, fn v -> v.id end) == [3, 2, 1]
75+
end
76+
77+
test "halfvec l1 distance" do
78+
items = Repo.all(from i in Item, order_by: l1_distance(i.half_embedding, Pgvector.HalfVector.new([1, 1, 1])), limit: 5)
79+
assert Enum.map(items, fn v -> v.id end) == [1, 3, 2]
80+
end
81+
5582
test "cast" do
5683
embedding = [1, 1, 1]
5784
items = Repo.all(from i in Item, where: i.embedding == ^embedding)

test/half_vector_test.exs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
defmodule HalfVectorTest do
2+
use ExUnit.Case
3+
4+
test "half vector" do
5+
vector = Pgvector.HalfVector.new([1, 2, 3])
6+
assert vector == vector |> Pgvector.HalfVector.new()
7+
end
8+
9+
test "list" do
10+
list = [1.0, 2.0, 3.0]
11+
assert list == list |> Pgvector.HalfVector.new() |> Pgvector.to_list()
12+
end
13+
14+
test "tensor" do
15+
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f16)
16+
assert tensor == tensor |> Pgvector.HalfVector.new() |> Pgvector.to_tensor()
17+
end
18+
19+
test "inspect" do
20+
vector = Pgvector.HalfVector.new([1, 2, 3])
21+
assert "Pgvector.HalfVector.new([1.0, 2.0, 3.0])" == inspect(vector)
22+
end
23+
24+
test "equals" do
25+
assert Pgvector.HalfVector.new([1, 2, 3]) == Pgvector.HalfVector.new([1, 2, 3])
26+
refute Pgvector.HalfVector.new([1, 2, 3]) == Pgvector.HalfVector.new([1, 2, 4])
27+
end
28+
end

0 commit comments

Comments
 (0)