15
15
//! let blobs = generator::make_blobs(100, 2, 3);
16
16
//! let x = DenseMatrix::from_vec(blobs.num_samples, blobs.num_features, &blobs.data);
17
17
//! // Fit the algorithm and predict cluster labels
18
- //! let labels = DBSCAN::fit(&x, Distances::euclidian(),
19
- //! DBSCANParameters::default().with_eps(3.0)).
18
+ //! let labels = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)).
20
19
//! and_then(|dbscan| dbscan.predict(&x));
21
20
//!
22
21
//! println!("{:?}", labels);
@@ -33,9 +32,11 @@ use std::iter::Sum;
33
32
use serde:: { Deserialize , Serialize } ;
34
33
35
34
use crate :: algorithm:: neighbour:: { KNNAlgorithm , KNNAlgorithmName } ;
35
+ use crate :: api:: { Predictor , UnsupervisedEstimator } ;
36
36
use crate :: error:: Failed ;
37
37
use crate :: linalg:: { row_iter, Matrix } ;
38
- use crate :: math:: distance:: Distance ;
38
+ use crate :: math:: distance:: euclidian:: Euclidian ;
39
+ use crate :: math:: distance:: { Distance , Distances } ;
39
40
use crate :: math:: num:: RealNumber ;
40
41
use crate :: tree:: decision_tree_classifier:: which_max;
41
42
@@ -50,7 +51,11 @@ pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
50
51
51
52
#[ derive( Debug , Clone ) ]
52
53
/// DBSCAN clustering algorithm parameters
53
- pub struct DBSCANParameters < T : RealNumber > {
54
+ pub struct DBSCANParameters < T : RealNumber , D : Distance < Vec < T > , T > > {
55
+ /// a function that defines a distance between each pair of point in training data.
56
+ /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
57
+ /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
58
+ pub distance : D ,
54
59
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
55
60
pub min_samples : usize ,
56
61
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
@@ -59,7 +64,18 @@ pub struct DBSCANParameters<T: RealNumber> {
59
64
pub algorithm : KNNAlgorithmName ,
60
65
}
61
66
62
- impl < T : RealNumber > DBSCANParameters < T > {
67
+ impl < T : RealNumber , D : Distance < Vec < T > , T > > DBSCANParameters < T , D > {
68
+ /// a function that defines a distance between each pair of point in training data.
69
+ /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
70
+ /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
71
+ pub fn with_distance < DD : Distance < Vec < T > , T > > ( self , distance : DD ) -> DBSCANParameters < T , DD > {
72
+ DBSCANParameters {
73
+ distance,
74
+ min_samples : self . min_samples ,
75
+ eps : self . eps ,
76
+ algorithm : self . algorithm ,
77
+ }
78
+ }
63
79
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
64
80
pub fn with_min_samples ( mut self , min_samples : usize ) -> Self {
65
81
self . min_samples = min_samples;
@@ -86,25 +102,41 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
86
102
}
87
103
}
88
104
89
- impl < T : RealNumber > Default for DBSCANParameters < T > {
105
+ impl < T : RealNumber > Default for DBSCANParameters < T , Euclidian > {
90
106
fn default ( ) -> Self {
91
107
DBSCANParameters {
108
+ distance : Distances :: euclidian ( ) ,
92
109
min_samples : 5 ,
93
110
eps : T :: half ( ) ,
94
111
algorithm : KNNAlgorithmName :: CoverTree ,
95
112
}
96
113
}
97
114
}
98
115
116
+ impl < T : RealNumber + Sum , M : Matrix < T > , D : Distance < Vec < T > , T > >
117
+ UnsupervisedEstimator < M , DBSCANParameters < T , D > > for DBSCAN < T , D >
118
+ {
119
+ fn fit ( x : & M , parameters : DBSCANParameters < T , D > ) -> Result < Self , Failed > {
120
+ DBSCAN :: fit ( x, parameters)
121
+ }
122
+ }
123
+
124
+ impl < T : RealNumber , M : Matrix < T > , D : Distance < Vec < T > , T > > Predictor < M , M :: RowVector >
125
+ for DBSCAN < T , D >
126
+ {
127
+ fn predict ( & self , x : & M ) -> Result < M :: RowVector , Failed > {
128
+ self . predict ( x)
129
+ }
130
+ }
131
+
99
132
impl < T : RealNumber + Sum , D : Distance < Vec < T > , T > > DBSCAN < T , D > {
100
133
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
101
134
/// * `data` - training instances to cluster
102
135
/// * `k` - number of clusters
103
136
/// * `parameters` - cluster parameters
104
137
pub fn fit < M : Matrix < T > > (
105
138
x : & M ,
106
- distance : D ,
107
- parameters : DBSCANParameters < T > ,
139
+ parameters : DBSCANParameters < T , D > ,
108
140
) -> Result < DBSCAN < T , D > , Failed > {
109
141
if parameters. min_samples < 1 {
110
142
return Err ( Failed :: fit ( & "Invalid minPts" . to_string ( ) ) ) ;
@@ -121,7 +153,9 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
121
153
let n = x. shape ( ) . 0 ;
122
154
let mut y = vec ! [ unassigned; n] ;
123
155
124
- let algo = parameters. algorithm . fit ( row_iter ( x) . collect ( ) , distance) ?;
156
+ let algo = parameters
157
+ . algorithm
158
+ . fit ( row_iter ( x) . collect ( ) , parameters. distance ) ?;
125
159
126
160
for ( i, e) in row_iter ( x) . enumerate ( ) {
127
161
if y[ i] == unassigned {
@@ -195,7 +229,6 @@ mod tests {
195
229
use super :: * ;
196
230
use crate :: linalg:: naive:: dense_matrix:: DenseMatrix ;
197
231
use crate :: math:: distance:: euclidian:: Euclidian ;
198
- use crate :: math:: distance:: Distances ;
199
232
200
233
#[ test]
201
234
fn fit_predict_dbscan ( ) {
@@ -215,16 +248,7 @@ mod tests {
215
248
216
249
let expected_labels = vec ! [ 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , -1.0 ] ;
217
250
218
- let dbscan = DBSCAN :: fit (
219
- & x,
220
- Distances :: euclidian ( ) ,
221
- DBSCANParameters {
222
- min_samples : 5 ,
223
- eps : 1.0 ,
224
- algorithm : KNNAlgorithmName :: CoverTree ,
225
- } ,
226
- )
227
- . unwrap ( ) ;
251
+ let dbscan = DBSCAN :: fit ( & x, DBSCANParameters :: default ( ) . with_eps ( 1.0 ) ) . unwrap ( ) ;
228
252
229
253
let predicted_labels = dbscan. predict ( & x) . unwrap ( ) ;
230
254
@@ -256,7 +280,7 @@ mod tests {
256
280
& [ 5.2 , 2.7 , 3.9 , 1.4 ] ,
257
281
] ) ;
258
282
259
- let dbscan = DBSCAN :: fit ( & x, Distances :: euclidian ( ) , Default :: default ( ) ) . unwrap ( ) ;
283
+ let dbscan = DBSCAN :: fit ( & x, Default :: default ( ) ) . unwrap ( ) ;
260
284
261
285
let deserialized_dbscan: DBSCAN < f64 , Euclidian > =
262
286
serde_json:: from_str ( & serde_json:: to_string ( & dbscan) . unwrap ( ) ) . unwrap ( ) ;
0 commit comments