Skip to content

Commit efb990a

Browse files
committed
Added support for sparsevec type
1 parent 478966e commit efb990a

File tree

4 files changed

+147
-5
lines changed

4 files changed

+147
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## 0.2.0 (unreleased)
22

3-
- Added support for `halfvec` type
3+
- Added support for `halfvec` and `sparsevec` types
44
- Fixed error with MSVC
55

66
## 0.1.1 (2022-11-13)

include/pqxx.hpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#pragma once
88

99
#include "halfvec.hpp"
10+
#include "sparsevec.hpp"
1011
#include "vector.hpp"
1112
#include <pqxx/pqxx>
1213
#include <sstream>
@@ -101,4 +102,72 @@ template <> struct string_traits<pgvector::HalfVector> {
101102
static_cast<std::vector<float>>(value));
102103
}
103104
};
105+
106+
template <> std::string const type_name<pgvector::SparseVector>{"sparsevec"};
107+
108+
template <> struct nullness<pgvector::SparseVector> : pqxx::no_null<pgvector::SparseVector> {};
109+
110+
template <> struct string_traits<pgvector::SparseVector> {
111+
static constexpr bool converts_to_string{true};
112+
113+
// TODO add from_string
114+
static constexpr bool converts_from_string{false};
115+
116+
static zview to_buf(char* begin, char* end, const pgvector::SparseVector& value) {
117+
char *const next = into_buf(begin, end, value);
118+
return zview{begin, next - begin - 1};
119+
}
120+
121+
static char* into_buf(char* begin, char* end, const pgvector::SparseVector& value) {
122+
int dimensions = value.dimensions();
123+
auto indices = value.indices();
124+
auto values = value.values();
125+
size_t nnz = indices.size();
126+
127+
// important! size_buffer cannot throw an exception on overflow
128+
// so perform this check before writing any data
129+
if (nnz > 16000) {
130+
throw conversion_overrun{"sparsevec cannot have more than 16000 dimensions"};
131+
}
132+
133+
char *here = begin;
134+
*here++ = '{';
135+
136+
for (size_t i = 0; i < nnz; i++) {
137+
if (i != 0) {
138+
*here++ = ',';
139+
}
140+
141+
here = string_traits<int>::into_buf(here, end, indices[i] + 1) - 1;
142+
*here++ = ':';
143+
here = string_traits<float>::into_buf(here, end, values[i]) - 1;
144+
}
145+
146+
*here++ = '}';
147+
*here++ = '/';
148+
here = string_traits<int>::into_buf(here, end, dimensions) - 1;
149+
*here++ = '\0';
150+
151+
return here;
152+
}
153+
154+
static size_t size_buffer(const pgvector::SparseVector& value) noexcept {
155+
int dimensions = value.dimensions();
156+
auto indices = value.indices();
157+
auto values = value.values();
158+
size_t nnz = indices.size();
159+
160+
// cannot throw an exception here on overflow
161+
// so throw in into_buf
162+
163+
size_t size = 4; // {, }, /, and \0
164+
size += string_traits<int>::size_buffer(dimensions);
165+
for (size_t i = 0; i < nnz; i++) {
166+
size += 2; // : and ,
167+
size += string_traits<int>::size_buffer(indices[i]);
168+
size += string_traits<float>::size_buffer(values[i]);
169+
}
170+
return size;
171+
}
172+
};
104173
} // namespace pqxx

include/sparsevec.hpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*!
2+
* pgvector-cpp v0.1.1
3+
* https://github.com/pgvector/pgvector-cpp
4+
* MIT License
5+
*/
6+
7+
#pragma once
8+
9+
#include <ostream>
10+
#include <vector>
11+
12+
namespace pgvector {
13+
class SparseVector {
14+
public:
15+
SparseVector() = default;
16+
17+
SparseVector(int dimensions, const std::vector<int>& indices, const std::vector<float>& values) {
18+
if (values.size() != indices.size()) {
19+
throw std::invalid_argument("indices and values must be the same length");
20+
}
21+
dimensions_ = dimensions;
22+
indices_ = indices;
23+
values_ = values;
24+
}
25+
26+
SparseVector(const std::vector<float>& value) {
27+
dimensions_ = value.size();
28+
for (size_t i = 0; i < value.size(); i++) {
29+
float v = value[i];
30+
if (v != 0) {
31+
indices_.push_back(i);
32+
values_.push_back(v);
33+
}
34+
}
35+
}
36+
37+
int dimensions() const {
38+
return dimensions_;
39+
}
40+
41+
const std::vector<int>& indices() const {
42+
return indices_;
43+
}
44+
45+
const std::vector<float>& values() const {
46+
return values_;
47+
}
48+
49+
friend bool operator==(const SparseVector& lhs, const SparseVector& rhs) {
50+
return lhs.dimensions_ == rhs.dimensions_ && lhs.indices_ == rhs.indices_ && lhs.values_ == rhs.values_;
51+
}
52+
53+
friend std::ostream& operator<<(std::ostream& os, const SparseVector& value) {
54+
os << "{";
55+
for (size_t i = 0; i < value.indices_.size(); i++) {
56+
if (i > 0) {
57+
os << ",";
58+
}
59+
os << value.indices_[i] + 1;
60+
os << ":";
61+
os << value.values_[i];
62+
}
63+
os << "}/";
64+
os << value.dimensions_;
65+
return os;
66+
}
67+
68+
private:
69+
int dimensions_;
70+
std::vector<int> indices_;
71+
std::vector<float> values_;
72+
};
73+
} // namespace pgvector

test/pqxx_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,16 @@ void test_sparsevec(pqxx::connection &conn) {
7878
before_each(conn);
7979

8080
pqxx::work tx{conn};
81-
auto embedding = "{1:1,2:2,3:3}/3";
82-
auto embedding2 = "{1:4,2:5,3:6}/3";
81+
auto embedding = pgvector::SparseVector({1, 2, 3});
82+
auto embedding2 = pgvector::SparseVector({4, 5, 6});
8383
tx.exec_params("INSERT INTO items (sparse_embedding) VALUES ($1), ($2), ($3)",
8484
embedding, embedding2, std::nullopt);
8585

8686
pqxx::result res{tx.exec_params(
8787
"SELECT sparse_embedding FROM items ORDER BY sparse_embedding <-> $1", embedding2)};
8888
assert(res.size() == 3);
89-
assert(res[0][0].as<std::string>() == embedding2);
90-
assert(res[1][0].as<std::string>() == embedding);
89+
assert(res[0][0].as<std::string>() == "{1:4,2:5,3:6}/3");
90+
assert(res[1][0].as<std::string>() == "{1:1,2:2,3:3}/3");
9191
assert(!res[2][0].as<std::optional<std::string>>().has_value());
9292
tx.commit();
9393
}

0 commit comments

Comments
 (0)