1
- using Plots
1
+ using ManifoldLearning
2
+ import StatsBase
2
3
3
- import FLANN
4
- function knn_flann (X:: AbstractMatrix{T} , k:: Int = 12 ) where T<: Real
5
- params = FLANN. FLANNParameters ()
6
- E, D = FLANN. knn (X, X, k+ 1 , params)
7
- sqrt .(@view D[2 : end , :]), @view E[2 : end , :]
4
+ import NearestNeighbors
5
+ struct KDTree <: ManifoldLearning.AbstractNearestNeighbors
6
+ k:: Integer
7
+ fitted:: NearestNeighbors.KDTree
8
8
end
9
+ Base. show (io:: IO , NN:: KDTree ) = print (io, " KDTree(k=$(NN. k) )" )
10
+ StatsBase. fit (:: Type{KDTree} , X:: AbstractMatrix{T} , k:: Integer ) where {T<: Real } = KDTree (k, NearestNeighbors. KDTree (X))
11
+ function ManifoldLearning. knn (NN:: KDTree , X:: AbstractVecOrMat{T} ; self= false ) where {T<: Real }
12
+ m, n = size (X)
13
+ k = NN. k
14
+ @assert n > k " Number of observations must be more then $(k) "
9
15
10
- import NearestNeighbors
11
- function knn_nearestneighbors (X:: AbstractMatrix{T} , k:: Int = 12 ) where T<: Real
12
- n = size (X,2 )
13
- kdtree = NearestNeighbors. KDTree (X)
14
- idxs, dist = NearestNeighbors. knn (kdtree, X, k+ 1 , true )
16
+ idxs, dist = NearestNeighbors. knn (NN. fitted, X, k+ 1 , true )
15
17
D = Array {T} (undef, k, n)
16
18
E = Array {Int32} (undef, k, n)
17
19
for i in eachindex (idxs)
@@ -21,25 +23,18 @@ function knn_nearestneighbors(X::AbstractMatrix{T}, k::Int=12) where T<:Real
21
23
return D, E
22
24
end
23
25
24
- using ManifoldLearning
25
- k= 13
26
- X, L = ManifoldLearning. swiss_roll ()
27
-
28
- # Use default distance matrix based method to find nearest neighbors
29
- M1 = fit (Isomap, X)
30
- Y1 = transform (M1)
31
-
32
- # Use NearestNeighbors package to find nearest neighbors
33
- M2 = fit (Isomap, X, knn= knn_nearestneighbors)
34
- Y2 = transform (M2)
35
-
36
- # Use FLANN package to find nearest neighbors
37
- M3 = fit (Isomap, X, knn= knn_flann)
38
- Y3 = transform (M3)
39
-
40
- plot (
41
- plot (X[1 ,:], X[2 ,:], X[3 ,:], zcolor= L, m= 2 , t= :scatter3d , leg= false , title= " Swiss Roll" ),
42
- plot (Y1[1 ,:], Y1[2 ,:], c= L, m= 2 , t= :scatter , title= " Distance Matrix" ),
43
- plot (Y2[1 ,:], Y2[2 ,:], c= L, m= 2 , t= :scatter , title= " NearestNeighbors" ),
44
- plot (Y3[1 ,:], Y3[2 ,:], c= L, m= 2 , t= :scatter , title= " FLANN" )
45
- , leg= false )
26
+ import FLANN
27
+ struct FLANNTree{T <: Real } <: ManifoldLearning.AbstractNearestNeighbors
28
+ k:: Integer
29
+ index:: FLANN.FLANNIndex{T}
30
+ end
31
+ Base. show (io:: IO , NN:: FLANNTree ) = print (io, " FLANNTree(k=$(NN. k) )" )
32
+ function StatsBase. fit (:: Type{FLANNTree} , X:: AbstractMatrix{T} , k:: Integer ) where {T<: Real }
33
+ params = FLANNParameters ()
34
+ idx = FLANN. flann (X, params)
35
+ FLANNTree (k, idx)
36
+ end
37
+ function ManifoldLearning. knn (NN:: FLANNTree , X:: AbstractVecOrMat{T} ; self= false ) where {T<: Real }
38
+ E, D = FLANN. knn (NN. index, X, NN. k+ 1 )
39
+ sqrt .(@view D[2 : end , :]), @view E[2 : end , :]
40
+ end
0 commit comments