Skip to content

Commit bb9a05b

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
fix: fixes a bug in DBSCAN, removes println's
1 parent c5a7bea commit bb9a05b

File tree

3 files changed

+44
-19
lines changed

3 files changed

+44
-19
lines changed

src/cluster/dbscan.rs

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -161,39 +161,60 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
161161
}
162162

163163
let mut k = 0;
164-
let unassigned = -2;
164+
let queued = -2;
165165
let outlier = -1;
166+
let undefined = -3;
166167

167168
let n = x.shape().0;
168-
let mut y = vec![unassigned; n];
169+
let mut y = vec![undefined; n];
169170

170171
let algo = parameters
171172
.algorithm
172173
.fit(row_iter(x).collect(), parameters.distance)?;
173174

174175
for (i, e) in row_iter(x).enumerate() {
175-
if y[i] == unassigned {
176+
if y[i] == undefined {
176177
let mut neighbors = algo.find_radius(&e, parameters.eps)?;
177178
if neighbors.len() < parameters.min_samples {
178179
y[i] = outlier;
179180
} else {
180181
y[i] = k;
182+
181183
for j in 0..neighbors.len() {
182-
if y[neighbors[j].0] == unassigned {
183-
y[neighbors[j].0] = k;
184+
if y[neighbors[j].0] == undefined {
185+
y[neighbors[j].0] = queued;
186+
}
187+
}
184188

185-
let mut secondary_neighbors =
186-
algo.find_radius(neighbors[j].2, parameters.eps)?;
189+
while !neighbors.is_empty() {
190+
let neighbor = neighbors.pop().unwrap();
191+
let index = neighbor.0;
187192

188-
if secondary_neighbors.len() >= parameters.min_samples {
189-
neighbors.append(&mut secondary_neighbors);
190-
}
193+
if y[index] == outlier {
194+
y[index] = k;
191195
}
192196

193-
if y[neighbors[j].0] == outlier {
194-
y[neighbors[j].0] = k;
197+
if y[index] == undefined || y[index] == queued {
198+
y[index] = k;
199+
200+
let secondary_neighbors =
201+
algo.find_radius(neighbor.2, parameters.eps)?;
202+
203+
if secondary_neighbors.len() >= parameters.min_samples {
204+
for j in 0..secondary_neighbors.len() {
205+
let label = y[secondary_neighbors[j].0];
206+
if label == undefined {
207+
y[secondary_neighbors[j].0] = queued;
208+
}
209+
210+
if label == undefined || label == outlier {
211+
neighbors.push(secondary_neighbors[j]);
212+
}
213+
}
214+
}
195215
}
196216
}
217+
197218
k += 1;
198219
}
199220
}
@@ -250,19 +271,25 @@ mod tests {
250271
&[1.0, 2.0],
251272
&[1.1, 2.1],
252273
&[0.9, 1.9],
253-
&[1.2, 1.2],
274+
&[1.2, 2.2],
254275
&[0.8, 1.8],
255276
&[2.0, 1.0],
256277
&[2.1, 1.1],
257-
&[2.2, 1.2],
258278
&[1.9, 0.9],
279+
&[2.2, 1.2],
259280
&[1.8, 0.8],
260281
&[3.0, 5.0],
261282
]);
262283

263284
let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0];
264285

265-
let dbscan = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(1.0)).unwrap();
286+
let dbscan = DBSCAN::fit(
287+
&x,
288+
DBSCANParameters::default()
289+
.with_eps(0.5)
290+
.with_min_samples(2),
291+
)
292+
.unwrap();
266293

267294
let predicted_labels = dbscan.predict(&x).unwrap();
268295

src/dataset/generator.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
5959
let linspace_out = linspace(0.0, 2.0 * std::f32::consts::PI, num_samples_out);
6060
let linspace_in = linspace(0.0, 2.0 * std::f32::consts::PI, num_samples_in);
6161

62-
println!("{:?}", linspace_out);
63-
println!("{:?}", linspace_in);
6462
let noise = Normal::new(0.0, noise).unwrap();
6563
let mut rng = rand::thread_rng();
6664

@@ -117,7 +115,6 @@ mod tests {
117115
#[test]
118116
fn test_make_circles() {
119117
let dataset = make_circles(10, 0.5, 0.05);
120-
println!("{:?}", dataset.as_matrix());
121118
assert_eq!(
122119
dataset.data.len(),
123120
dataset.num_features * dataset.num_samples

src/decomposition/svd.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
//! &[5.2, 2.7, 3.9, 1.4],
3535
//! ]);
3636
//!
37-
//! let svd = SVD::fit(&iris, SVDParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2
37+
//! let svd = SVD::fit(&iris, SVDParameters::default().
38+
//! with_n_components(2)).unwrap(); // Reduce number of features to 2
3839
//!
3940
//! let iris_reduced = svd.transform(&iris).unwrap();
4041
//!

0 commit comments

Comments
 (0)