Skip to content

Commit 71b7a1a

Browse files
authored
fix: elkan means impl for some vector types (#472)
Signed-off-by: usamoi <usamoi@outlook.com>
1 parent d81473c commit 71b7a1a

File tree

3 files changed

+25
-173
lines changed

3 files changed

+25
-173
lines changed

crates/elkan_k_means/src/operator.rs

Lines changed: 13 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@ pub trait OperatorElkanKMeans: Operator {
77
type VectorNormalized: VectorOwned;
88

99
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]);
10-
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Self::VectorNormalized;
1110
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32;
12-
fn elkan_k_means_distance2(
13-
lhs: <Self::VectorNormalized as VectorOwned>::Borrowed<'_>,
14-
rhs: &[Scalar<Self>],
15-
) -> F32;
1611
}
1712

1813
impl OperatorElkanKMeans for BVecf32Cos {
@@ -22,17 +17,9 @@ impl OperatorElkanKMeans for BVecf32Cos {
2217
vecf32::l2_normalize(vector)
2318
}
2419

25-
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
26-
bvecf32::l2_normalize(vector)
27-
}
28-
2920
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
3021
vecf32::dot(lhs, rhs).acos()
3122
}
32-
33-
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
34-
vecf32::dot(lhs.slice(), rhs).acos()
35-
}
3623
}
3724

3825
impl OperatorElkanKMeans for BVecf32Dot {
@@ -42,17 +29,9 @@ impl OperatorElkanKMeans for BVecf32Dot {
4229
vecf32::l2_normalize(vector)
4330
}
4431

45-
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
46-
bvecf32::l2_normalize(vector)
47-
}
48-
4932
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
5033
vecf32::dot(lhs, rhs).acos()
5134
}
52-
53-
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
54-
vecf32::dot(lhs.slice(), rhs).acos()
55-
}
5635
}
5736

5837
impl OperatorElkanKMeans for BVecf32Jaccard {
@@ -62,37 +41,19 @@ impl OperatorElkanKMeans for BVecf32Jaccard {
6241
vecf32::l2_normalize(vector)
6342
}
6443

65-
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
66-
Vecf32Owned::new(vector.to_vec())
67-
}
68-
6944
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
7045
vecf32::sl2(lhs, rhs).sqrt()
7146
}
72-
73-
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
74-
vecf32::sl2(lhs.slice(), rhs).sqrt()
75-
}
7647
}
7748

7849
impl OperatorElkanKMeans for BVecf32L2 {
7950
type VectorNormalized = Vecf32Owned;
8051

81-
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
82-
vecf32::l2_normalize(vector)
83-
}
84-
85-
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
86-
Vecf32Owned::new(vector.to_vec())
87-
}
52+
fn elkan_k_means_normalize(_: &mut [Scalar<Self>]) {}
8853

8954
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
9055
vecf32::sl2(lhs, rhs).sqrt()
9156
}
92-
93-
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
94-
vecf32::sl2(lhs.slice(), rhs).sqrt()
95-
}
9657
}
9758

9859
impl OperatorElkanKMeans for SVecf32Cos {
@@ -102,19 +63,9 @@ impl OperatorElkanKMeans for SVecf32Cos {
10263
vecf32::l2_normalize(vector)
10364
}
10465

105-
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> SVecf32Owned {
106-
let mut vector = vector.for_own();
107-
svecf32::l2_normalize(&mut vector);
108-
vector
109-
}
110-
11166
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
11267
vecf32::dot(lhs, rhs).acos()
11368
}
114-
115-
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
116-
svecf32::dot_2(lhs, rhs).acos()
117-
}
11869
}
11970

12071
impl OperatorElkanKMeans for SVecf32Dot {
@@ -124,161 +75,86 @@ impl OperatorElkanKMeans for SVecf32Dot {
12475
vecf32::l2_normalize(vector)
12576
}
12677

127-
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> SVecf32Owned {
128-
let mut vector = vector.for_own();
129-
svecf32::l2_normalize(&mut vector);
130-
vector
131-
}
132-
13378
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
13479
vecf32::dot(lhs, rhs).acos()
13580
}
136-
137-
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
138-
svecf32::dot_2(lhs, rhs).acos()
139-
}
14081
}
14182

14283
impl OperatorElkanKMeans for SVecf32L2 {
14384
type VectorNormalized = Self::VectorOwned;
14485

14586
fn elkan_k_means_normalize(_: &mut [Scalar<Self>]) {}
14687

147-
fn elkan_k_means_normalize2(vector: SVecf32Borrowed<'_>) -> SVecf32Owned {
148-
vector.for_own()
149-
}
150-
15188
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
15289
vecf32::sl2(lhs, rhs).sqrt()
15390
}
154-
155-
fn elkan_k_means_distance2(lhs: SVecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
156-
svecf32::sl2_2(lhs, rhs).sqrt()
157-
}
15891
}
15992

16093
impl OperatorElkanKMeans for Vecf16Cos {
16194
type VectorNormalized = Self::VectorOwned;
16295

163-
fn elkan_k_means_normalize(vector: &mut [F16]) {
96+
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
16497
vecf16::l2_normalize(vector)
16598
}
166-
167-
fn elkan_k_means_normalize2(vector: Vecf16Borrowed<'_>) -> Vecf16Owned {
168-
let mut vector = vector.for_own();
169-
vecf16::l2_normalize(vector.slice_mut());
170-
vector
171-
}
172-
173-
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
99+
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
174100
vecf16::dot(lhs, rhs).acos()
175101
}
176-
177-
fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 {
178-
vecf16::dot(lhs.slice(), rhs).acos()
179-
}
180102
}
181103

182104
impl OperatorElkanKMeans for Vecf16Dot {
183105
type VectorNormalized = Self::VectorOwned;
184106

185-
fn elkan_k_means_normalize(vector: &mut [F16]) {
107+
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
186108
vecf16::l2_normalize(vector)
187109
}
188110

189-
fn elkan_k_means_normalize2(vector: Vecf16Borrowed<'_>) -> Vecf16Owned {
190-
let mut vector = vector.for_own();
191-
vecf16::l2_normalize(vector.slice_mut());
192-
vector
193-
}
194-
195-
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
111+
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
196112
vecf16::dot(lhs, rhs).acos()
197113
}
198-
199-
fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 {
200-
vecf16::dot(lhs.slice(), rhs).acos()
201-
}
202114
}
203115

204116
impl OperatorElkanKMeans for Vecf16L2 {
205117
type VectorNormalized = Self::VectorOwned;
206118

207-
fn elkan_k_means_normalize(_: &mut [F16]) {}
208-
209-
fn elkan_k_means_normalize2(vector: Vecf16Borrowed<'_>) -> Vecf16Owned {
210-
vector.for_own()
211-
}
119+
fn elkan_k_means_normalize(_: &mut [Scalar<Self>]) {}
212120

213-
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
121+
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
214122
vecf16::sl2(lhs, rhs).sqrt()
215123
}
216-
217-
fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 {
218-
vecf16::sl2(lhs.slice(), rhs).sqrt()
219-
}
220124
}
221125

222126
impl OperatorElkanKMeans for Vecf32Cos {
223127
type VectorNormalized = Self::VectorOwned;
224128

225-
fn elkan_k_means_normalize(vector: &mut [F32]) {
129+
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
226130
vecf32::l2_normalize(vector)
227131
}
228132

229-
fn elkan_k_means_normalize2(vector: Vecf32Borrowed<'_>) -> Vecf32Owned {
230-
let mut vector = vector.for_own();
231-
vecf32::l2_normalize(vector.slice_mut());
232-
vector
233-
}
234-
235-
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
133+
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
236134
vecf32::dot(lhs, rhs).acos()
237135
}
238-
239-
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[F32]) -> F32 {
240-
vecf32::dot(lhs.slice(), rhs).acos()
241-
}
242136
}
243137

244138
impl OperatorElkanKMeans for Vecf32Dot {
245139
type VectorNormalized = Self::VectorOwned;
246140

247-
fn elkan_k_means_normalize(vector: &mut [F32]) {
141+
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
248142
vecf32::l2_normalize(vector)
249143
}
250144

251-
fn elkan_k_means_normalize2(vector: Vecf32Borrowed<'_>) -> Vecf32Owned {
252-
let mut vector = vector.for_own();
253-
vecf32::l2_normalize(vector.slice_mut());
254-
vector
255-
}
256-
257-
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
145+
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
258146
vecf32::dot(lhs, rhs).acos()
259147
}
260-
261-
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[F32]) -> F32 {
262-
vecf32::dot(lhs.slice(), rhs).acos()
263-
}
264148
}
265149

266150
impl OperatorElkanKMeans for Vecf32L2 {
267151
type VectorNormalized = Self::VectorOwned;
268152

269-
fn elkan_k_means_normalize(_: &mut [F32]) {}
270-
271-
fn elkan_k_means_normalize2(vector: Vecf32Borrowed<'_>) -> Vecf32Owned {
272-
vector.for_own()
273-
}
153+
fn elkan_k_means_normalize(_: &mut [Scalar<Self>]) {}
274154

275155
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
276156
vecf32::sl2(lhs, rhs).sqrt()
277157
}
278-
279-
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
280-
vecf32::sl2(lhs.slice(), rhs).sqrt()
281-
}
282158
}
283159

284160
impl OperatorElkanKMeans for Veci8Cos {
@@ -288,17 +164,9 @@ impl OperatorElkanKMeans for Veci8Cos {
288164
vecf32::l2_normalize(vector)
289165
}
290166

291-
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Veci8Owned {
292-
vector.normalize()
293-
}
294-
295167
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
296168
vecf32::dot(lhs, rhs).acos()
297169
}
298-
299-
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
300-
veci8::dot_2(lhs, rhs).acos()
301-
}
302170
}
303171

304172
impl OperatorElkanKMeans for Veci8Dot {
@@ -308,35 +176,17 @@ impl OperatorElkanKMeans for Veci8Dot {
308176
vecf32::l2_normalize(vector)
309177
}
310178

311-
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Veci8Owned {
312-
vector.normalize()
313-
}
314-
315179
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
316180
vecf32::dot(lhs, rhs).acos()
317181
}
318-
319-
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
320-
veci8::dot_2(lhs, rhs).acos()
321-
}
322182
}
323183

324184
impl OperatorElkanKMeans for Veci8L2 {
325185
type VectorNormalized = Self::VectorOwned;
326186

327-
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
328-
vecf32::l2_normalize(vector)
329-
}
330-
331-
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Veci8Owned {
332-
vector.normalize()
333-
}
187+
fn elkan_k_means_normalize(_: &mut [Scalar<Self>]) {}
334188

335189
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
336190
vecf32::sl2(lhs, rhs).sqrt()
337191
}
338-
339-
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
340-
veci8::l2_2(lhs, rhs).sqrt()
341-
}
342192
}

crates/ivf/src/ivf_naive.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,11 @@ pub fn make<O: Op, S: Source<O>>(path: &Path, options: IndexOptions, source: &S)
146146
let mut idx = vec![0usize; n as usize];
147147
idx.par_iter_mut().enumerate().for_each(|(i, x)| {
148148
rayon::check();
149-
let vector = storage.vector(i as u32);
150-
let vector = O::elkan_k_means_normalize2(vector);
149+
let mut vector = storage.vector(i as u32).to_vec();
150+
O::elkan_k_means_normalize(&mut vector);
151151
let mut result = (F32::infinity(), 0);
152152
for i in 0..nlist as usize {
153-
let dis = O::elkan_k_means_distance2(vector.for_borrow(), &centroids[i]);
153+
let dis = O::elkan_k_means_distance(&vector, &centroids[i]);
154154
result = std::cmp::min(result, (dis, i));
155155
}
156156
*x = result.1;
@@ -242,11 +242,12 @@ pub fn basic<O: Op>(
242242
nprobe: u32,
243243
mut filter: impl Filter,
244244
) -> BinaryHeap<Reverse<Element>> {
245-
let target = O::elkan_k_means_normalize2(vector);
245+
let mut target = vector.to_vec();
246+
O::elkan_k_means_normalize(&mut target);
246247
let mut lists = Vec::with_capacity(mmap.nlist as usize);
247248
for i in 0..mmap.nlist {
248249
let centroid = mmap.centroids(i);
249-
let distance = O::elkan_k_means_distance2(target.for_borrow(), centroid);
250+
let distance = O::elkan_k_means_distance(&target, centroid);
250251
lists.push((distance, i));
251252
}
252253
if nprobe < mmap.nlist {
@@ -274,11 +275,12 @@ pub fn vbase<'a, O: Op>(
274275
nprobe: u32,
275276
mut filter: impl Filter + 'a,
276277
) -> (Vec<Element>, Box<(dyn Iterator<Item = Element> + 'a)>) {
277-
let target = O::elkan_k_means_normalize2(vector);
278+
let mut target = vector.to_vec();
279+
O::elkan_k_means_normalize(&mut target);
278280
let mut lists = Vec::with_capacity(mmap.nlist as usize);
279281
for i in 0..mmap.nlist {
280282
let centroid = mmap.centroids(i);
281-
let distance = O::elkan_k_means_distance2(target.for_borrow(), centroid);
283+
let distance = O::elkan_k_means_distance(&target, centroid);
282284
lists.push((distance, i));
283285
}
284286
if nprobe < mmap.nlist {

crates/ivf/src/ivf_pq.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,11 @@ pub fn make<O: Op, S: Source<O>>(path: &Path, options: IndexOptions, source: &S)
150150
let mut idx = vec![0usize; n as usize];
151151
idx.par_iter_mut().enumerate().for_each(|(i, x)| {
152152
rayon::check();
153-
let vector = storage.vector(i as u32);
154-
let vector = O::elkan_k_means_normalize2(vector);
153+
let mut vector = storage.vector(i as u32).to_vec();
154+
O::elkan_k_means_normalize(&mut vector);
155155
let mut result = (F32::infinity(), 0);
156156
for i in 0..nlist as usize {
157-
let dis = O::elkan_k_means_distance2(vector.for_borrow(), &centroids[i]);
157+
let dis = O::elkan_k_means_distance(&vector, &centroids[i]);
158158
result = std::cmp::min(result, (dis, i));
159159
}
160160
*x = result.1;

0 commit comments

Comments
 (0)