Skip to content

Commit a2529fc

Browse files
committed
Extract KMeans TClusters into a separate file to share it with schemeshard (#19412)
1 parent f32ba67 commit a2529fc

File tree

8 files changed

+686
-680
lines changed

8 files changed

+686
-680
lines changed

ydb/core/base/kmeans_clusters.cpp

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
#include "kmeans_clusters.h"
2+
3+
#include <library/cpp/dot_product/dot_product.h>
4+
#include <library/cpp/l1_distance/l1_distance.h>
5+
#include <library/cpp/l2_distance/l2_distance.h>
6+
7+
#include <span>
8+
9+
namespace NKikimr::NKMeans {
10+
11+
template <typename TRes>
12+
Y_PURE_FUNCTION TTriWayDotProduct<TRes> CosineImpl(const float* lhs, const float* rhs, size_t length)
13+
{
14+
auto r = TriWayDotProduct(lhs, rhs, length);
15+
return {static_cast<TRes>(r.LL), static_cast<TRes>(r.LR), static_cast<TRes>(r.RR)};
16+
}
17+
18+
template <typename TRes>
19+
Y_PURE_FUNCTION TTriWayDotProduct<TRes> CosineImpl(const i8* lhs, const i8* rhs, size_t length)
20+
{
21+
const auto ll = DotProduct(lhs, lhs, length);
22+
const auto lr = DotProduct(lhs, rhs, length);
23+
const auto rr = DotProduct(rhs, rhs, length);
24+
return {static_cast<TRes>(ll), static_cast<TRes>(lr), static_cast<TRes>(rr)};
25+
}
26+
27+
template <typename TRes>
28+
Y_PURE_FUNCTION TTriWayDotProduct<TRes> CosineImpl(const ui8* lhs, const ui8* rhs, size_t length)
29+
{
30+
const auto ll = DotProduct(lhs, lhs, length);
31+
const auto lr = DotProduct(lhs, rhs, length);
32+
const auto rr = DotProduct(rhs, rhs, length);
33+
return {static_cast<TRes>(ll), static_cast<TRes>(lr), static_cast<TRes>(rr)};
34+
}
35+
36+
// TODO(mbkkt) maybe compute floating sum in double? Needs benchmark
37+
template <typename TCoord>
38+
struct TMetric {
39+
using TCoord_ = TCoord;
40+
using TSum = std::conditional_t<std::is_floating_point_v<TCoord>, TCoord, i64>;
41+
};
42+
43+
template <typename TCoord>
44+
struct TCosineSimilarity : TMetric<TCoord> {
45+
using TSum = typename TMetric<TCoord>::TSum;
46+
// double used to avoid precision issues
47+
using TRes = double;
48+
49+
static TRes Init()
50+
{
51+
return std::numeric_limits<TRes>::max();
52+
}
53+
54+
static auto Distance(const char* cluster, const char* embedding, ui32 dimensions)
55+
{
56+
const auto r = CosineImpl<TRes>(reinterpret_cast<const TCoord*>(cluster),
57+
reinterpret_cast<const TCoord*>(embedding), dimensions);
58+
// sqrt(ll) * sqrt(rr) computed instead of sqrt(ll * rr) to avoid precision issues
59+
const auto norm = std::sqrt(r.LL) * std::sqrt(r.RR);
60+
const TRes similarity = norm != 0 ? static_cast<TRes>(r.LR) / static_cast<TRes>(norm) : 0;
61+
return -similarity;
62+
}
63+
};
64+
65+
template <typename TCoord>
66+
struct TL1Distance : TMetric<TCoord> {
67+
using TSum = typename TMetric<TCoord>::TSum;
68+
using TRes = std::conditional_t<std::is_floating_point_v<TCoord>, TCoord, ui64>;
69+
70+
static TRes Init()
71+
{
72+
return std::numeric_limits<TRes>::max();
73+
}
74+
75+
static auto Distance(const char* cluster, const char* embedding, ui32 dimensions)
76+
{
77+
const auto distance = L1Distance(reinterpret_cast<const TCoord*>(cluster),
78+
reinterpret_cast<const TCoord*>(embedding), dimensions);
79+
return distance;
80+
}
81+
};
82+
83+
template <typename TCoord>
84+
struct TL2Distance : TMetric<TCoord> {
85+
using TSum = typename TMetric<TCoord>::TSum;
86+
using TRes = std::conditional_t<std::is_floating_point_v<TCoord>, TCoord, ui64>;
87+
88+
static TRes Init()
89+
{
90+
return std::numeric_limits<TRes>::max();
91+
}
92+
93+
static auto Distance(const char* cluster, const char* embedding, ui32 dimensions)
94+
{
95+
const auto distance = L2SqrDistance(reinterpret_cast<const TCoord*>(cluster),
96+
reinterpret_cast<const TCoord*>(embedding), dimensions);
97+
return distance;
98+
}
99+
};
100+
101+
template <typename TCoord>
102+
struct TMaxInnerProductSimilarity : TMetric<TCoord> {
103+
using TSum = typename TMetric<TCoord>::TSum;
104+
using TRes = std::conditional_t<std::is_floating_point_v<TCoord>, TCoord, i64>;
105+
106+
static TRes Init()
107+
{
108+
return std::numeric_limits<TRes>::max();
109+
}
110+
111+
static auto Distance(const char* cluster, const char* embedding, ui32 dimensions)
112+
{
113+
const TRes similarity = DotProduct(reinterpret_cast<const TCoord*>(cluster),
114+
reinterpret_cast<const TCoord*>(embedding), dimensions);
115+
return -similarity;
116+
}
117+
};
118+
119+
template <typename TMetric>
120+
class TClusters: public IClusters {
121+
// If less than 1% of vectors are reassigned to new clusters we want to stop
122+
static constexpr double MinVectorsNeedsReassigned = 0.01;
123+
124+
using TCoord = TMetric::TCoord_;
125+
using TSum = TMetric::TSum;
126+
using TEmbedding = TVector<TSum>;
127+
128+
ui32 InitK = 0;
129+
ui32 K = 0;
130+
const ui32 Dimensions = 0;
131+
132+
TVector<TString> Clusters;
133+
TVector<ui64> ClusterSizes;
134+
135+
struct TAggregatedCluster {
136+
TEmbedding Cluster;
137+
ui64 Size = 0;
138+
};
139+
TVector<TAggregatedCluster> AggregatedClusters;
140+
141+
ui32 Round = 0;
142+
ui32 MaxRounds = 0;
143+
144+
public:
145+
TClusters(ui32 dimensions)
146+
: Dimensions(dimensions)
147+
{
148+
}
149+
150+
void Init(ui32 k, ui32 maxRounds) override {
151+
InitK = k;
152+
K = k;
153+
MaxRounds = maxRounds;
154+
}
155+
156+
ui32 GetK() const override {
157+
return K;
158+
}
159+
160+
TString Debug() const override {
161+
if (!MaxRounds) {
162+
return TStringBuilder() << "K: " << K;
163+
}
164+
return TStringBuilder() << "K: " << K << " Round: " << Round << " / " << MaxRounds;
165+
}
166+
167+
const TVector<TString>& GetClusters() const override {
168+
return Clusters;
169+
}
170+
171+
void Clear() override {
172+
K = InitK;
173+
Clusters.clear();
174+
ClusterSizes.clear();
175+
AggregatedClusters.clear();
176+
Round = 0;
177+
}
178+
179+
bool SetClusters(TVector<TString> && newClusters) override {
180+
Clusters = newClusters;
181+
if (Clusters.size() == 0) {
182+
return false;
183+
}
184+
if (Clusters.size() < K) {
185+
// if this datashard have less than K valid embeddings for this parent
186+
// lets make single centroid for it
187+
K = 1;
188+
Clusters.resize(K);
189+
}
190+
if (!K) {
191+
K = InitK = newClusters.size();
192+
}
193+
Y_ENSURE(Clusters.size() == K);
194+
return true;
195+
}
196+
197+
void InitAggregatedClusters() override {
198+
AggregatedClusters.resize(K);
199+
ClusterSizes.resize(K, 0);
200+
for (auto& aggregate : AggregatedClusters) {
201+
aggregate.Cluster.resize(Dimensions, 0);
202+
}
203+
Round = 1;
204+
}
205+
206+
bool RecomputeClusters() override {
207+
Y_ENSURE(K >= 1);
208+
ui64 vectorCount = 0;
209+
ui64 reassignedCount = 0;
210+
for (size_t i = 0; auto& aggregate : AggregatedClusters) {
211+
vectorCount += aggregate.Size;
212+
213+
auto& clusterSize = ClusterSizes[i];
214+
reassignedCount += clusterSize < aggregate.Size ? aggregate.Size - clusterSize : 0;
215+
clusterSize = aggregate.Size;
216+
217+
if (aggregate.Size != 0) {
218+
this->Fill(Clusters[i], aggregate.Cluster.data(), aggregate.Size);
219+
Y_ENSURE(aggregate.Size == 0);
220+
}
221+
++i;
222+
}
223+
Y_ENSURE(vectorCount >= K);
224+
Y_ENSURE(reassignedCount <= vectorCount);
225+
if (K == 1) {
226+
return true;
227+
}
228+
229+
bool last = Round >= MaxRounds;
230+
if (!last && Round > 1) {
231+
const auto changes = static_cast<double>(reassignedCount) / static_cast<double>(vectorCount);
232+
last = changes < MinVectorsNeedsReassigned;
233+
}
234+
if (!last) {
235+
++Round;
236+
return false;
237+
}
238+
239+
size_t w = 0;
240+
for (size_t r = 0; r < ClusterSizes.size(); ++r) {
241+
if (ClusterSizes[r] != 0) {
242+
ClusterSizes[w] = ClusterSizes[r];
243+
Clusters[w] = std::move(Clusters[r]);
244+
++w;
245+
}
246+
}
247+
ClusterSizes.erase(ClusterSizes.begin() + w, ClusterSizes.end());
248+
Clusters.erase(Clusters.begin() + w, Clusters.end());
249+
return true;
250+
}
251+
252+
std::optional<ui32> FindCluster(TArrayRef<const TCell> row, ui32 embeddingPos) override {
253+
Y_ENSURE(embeddingPos < row.size());
254+
const auto embedding = row.at(embeddingPos).AsRef();
255+
if (!IsExpectedSize(embedding)) {
256+
return {};
257+
}
258+
259+
auto min = TMetric::Init();
260+
std::optional<ui32> closest = {};
261+
for (size_t i = 0; const auto& cluster : Clusters) {
262+
auto distance = TMetric::Distance(cluster.data(), embedding.data(), Dimensions);
263+
if (distance < min) {
264+
min = distance;
265+
closest = i;
266+
}
267+
++i;
268+
}
269+
return closest;
270+
}
271+
272+
void AggregateToCluster(ui32 pos, const char* embedding) override {
273+
auto& aggregate = AggregatedClusters[pos];
274+
auto* coords = aggregate.Cluster.data();
275+
for (auto coord : this->GetCoords(embedding)) {
276+
*coords++ += coord;
277+
}
278+
++aggregate.Size;
279+
}
280+
281+
bool IsExpectedSize(TArrayRef<const char> data) override {
282+
return data.size() == 1 + sizeof(TCoord) * Dimensions;
283+
}
284+
285+
private:
286+
auto GetCoords(const char* coords) {
287+
return std::span{reinterpret_cast<const TCoord*>(coords), Dimensions};
288+
}
289+
290+
auto GetData(char* data) {
291+
return std::span{reinterpret_cast<TCoord*>(data), Dimensions};
292+
}
293+
294+
void Fill(TString& d, TSum* embedding, ui64& c) {
295+
Y_ENSURE(c > 0);
296+
const auto count = static_cast<TSum>(std::exchange(c, 0));
297+
auto data = GetData(d.MutRef().data());
298+
for (auto& coord : data) {
299+
coord = *embedding / count;
300+
*embedding++ = 0;
301+
}
302+
}
303+
};
304+
305+
std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings& settings, TString& error) {
306+
if (settings.vector_dimension() < 1) {
307+
error = "Dimension of vector should be at least one";
308+
return nullptr;
309+
}
310+
311+
const ui32 dim = settings.vector_dimension();
312+
313+
auto handleMetric = [&]<typename T>() -> std::unique_ptr<IClusters> {
314+
switch (settings.metric()) {
315+
case Ydb::Table::VectorIndexSettings::SIMILARITY_INNER_PRODUCT:
316+
return std::make_unique<TClusters<TMaxInnerProductSimilarity<T>>>(dim);
317+
case Ydb::Table::VectorIndexSettings::SIMILARITY_COSINE:
318+
case Ydb::Table::VectorIndexSettings::DISTANCE_COSINE:
319+
// We don't need to have separate implementation for distance,
320+
// because clusters will be same as for similarity
321+
return std::make_unique<TClusters<TCosineSimilarity<T>>>(dim);
322+
case Ydb::Table::VectorIndexSettings::DISTANCE_MANHATTAN:
323+
return std::make_unique<TClusters<TL1Distance<T>>>(dim);
324+
case Ydb::Table::VectorIndexSettings::DISTANCE_EUCLIDEAN:
325+
return std::make_unique<TClusters<TL2Distance<T>>>(dim);
326+
default:
327+
error = "Wrong similarity";
328+
break;
329+
}
330+
return nullptr;
331+
};
332+
333+
switch (settings.vector_type()) {
334+
case Ydb::Table::VectorIndexSettings::VECTOR_TYPE_FLOAT:
335+
return handleMetric.template operator()<float>();
336+
case Ydb::Table::VectorIndexSettings::VECTOR_TYPE_UINT8:
337+
return handleMetric.template operator()<ui8>();
338+
case Ydb::Table::VectorIndexSettings::VECTOR_TYPE_INT8:
339+
return handleMetric.template operator()<i8>();
340+
case Ydb::Table::VectorIndexSettings::VECTOR_TYPE_BIT:
341+
error = "TODO(mbkkt) bit vector type is not supported";
342+
break;
343+
default:
344+
error = "Wrong vector type";
345+
break;
346+
}
347+
348+
return nullptr;
349+
}
350+
351+
}

ydb/core/base/kmeans_clusters.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#pragma once
2+
3+
#include <ydb/core/scheme/scheme_tablecell.h>
4+
5+
#include <ydb/public/api/protos/ydb_table.pb.h>
6+
7+
namespace NKikimr::NKMeans {
8+
9+
class IClusters {
10+
public:
11+
virtual ~IClusters() = default;
12+
13+
virtual void Init(ui32 k, ui32 maxRounds) = 0;
14+
15+
virtual ui32 GetK() const = 0;
16+
17+
virtual TString Debug() const = 0;
18+
19+
virtual const TVector<TString>& GetClusters() const = 0;
20+
21+
virtual void Clear() = 0;
22+
23+
virtual bool SetClusters(TVector<TString> && newClusters) = 0;
24+
25+
virtual void InitAggregatedClusters() = 0;
26+
27+
virtual bool RecomputeClusters() = 0;
28+
29+
virtual std::optional<ui32> FindCluster(TArrayRef<const TCell> row, ui32 embeddingPos) = 0;
30+
31+
virtual void AggregateToCluster(ui32 pos, const char* embedding) = 0;
32+
33+
virtual bool IsExpectedSize(TArrayRef<const char> data) = 0;
34+
};
35+
36+
std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings& settings, TString& error);
37+
38+
}

0 commit comments

Comments
 (0)