Skip to content

Commit 34a4d71

Browse files
committed
Improved SparseVector tests
1 parent 76a0a70 commit 34a4d71

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

test/sparse_vector_test.exs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@ defmodule SparseVectorTest do
22
use ExUnit.Case
33

44
test "sparse vector" do
5-
vector = Pgvector.SparseVector.new([1, 2, 3])
5+
vector = Pgvector.SparseVector.new([1, 0, 2, 0, 3, 0])
66
assert vector == vector |> Pgvector.SparseVector.new()
77
end
88

99
test "list" do
10-
list = [1.0, 2.0, 3.0]
10+
list = [1.0, 0.0, 2.0, 0.0, 3.0, 0.0]
1111
assert list == list |> Pgvector.SparseVector.new() |> Pgvector.to_list()
1212
end
1313

1414
test "tensor" do
15-
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32)
15+
tensor = Nx.tensor([1.0, 0.0, 2.0, 0.0, 3.0, 0.0], type: :f32)
1616
assert tensor == tensor |> Pgvector.SparseVector.new() |> Pgvector.to_tensor()
1717
end
1818

@@ -22,23 +22,23 @@ defmodule SparseVectorTest do
2222
end
2323

2424
test "dimensions" do
25-
vector = Pgvector.SparseVector.new([1, 2, 3])
26-
assert 3 == vector |> Pgvector.SparseVector.dimensions()
25+
vector = Pgvector.SparseVector.new([1, 0, 2, 0, 3, 0])
26+
assert 6 == vector |> Pgvector.SparseVector.dimensions()
2727
end
2828

2929
test "indices" do
30-
vector = Pgvector.SparseVector.new([1, 2, 3])
31-
assert [0, 1, 2] == vector |> Pgvector.SparseVector.indices()
30+
vector = Pgvector.SparseVector.new([1, 0, 2, 0, 3, 0])
31+
assert [0, 2, 4] == vector |> Pgvector.SparseVector.indices()
3232
end
3333

3434
test "values" do
35-
vector = Pgvector.SparseVector.new([1, 2, 3])
35+
vector = Pgvector.SparseVector.new([1, 0, 2, 0, 3, 0])
3636
assert [1, 2, 3] == vector |> Pgvector.SparseVector.values()
3737
end
3838

3939
test "inspect" do
40-
vector = Pgvector.SparseVector.new([1, 2, 3])
41-
assert "Pgvector.SparseVector.new(%{0 => 1.0, 1 => 2.0, 2 => 3.0}, 3)" == inspect(vector)
40+
vector = Pgvector.SparseVector.new([1, 0, 2, 0, 3, 0])
41+
assert "Pgvector.SparseVector.new(%{0 => 1.0, 2 => 2.0, 4 => 3.0}, 6)" == inspect(vector)
4242
end
4343

4444
test "equals" do

0 commit comments

Comments
 (0)