1- using  Plots
1+ using  ManifoldLearning
2+ import  StatsBase
23
3- import  FLANN
4- function  knn_flann (X:: AbstractMatrix{T} , k:: Int = 12 ) where  T<: Real 
5-     params =  FLANN. FLANNParameters ()
6-     E, D =  FLANN. knn (X, X, k+ 1 , params)
7-     sqrt .(@view  D[2 : end , :]), @view  E[2 : end , :]
4+ import  NearestNeighbors
5+ struct  KDTree <:  ManifoldLearning.AbstractNearestNeighbors 
6+     k:: Integer 
7+     fitted:: NearestNeighbors.KDTree 
88end 
9+ Base. show (io:: IO , NN:: KDTree ) =  print (io, " KDTree(k=$(NN. k) )"  )
10+ StatsBase. fit (:: Type{KDTree} , X:: AbstractMatrix{T} , k:: Integer ) where  {T<: Real } =  KDTree (k, NearestNeighbors. KDTree (X))
11+ function  ManifoldLearning. knn (NN:: KDTree , X:: AbstractVecOrMat{T} ; self= false ) where  {T<: Real }
12+     m, n =  size (X)
13+     k =  NN. k
14+     @assert  n >  k " Number of observations must be more then $(k) " 
915
10- import  NearestNeighbors
11- function  knn_nearestneighbors (X:: AbstractMatrix{T} , k:: Int = 12 ) where  T<: Real 
12-     n =  size (X,2 )
13-     kdtree =  NearestNeighbors. KDTree (X)
14-     idxs, dist =  NearestNeighbors. knn (kdtree, X, k+ 1 , true )
16+     idxs, dist =  NearestNeighbors. knn (NN. fitted, X, k+ 1 , true )
1517    D =  Array {T} (undef, k, n)
1618    E =  Array {Int32} (undef, k, n)
1719    for  i in  eachindex (idxs)
@@ -21,25 +23,18 @@ function knn_nearestneighbors(X::AbstractMatrix{T}, k::Int=12) where T<:Real
2123    return  D, E
2224end 
2325
24- using  ManifoldLearning
25- k= 13 
26- X, L =  ManifoldLearning. swiss_roll ()
27- 
28- #  Use default distance matrix based method to find nearest neighbors
29- M1 =  fit (Isomap, X)
30- Y1 =  transform (M1)
31- 
32- #  Use NearestNeighbors package to find nearest neighbors
33- M2 =  fit (Isomap, X, knn= knn_nearestneighbors)
34- Y2 =  transform (M2)
35- 
36- #  Use FLANN package to find nearest neighbors
37- M3 =  fit (Isomap, X, knn= knn_flann)
38- Y3 =  transform (M3)
39- 
40- plot (
41-     plot (X[1 ,:], X[2 ,:], X[3 ,:], zcolor= L, m= 2 , t= :scatter3d , leg= false , title= " Swiss Roll"  ),
42-     plot (Y1[1 ,:], Y1[2 ,:], c= L, m= 2 , t= :scatter , title= " Distance Matrix"  ),
43-     plot (Y2[1 ,:], Y2[2 ,:], c= L, m= 2 , t= :scatter , title= " NearestNeighbors"  ),
44-     plot (Y3[1 ,:], Y3[2 ,:], c= L, m= 2 , t= :scatter , title= " FLANN"  )
45- , leg= false )
26+ import  FLANN
27+ struct  FLANNTree{T <:  Real } <:  ManifoldLearning.AbstractNearestNeighbors 
28+     k:: Integer 
29+     index:: FLANN.FLANNIndex{T} 
30+ end 
31+ Base. show (io:: IO , NN:: FLANNTree ) =  print (io, " FLANNTree(k=$(NN. k) )"  )
32+ function  StatsBase. fit (:: Type{FLANNTree} , X:: AbstractMatrix{T} , k:: Integer ) where  {T<: Real }
33+     params =  FLANNParameters ()
34+     idx =  FLANN. flann (X, params)
35+     FLANNTree (k, idx)
36+ end 
37+ function  ManifoldLearning. knn (NN:: FLANNTree , X:: AbstractVecOrMat{T} ; self= false ) where  {T<: Real }
38+     E, D =  FLANN. knn (NN. index, X, NN. k+ 1 )
39+     sqrt .(@view  D[2 : end , :]), @view  E[2 : end , :]
40+ end 
0 commit comments