Skip to content

Commit 90004b5

Browse files
authored
25-1-2 Vector index fixes (#19634) (#20477)
2 parents 91312c6 + a927874 commit 90004b5

File tree

58 files changed

+3635
-1591
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+3635
-1591
lines changed

ydb/core/base/kmeans_clusters.cpp

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
#include "kmeans_clusters.h"
2+
3+
#include <library/cpp/dot_product/dot_product.h>
4+
#include <library/cpp/l1_distance/l1_distance.h>
5+
#include <library/cpp/l2_distance/l2_distance.h>
6+
7+
#include <span>
8+
9+
namespace NKikimr::NKMeans {
10+
11+
template <typename TRes>
12+
Y_PURE_FUNCTION TTriWayDotProduct<TRes> CosineImpl(const float* lhs, const float* rhs, size_t length)
13+
{
14+
auto r = TriWayDotProduct(lhs, rhs, length);
15+
return {static_cast<TRes>(r.LL), static_cast<TRes>(r.LR), static_cast<TRes>(r.RR)};
16+
}
17+
18+
template <typename TRes>
19+
Y_PURE_FUNCTION TTriWayDotProduct<TRes> CosineImpl(const i8* lhs, const i8* rhs, size_t length)
20+
{
21+
const auto ll = DotProduct(lhs, lhs, length);
22+
const auto lr = DotProduct(lhs, rhs, length);
23+
const auto rr = DotProduct(rhs, rhs, length);
24+
return {static_cast<TRes>(ll), static_cast<TRes>(lr), static_cast<TRes>(rr)};
25+
}
26+
27+
template <typename TRes>
28+
Y_PURE_FUNCTION TTriWayDotProduct<TRes> CosineImpl(const ui8* lhs, const ui8* rhs, size_t length)
29+
{
30+
const auto ll = DotProduct(lhs, lhs, length);
31+
const auto lr = DotProduct(lhs, rhs, length);
32+
const auto rr = DotProduct(rhs, rhs, length);
33+
return {static_cast<TRes>(ll), static_cast<TRes>(lr), static_cast<TRes>(rr)};
34+
}
35+
36+
// TODO(mbkkt) maybe compute floating sum in double? Needs benchmark
37+
template <typename TCoord>
38+
struct TMetric {
39+
using TCoord_ = TCoord;
40+
using TSum = std::conditional_t<std::is_floating_point_v<TCoord>, TCoord, i64>;
41+
};
42+
43+
template <typename TCoord>
44+
struct TCosineSimilarity : TMetric<TCoord> {
45+
using TSum = typename TMetric<TCoord>::TSum;
46+
// double used to avoid precision issues
47+
using TRes = double;
48+
49+
static TRes Init()
50+
{
51+
return std::numeric_limits<TRes>::max();
52+
}
53+
54+
static auto Distance(const char* cluster, const char* embedding, ui32 dimensions)
55+
{
56+
const auto r = CosineImpl<TRes>(reinterpret_cast<const TCoord*>(cluster),
57+
reinterpret_cast<const TCoord*>(embedding), dimensions);
58+
// sqrt(ll) * sqrt(rr) computed instead of sqrt(ll * rr) to avoid precision issues
59+
const auto norm = std::sqrt(r.LL) * std::sqrt(r.RR);
60+
const TRes similarity = norm != 0 ? static_cast<TRes>(r.LR) / static_cast<TRes>(norm) : 0;
61+
return -similarity;
62+
}
63+
};
64+
65+
template <typename TCoord>
66+
struct TL1Distance : TMetric<TCoord> {
67+
using TSum = typename TMetric<TCoord>::TSum;
68+
using TRes = std::conditional_t<std::is_floating_point_v<TCoord>, TCoord, ui64>;
69+
70+
static TRes Init()
71+
{
72+
return std::numeric_limits<TRes>::max();
73+
}
74+
75+
static auto Distance(const char* cluster, const char* embedding, ui32 dimensions)
76+
{
77+
const auto distance = L1Distance(reinterpret_cast<const TCoord*>(cluster),
78+
reinterpret_cast<const TCoord*>(embedding), dimensions);
79+
return distance;
80+
}
81+
};
82+
83+
template <typename TCoord>
84+
struct TL2Distance : TMetric<TCoord> {
85+
using TSum = typename TMetric<TCoord>::TSum;
86+
using TRes = std::conditional_t<std::is_floating_point_v<TCoord>, TCoord, ui64>;
87+
88+
static TRes Init()
89+
{
90+
return std::numeric_limits<TRes>::max();
91+
}
92+
93+
static auto Distance(const char* cluster, const char* embedding, ui32 dimensions)
94+
{
95+
const auto distance = L2SqrDistance(reinterpret_cast<const TCoord*>(cluster),
96+
reinterpret_cast<const TCoord*>(embedding), dimensions);
97+
return distance;
98+
}
99+
};
100+
101+
template <typename TCoord>
102+
struct TMaxInnerProductSimilarity : TMetric<TCoord> {
103+
using TSum = typename TMetric<TCoord>::TSum;
104+
using TRes = std::conditional_t<std::is_floating_point_v<TCoord>, TCoord, i64>;
105+
106+
static TRes Init()
107+
{
108+
return std::numeric_limits<TRes>::max();
109+
}
110+
111+
static auto Distance(const char* cluster, const char* embedding, ui32 dimensions)
112+
{
113+
const TRes similarity = DotProduct(reinterpret_cast<const TCoord*>(cluster),
114+
reinterpret_cast<const TCoord*>(embedding), dimensions);
115+
return -similarity;
116+
}
117+
};
118+
119+
template <typename TMetric>
120+
class TClusters: public IClusters {
121+
// If less than 1% of vectors are reassigned to new clusters we want to stop
122+
static constexpr double MinVectorsNeedsReassigned = 0.01;
123+
124+
using TCoord = TMetric::TCoord_;
125+
using TSum = TMetric::TSum;
126+
using TEmbedding = TVector<TSum>;
127+
128+
const ui32 Dimensions = 0;
129+
const ui32 MaxRounds = 0;
130+
const ui8 TypeByte = 0;
131+
132+
TVector<TString> Clusters;
133+
TVector<ui64> ClusterSizes;
134+
TVector<TEmbedding> NextClusters;
135+
TVector<ui64> NextClusterSizes;
136+
137+
ui32 Round = 0;
138+
139+
public:
140+
TClusters(ui32 dimensions, ui32 maxRounds, ui8 typeByte)
141+
: Dimensions(dimensions)
142+
, MaxRounds(maxRounds)
143+
, TypeByte(typeByte)
144+
{
145+
}
146+
147+
void SetRound(ui32 round) override {
148+
Round = round;
149+
}
150+
151+
TString Debug() const override {
152+
auto sb = TStringBuilder() << "K: " << Clusters.size();
153+
if (MaxRounds) {
154+
sb << " Round: " << Round << " / " << MaxRounds;
155+
}
156+
return sb;
157+
}
158+
159+
const TVector<TString>& GetClusters() const override {
160+
return Clusters;
161+
}
162+
163+
const TVector<ui64>& GetClusterSizes() const override {
164+
return ClusterSizes;
165+
}
166+
167+
const TVector<ui64>& GetNextClusterSizes() const override {
168+
return NextClusterSizes;
169+
}
170+
171+
virtual void SetClusterSize(ui32 num, ui64 size) override {
172+
ClusterSizes.at(num) = size;
173+
}
174+
175+
void Clear() override {
176+
Clusters.clear();
177+
ClusterSizes.clear();
178+
NextClusterSizes.clear();
179+
NextClusters.clear();
180+
Round = 0;
181+
}
182+
183+
bool SetClusters(TVector<TString> && newClusters) override {
184+
if (newClusters.size() == 0) {
185+
return false;
186+
}
187+
for (const auto& cluster: newClusters) {
188+
if (!IsExpectedSize(cluster)) {
189+
return false;
190+
}
191+
}
192+
Clusters = std::move(newClusters);
193+
ClusterSizes.clear();
194+
ClusterSizes.resize(Clusters.size());
195+
NextClusterSizes.clear();
196+
NextClusterSizes.resize(Clusters.size());
197+
NextClusters.clear();
198+
NextClusters.resize(Clusters.size());
199+
for (auto& aggregate : NextClusters) {
200+
aggregate.resize(Dimensions, 0);
201+
}
202+
return true;
203+
}
204+
205+
bool RecomputeClusters() override {
206+
ui64 vectorCount = 0;
207+
ui64 reassignedCount = 0;
208+
for (size_t i = 0; auto& aggregate : NextClusters) {
209+
auto newSize = NextClusterSizes[i];
210+
vectorCount += newSize;
211+
212+
auto clusterSize = ClusterSizes[i];
213+
reassignedCount += clusterSize < newSize ? newSize - clusterSize : 0;
214+
215+
if (newSize != 0) {
216+
this->Fill(Clusters[i], aggregate.data(), newSize);
217+
}
218+
++i;
219+
}
220+
221+
Y_ENSURE(reassignedCount <= vectorCount);
222+
if (Clusters.size() == 1) {
223+
return true;
224+
}
225+
226+
bool last = Round >= MaxRounds;
227+
if (!last && Round > 1) {
228+
const auto changes = static_cast<double>(reassignedCount) / static_cast<double>(vectorCount);
229+
last = changes < MinVectorsNeedsReassigned;
230+
}
231+
if (!last) {
232+
return false;
233+
}
234+
return true;
235+
}
236+
237+
void RemoveEmptyClusters() override {
238+
size_t w = 0;
239+
for (size_t r = 0; r < ClusterSizes.size(); ++r) {
240+
if (ClusterSizes[r] != 0) {
241+
ClusterSizes[w] = ClusterSizes[r];
242+
Clusters[w] = std::move(Clusters[r]);
243+
++w;
244+
}
245+
}
246+
ClusterSizes.erase(ClusterSizes.begin() + w, ClusterSizes.end());
247+
Clusters.erase(Clusters.begin() + w, Clusters.end());
248+
}
249+
250+
bool NextRound() override {
251+
bool isLast = RecomputeClusters();
252+
ClusterSizes = std::move(NextClusterSizes);
253+
RemoveEmptyClusters();
254+
if (isLast) {
255+
NextClusters.clear();
256+
return true;
257+
}
258+
++Round;
259+
NextClusterSizes.clear();
260+
NextClusterSizes.resize(Clusters.size());
261+
NextClusters.clear();
262+
NextClusters.resize(Clusters.size());
263+
for (auto& aggregate : NextClusters) {
264+
aggregate.resize(Dimensions, 0);
265+
}
266+
return false;
267+
}
268+
269+
std::optional<ui32> FindCluster(TArrayRef<const TCell> row, ui32 embeddingPos) override {
270+
Y_ENSURE(embeddingPos < row.size());
271+
const auto embedding = row.at(embeddingPos).AsRef();
272+
if (!IsExpectedSize(embedding)) {
273+
return {};
274+
}
275+
276+
auto min = TMetric::Init();
277+
std::optional<ui32> closest = {};
278+
for (size_t i = 0; const auto& cluster : Clusters) {
279+
auto distance = TMetric::Distance(cluster.data(), embedding.data(), Dimensions);
280+
if (distance < min) {
281+
min = distance;
282+
closest = i;
283+
}
284+
++i;
285+
}
286+
return closest;
287+
}
288+
289+
void AggregateToCluster(ui32 pos, const TArrayRef<const char>& embedding, ui64 weight) override {
290+
auto& aggregate = NextClusters.at(pos);
291+
auto* coords = aggregate.data();
292+
Y_ENSURE(IsExpectedSize(embedding));
293+
for (auto coord : this->GetCoords(embedding.data())) {
294+
*coords++ += (TSum)coord * weight;
295+
}
296+
NextClusterSizes.at(pos) += weight;
297+
}
298+
299+
bool IsExpectedSize(const TArrayRef<const char>& data) override {
300+
return data.size() == 1 + sizeof(TCoord) * Dimensions;
301+
}
302+
303+
private:
304+
auto GetCoords(const char* coords) {
305+
return std::span{reinterpret_cast<const TCoord*>(coords), Dimensions};
306+
}
307+
308+
auto GetData(char* data) {
309+
return std::span{reinterpret_cast<TCoord*>(data), Dimensions};
310+
}
311+
312+
void Fill(TString& d, TSum* embedding, ui64& c) {
313+
Y_ENSURE(c > 0);
314+
const auto count = static_cast<TSum>(c);
315+
auto data = GetData(d.MutRef().data());
316+
for (auto& coord : data) {
317+
coord = *embedding / count;
318+
embedding++;
319+
}
320+
}
321+
};
322+
323+
std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings& settings, ui32 maxRounds, TString& error) {
324+
if (settings.vector_dimension() < 1) {
325+
error = "Dimension of vector should be at least one";
326+
return nullptr;
327+
}
328+
329+
const ui8 typeVal = (ui8)settings.vector_type();
330+
const ui32 dim = settings.vector_dimension();
331+
332+
auto handleMetric = [&]<typename T>() -> std::unique_ptr<IClusters> {
333+
switch (settings.metric()) {
334+
case Ydb::Table::VectorIndexSettings::SIMILARITY_INNER_PRODUCT:
335+
return std::make_unique<TClusters<TMaxInnerProductSimilarity<T>>>(dim, maxRounds, typeVal);
336+
case Ydb::Table::VectorIndexSettings::SIMILARITY_COSINE:
337+
case Ydb::Table::VectorIndexSettings::DISTANCE_COSINE:
338+
// We don't need to have separate implementation for distance,
339+
// because clusters will be same as for similarity
340+
return std::make_unique<TClusters<TCosineSimilarity<T>>>(dim, maxRounds, typeVal);
341+
case Ydb::Table::VectorIndexSettings::DISTANCE_MANHATTAN:
342+
return std::make_unique<TClusters<TL1Distance<T>>>(dim, maxRounds, typeVal);
343+
case Ydb::Table::VectorIndexSettings::DISTANCE_EUCLIDEAN:
344+
return std::make_unique<TClusters<TL2Distance<T>>>(dim, maxRounds, typeVal);
345+
default:
346+
error = "Wrong similarity";
347+
break;
348+
}
349+
return nullptr;
350+
};
351+
352+
switch (settings.vector_type()) {
353+
case Ydb::Table::VectorIndexSettings::VECTOR_TYPE_FLOAT:
354+
return handleMetric.template operator()<float>();
355+
case Ydb::Table::VectorIndexSettings::VECTOR_TYPE_UINT8:
356+
return handleMetric.template operator()<ui8>();
357+
case Ydb::Table::VectorIndexSettings::VECTOR_TYPE_INT8:
358+
return handleMetric.template operator()<i8>();
359+
case Ydb::Table::VectorIndexSettings::VECTOR_TYPE_BIT:
360+
error = "TODO(mbkkt) bit vector type is not supported";
361+
break;
362+
default:
363+
error = "Wrong vector type";
364+
break;
365+
}
366+
367+
return nullptr;
368+
}
369+
370+
}

0 commit comments

Comments
 (0)