1
- # Additional wrappers for performant calculations of nearest neighbors
1
+ # Additional wrappers for calculations of nearest neighbors
2
2
3
3
using ManifoldLearning
4
- import Base: show
4
+ using LinearAlgebra: norm
5
+ import Base: show, size
5
6
import StatsAPI: fit
6
- import ManifoldLearning: knn
7
+ import ManifoldLearning: knn, inradius
7
8
8
9
# Wrapper around NearestNeighbors functionality
9
10
using NearestNeighbors: NearestNeighbors
10
11
struct KDTree <: ManifoldLearning.AbstractNearestNeighbors
11
- fitted:: NearestNeighbors.KDTree
12
+ fitted:: AbstractMatrix
13
+ tree:: NearestNeighbors.KDTree
12
14
end
13
15
show (io:: IO , NN:: KDTree ) = print (io, " KDTree" )
14
- fit (:: Type{KDTree} , X:: AbstractMatrix{T} ) where {T<: Real } = KDTree (NearestNeighbors. KDTree (X))
15
- function knn (NN:: KDTree , X:: AbstractVecOrMat{T} , k:: Int ; self= false ) where {T<: Real }
16
+ fit (:: Type{KDTree} , X:: AbstractMatrix{T} ) where {T<: Real } =
17
+ KDTree (X, NearestNeighbors. KDTree (X))
18
+ size (NN:: KDTree ) = (length (NN. fitted. data[1 ]), length (NN. fitted. data))
19
+ function knn (NN:: KDTree , X:: AbstractVecOrMat{T} , k:: Integer ;
20
+ self:: Bool = false , weights:: Bool = true , kwargs... ) where {T<: Real }
16
21
m, n = size (X)
17
22
@assert n > k " Number of observations must be more then $(k) "
18
-
19
- idxs, dist = NearestNeighbors. knn (NN. fitted, X, k+ 1 , true )
20
- D = Array {T} (undef, k, n)
21
- E = Array {Int32} (undef, k, n)
22
- for i in eachindex (idxs)
23
- E[:, i] = idxs[i][2 : end ]
24
- D[:, i] = dist[i][2 : end ]
23
+ A, D = NearestNeighbors. knn (NN. tree, X, k, true )
24
+ return A, D
25
+ end
26
+ function inradius (NN:: KDTree , X:: AbstractVecOrMat{T} , r:: Real ;
27
+ weights:: Bool = false , kwargs... ) where {T<: Real }
28
+ m, n = size (X)
29
+ A = NearestNeighbors. inrange (NN. tree, X, r)
30
+ W = Vector {Vector{T}} (undef, (weights ? n : 0 ))
31
+ if weights
32
+ for (i, ii) in enumerate (A)
33
+ W[i] = T[]
34
+ if length (ii) > 0
35
+ for v in eachcol (NN. fitted[:, ii])
36
+ d = norm (X[:,i] - v)
37
+ push! (W[i], d)
38
+ end
39
+ end
40
+ end
25
41
end
26
- return D, E
42
+ return A, W
27
43
end
28
44
45
+ #=
29
46
# Wrapper around FLANN functionality
30
47
using FLANN: FLANN
31
48
struct FLANNTree{T <: Real} <: ManifoldLearning.AbstractNearestNeighbors
@@ -41,4 +58,4 @@ function knn(NN::FLANNTree, X::AbstractVecOrMat{T}, k::Int; self=false) where {T
41
58
E, D = FLANN.knn(NN.index, X, NN.k+1)
42
59
sqrt.(@view D[2:end, :]), @view E[2:end, :]
43
60
end
44
-
61
+ =#
0 commit comments