@@ -13,16 +13,19 @@ struct KDTree <: ManifoldLearning.AbstractNearestNeighbors
13
13
tree:: NearestNeighbors.KDTree
14
14
end
15
15
show (io:: IO , NN:: KDTree ) = print (io, " KDTree" )
16
+ size (NN:: KDTree ) = (length (NN. fitted. data[1 ]), length (NN. fitted. data))
17
+
16
18
fit (:: Type{KDTree} , X:: AbstractMatrix{T} ) where {T<: Real } =
17
19
KDTree (X, NearestNeighbors. KDTree (X))
18
- size (NN :: KDTree ) = ( length (NN . fitted . data[ 1 ]), length (NN . fitted . data))
20
+
19
21
function knn (NN:: KDTree , X:: AbstractVecOrMat{T} , k:: Integer ;
20
22
self:: Bool = false , weights:: Bool = true , kwargs... ) where {T<: Real }
21
23
m, n = size (X)
22
24
@assert n > k " Number of observations must be more then $(k) "
23
25
A, D = NearestNeighbors. knn (NN. tree, X, k, true )
24
26
return A, D
25
27
end
28
+
26
29
function inradius (NN:: KDTree , X:: AbstractVecOrMat{T} , r:: Real ;
27
30
weights:: Bool = false , kwargs... ) where {T<: Real }
28
31
m, n = size (X)
@@ -42,20 +45,50 @@ function inradius(NN::KDTree, X::AbstractVecOrMat{T}, r::Real;
42
45
return A, W
43
46
end
44
47
45
- #=
46
48
# Wrapper around FLANN functionality
47
49
using FLANN: FLANN
48
50
struct FLANNTree{T <: Real } <: ManifoldLearning.AbstractNearestNeighbors
51
+ d:: Int
49
52
index:: FLANN.FLANNIndex{T}
50
53
end
51
54
show (io:: IO , NN:: FLANNTree ) = print (io, " FLANNTree" )
55
+ size (NN:: FLANNTree ) = (NN. d, length (NN. index))
56
+
52
57
function fit (:: Type{FLANNTree} , X:: AbstractMatrix{T} ) where {T<: Real }
53
58
params = FLANN. FLANNParameters ()
54
59
idx = FLANN. flann (X, params)
55
- FLANNTree(idx)
60
+ FLANNTree (size (X,1 ), idx)
61
+ end
62
+
63
+ function knn (NN:: FLANNTree , X:: AbstractVecOrMat{T} , k:: Integer ;
64
+ self:: Bool = false , weights:: Bool = false , kwargs... ) where {T<: Real }
65
+ m, n = size (X)
66
+ E, D = FLANN. knn (NN. index, X, k+ 1 )
67
+ idxs = (1 : k). + (! self)
68
+
69
+ A = Vector {Vector{Int}} (undef, n)
70
+ W = Vector {Vector{T}} (undef, (weights ? n : 0 ))
71
+ for (i,(es, ds)) in enumerate (zip (eachcol (E), eachcol (D)))
72
+ A[i] = es[idxs]
73
+ if weights
74
+ W[i] = sqrt .(ds[idxs])
75
+ end
76
+ end
77
+ return A, W
56
78
end
57
- function knn(NN::FLANNTree, X::AbstractVecOrMat{T}, k::Int; self=false) where {T<:Real}
58
- E, D = FLANN.knn(NN.index, X, NN.k+1)
59
- sqrt.(@view D[2:end, :]), @view E[2:end, :]
79
+
80
+ function inradius (NN:: FLANNTree , X:: AbstractVecOrMat{T} , r:: Real ;
81
+ weights:: Bool = false , kwargs... ) where {T<: Real }
82
+ m, n = size (X)
83
+ A = Vector {Vector{Int}} (undef, n)
84
+ W = Vector {Vector{T}} (undef, (weights ? n : 0 ))
85
+ for (i, x) in enumerate (eachcol (X))
86
+ E, D = FLANN. inrange (NN. index, x, r)
87
+ A[i] = E
88
+ if weights
89
+ W[i] = D
90
+ end
91
+ end
92
+ return A, W
60
93
end
61
- =#
94
+
0 commit comments