Skip to content

Commit 750015b

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
feat: + cluster metrics
1 parent 0803532 commit 750015b

File tree

15 files changed

+477
-16
lines changed

15 files changed

+477
-16
lines changed

src/algorithm/neighbour/bbd_tree.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ impl<T: RealNumber> BBDTree<T> {
7171
) -> T {
7272
let k = centroids.len();
7373

74-
counts.iter_mut().for_each(|x| *x = 0);
74+
counts.iter_mut().for_each(|v| *v = 0);
7575
let mut candidates = vec![0; k];
7676
for i in 0..k {
7777
candidates[i] = i;
78-
sums[i].iter_mut().for_each(|x| *x = T::zero());
78+
sums[i].iter_mut().for_each(|v| *v = T::zero());
7979
}
8080

8181
self.filter(
@@ -124,7 +124,7 @@ impl<T: RealNumber> BBDTree<T> {
124124
if !BBDTree::prune(
125125
&self.nodes[node].center,
126126
&self.nodes[node].radius,
127-
&centroids,
127+
centroids,
128128
closest,
129129
candidates[i],
130130
) {
@@ -135,7 +135,7 @@ impl<T: RealNumber> BBDTree<T> {
135135

136136
// Recurse if there's at least two
137137
if newk > 1 {
138-
let result = self.filter(
138+
return self.filter(
139139
self.nodes[node].lower.unwrap(),
140140
centroids,
141141
&mut new_candidates,
@@ -152,7 +152,6 @@ impl<T: RealNumber> BBDTree<T> {
152152
counts,
153153
membership,
154154
);
155-
return result;
156155
}
157156
}
158157

@@ -198,7 +197,7 @@ impl<T: RealNumber> BBDTree<T> {
198197
}
199198
}
200199

201-
return lhs >= T::two() * rhs;
200+
lhs >= T::two() * rhs
202201
}
203202

204203
fn build_node<M: Matrix<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
@@ -336,7 +335,7 @@ mod tests {
336335
use crate::linalg::naive::dense_matrix::DenseMatrix;
337336

338337
#[test]
339-
fn fit_predict_iris() {
338+
fn bbdtree_iris() {
340339
let data = DenseMatrix::from_2d_array(&[
341340
&[5.1, 3.5, 1.4, 0.2],
342341
&[4.9, 3.0, 1.4, 0.2],

src/cluster/kmeans.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,18 @@ impl<T: RealNumber + Sum> KMeans<T> {
189189
/// Predict clusters for `x`
190190
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
191191
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
192-
let (n, _) = x.shape();
192+
let (n, m) = x.shape();
193193
let mut result = M::zeros(1, n);
194194

195+
let mut row = vec![T::zero(); m];
196+
195197
for i in 0..n {
196198
let mut min_dist = T::max_value();
197199
let mut best_cluster = 0;
198200

199201
for j in 0..self.k {
200-
let dist = Euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
202+
x.copy_row_as_vec(i, &mut row);
203+
let dist = Euclidian::squared_distance(&row, &self.centroids[j]);
201204
if dist < min_dist {
202205
min_dist = dist;
203206
best_cluster = j;
@@ -211,19 +214,22 @@ impl<T: RealNumber + Sum> KMeans<T> {
211214

212215
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
213216
let mut rng = rand::thread_rng();
214-
let (n, _) = data.shape();
217+
let (n, m) = data.shape();
215218
let mut y = vec![0; n];
216219
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n));
217220

218221
let mut d = vec![T::max_value(); n];
219222

223+
let mut row = vec![T::zero(); m];
224+
220225
// pick the next center
221226
for j in 1..k {
222227
// Loop over the samples and compare them to the most recent center. Store
223228
// the distance from each sample to its closest center in scores.
224229
for i in 0..n {
225230
// compute the distance between this sample and the current center
226-
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), &centroid);
231+
data.copy_row_as_vec(i, &mut row);
232+
let dist = Euclidian::squared_distance(&row, &centroid);
227233

228234
if dist < d[i] {
229235
d[i] = dist;
@@ -237,20 +243,22 @@ impl<T: RealNumber + Sum> KMeans<T> {
237243
}
238244
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
239245
let mut cost = T::zero();
240-
let index = 0;
241-
for index in 0..n {
246+
let mut index = 0;
247+
while index < n {
242248
cost = cost + d[index];
243249
if cost >= cutoff {
244250
break;
245251
}
252+
index += 1;
246253
}
247254

248-
centroid = data.get_row_as_vec(index);
255+
data.copy_row_as_vec(index, &mut centroid);
249256
}
250257

251258
for i in 0..n {
259+
data.copy_row_as_vec(i, &mut row);
252260
// compute the distance between this sample and the current center
253-
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), &centroid);
261+
let dist = Euclidian::squared_distance(&row, &centroid);
254262

255263
if dist < d[i] {
256264
d[i] = dist;

src/dataset/digits.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//! # Optical Recognition of Handwritten Digits Data Set
2+
//!
3+
//! | Number of Instances | Number of Attributes | Missing Values? | Associated Tasks: |
4+
//! |-|-|-|-|
5+
//! | 1797 | 64 | No | Classification, Clusteing |
6+
//!
7+
//! [Digits dataset](https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits) contains normalized bitmaps of handwritten digits (0-9) from a preprinted form.
8+
//! This multivariate dataset is frequently used to demonstrate various machine learning algorithms.
9+
//!
10+
//! All input attributes are integers in the range 0..16.
11+
//!
12+
use crate::dataset::deserialize_data;
13+
use crate::dataset::Dataset;
14+
15+
/// Get dataset
16+
pub fn load_dataset() -> Dataset<f32, f32> {
17+
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("digits.xy"))
18+
{
19+
Err(why) => panic!("Can't deserialize digits.xy. {}", why),
20+
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
21+
};
22+
23+
Dataset {
24+
data: x,
25+
target: y,
26+
num_samples: num_samples,
27+
num_features: num_features,
28+
feature_names: vec![
29+
"sepal length (cm)",
30+
"sepal width (cm)",
31+
"petal length (cm)",
32+
"petal width (cm)",
33+
]
34+
.iter()
35+
.map(|s| s.to_string())
36+
.collect(),
37+
target_names: vec!["setosa", "versicolor", "virginica"]
38+
.iter()
39+
.map(|s| s.to_string())
40+
.collect(),
41+
description: "Digits dataset: https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits".to_string(),
42+
}
43+
}
44+
45+
#[cfg(test)]
46+
mod tests {
47+
48+
use super::super::*;
49+
use super::*;
50+
51+
#[test]
52+
#[ignore]
53+
fn refresh_digits_dataset() {
54+
// run this test to generate digits.xy file.
55+
let dataset = load_dataset();
56+
assert!(serialize_data(&dataset, "digits.xy").is_ok());
57+
}
58+
59+
#[test]
60+
fn digits_dataset() {
61+
let dataset = load_dataset();
62+
assert_eq!(dataset.data.len(), 1797 * 64);
63+
assert_eq!(dataset.target.len(), 1797);
64+
assert_eq!(dataset.num_features, 64);
65+
assert_eq!(dataset.num_samples, 1797);
66+
}
67+
}

src/dataset/digits.xy

456 KB
Binary file not shown.

src/dataset/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
pub mod boston;
55
pub mod breast_cancer;
66
pub mod diabetes;
7+
pub mod digits;
78
pub mod iris;
89

910
use crate::math::num::RealNumber;

src/linalg/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,20 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
110110
/// * `row` - row number
111111
fn get_row_as_vec(&self, row: usize) -> Vec<T>;
112112

113+
/// Copies a vector with elements of the `row`'th row into `result`
114+
/// * `row` - row number
115+
/// * `result` - receiver for the row
116+
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>);
117+
113118
/// Get a vector with elements of the `col`'th column
114119
/// * `col` - column number
115120
fn get_col_as_vec(&self, col: usize) -> Vec<T>;
116121

122+
/// Copies a vector with elements of the `col`'th column into `result`
123+
/// * `col` - column number
124+
/// * `result` - receiver for the col
125+
fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>);
126+
117127
/// Set an element at `col`, `row` to `x`
118128
fn set(&mut self, row: usize, col: usize, x: T);
119129

src/linalg/naive/dense_matrix.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ pub struct DenseMatrix<T: RealNumber> {
5454
values: Vec<T>,
5555
}
5656

57+
/// Column-major, dense matrix. See [Simple Dense Matrix](../index.html).
58+
#[derive(Debug)]
59+
pub struct DenseMatrixIterator<'a, T: RealNumber> {
60+
cur_c: usize,
61+
cur_r: usize,
62+
max_c: usize,
63+
max_r: usize,
64+
m: &'a DenseMatrix<T>,
65+
}
66+
5767
impl<T: RealNumber> fmt::Display for DenseMatrix<T> {
5868
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
5969
let mut rows: Vec<Vec<f64>> = Vec::new();
@@ -162,6 +172,36 @@ impl<T: RealNumber> DenseMatrix<T> {
162172
values: values,
163173
}
164174
}
175+
176+
/// Creates new column vector (_1xN_ matrix) from a vector.
177+
/// * `values` - values to initialize the matrix.
178+
pub fn iter<'a>(&'a self) -> DenseMatrixIterator<'a, T> {
179+
DenseMatrixIterator {
180+
cur_c: 0,
181+
cur_r: 0,
182+
max_c: self.ncols,
183+
max_r: self.nrows,
184+
m: &self,
185+
}
186+
}
187+
}
188+
189+
impl<'a, T: RealNumber> Iterator for DenseMatrixIterator<'a, T> {
190+
type Item = T;
191+
192+
fn next(&mut self) -> Option<T> {
193+
if self.cur_r * self.max_c + self.cur_c >= self.max_c * self.max_r {
194+
None
195+
} else {
196+
let v = self.m.get(self.cur_r, self.cur_c);
197+
self.cur_c += 1;
198+
if self.cur_c >= self.max_c {
199+
self.cur_c = 0;
200+
self.cur_r += 1;
201+
}
202+
Some(v)
203+
}
204+
}
165205
}
166206

167207
impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
@@ -339,6 +379,12 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
339379
result
340380
}
341381

382+
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>) {
383+
for c in 0..self.ncols {
384+
result[c] = self.get(row, c);
385+
}
386+
}
387+
342388
fn get_col_as_vec(&self, col: usize) -> Vec<T> {
343389
let mut result = vec![T::zero(); self.nrows];
344390
for r in 0..self.nrows {
@@ -347,6 +393,12 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
347393
result
348394
}
349395

396+
fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>) {
397+
for r in 0..self.nrows {
398+
result[r] = self.get(r, col);
399+
}
400+
}
401+
350402
fn set(&mut self, row: usize, col: usize, x: T) {
351403
self.values[col * self.nrows + row] = x;
352404
}
@@ -852,6 +904,13 @@ mod tests {
852904
);
853905
}
854906

907+
#[test]
908+
fn iter() {
909+
let vec = vec![1., 2., 3., 4., 5., 6.];
910+
let m = DenseMatrix::from_array(3, 2, &vec);
911+
assert_eq!(vec, m.iter().collect::<Vec<f32>>());
912+
}
913+
855914
#[test]
856915
fn v_stack() {
857916
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);

src/linalg/nalgebra_bindings.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,26 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
102102
self.row(row).iter().map(|v| *v).collect()
103103
}
104104

105+
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>) {
106+
let mut r = 0;
107+
for e in self.row(row).iter() {
108+
result[r] = *e;
109+
r += 1;
110+
}
111+
}
112+
105113
fn get_col_as_vec(&self, col: usize) -> Vec<T> {
106114
self.column(col).iter().map(|v| *v).collect()
107115
}
108116

117+
fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>) {
118+
let mut r = 0;
119+
for e in self.column(col).iter() {
120+
result[r] = *e;
121+
r += 1;
122+
}
123+
}
124+
109125
fn set(&mut self, row: usize, col: usize, x: T) {
110126
*self.get_mut((row, col)).unwrap() = x;
111127
}
@@ -563,6 +579,17 @@ mod tests {
563579
assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.));
564580
}
565581

582+
#[test]
583+
fn copy_row_col_as_vec() {
584+
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
585+
let mut v = vec![0f32; 3];
586+
587+
m.copy_row_as_vec(1, &mut v);
588+
assert_eq!(v, vec!(4., 5., 6.));
589+
m.copy_col_as_vec(1, &mut v);
590+
assert_eq!(v, vec!(2., 5., 8.));
591+
}
592+
566593
#[test]
567594
fn element_add_sub_mul_div() {
568595
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);

0 commit comments

Comments
 (0)