Skip to content

Commit 39a491b

Browse files
committed
updated FLANN plug-in type
1 parent fcf8a90 commit 39a491b

File tree

1 file changed

+40
-7
lines changed

1 file changed

+40
-7
lines changed

misc/nearestneighbors.jl

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,19 @@ struct KDTree <: ManifoldLearning.AbstractNearestNeighbors
1313
tree::NearestNeighbors.KDTree
1414
end
1515
show(io::IO, NN::KDTree) = print(io, "KDTree")
16+
size(NN::KDTree) = (length(NN.fitted.data[1]), length(NN.fitted.data))
17+
1618
fit(::Type{KDTree}, X::AbstractMatrix{T}) where {T<:Real} =
1719
KDTree(X, NearestNeighbors.KDTree(X))
18-
size(NN::KDTree) = (length(NN.fitted.data[1]), length(NN.fitted.data))
20+
1921
function knn(NN::KDTree, X::AbstractVecOrMat{T}, k::Integer;
2022
self::Bool=false, weights::Bool=true, kwargs...) where {T<:Real}
2123
m, n = size(X)
2224
@assert n > k "Number of observations must be more then $(k)"
2325
A, D = NearestNeighbors.knn(NN.tree, X, k, true)
2426
return A, D
2527
end
28+
2629
function inradius(NN::KDTree, X::AbstractVecOrMat{T}, r::Real;
2730
weights::Bool=false, kwargs...) where {T<:Real}
2831
m, n = size(X)
@@ -42,20 +45,50 @@ function inradius(NN::KDTree, X::AbstractVecOrMat{T}, r::Real;
4245
return A, W
4346
end
4447

45-
#=
4648
# Wrapper around FLANN functionality
4749
using FLANN: FLANN
4850
struct FLANNTree{T <: Real} <: ManifoldLearning.AbstractNearestNeighbors
51+
d::Int
4952
index::FLANN.FLANNIndex{T}
5053
end
5154
show(io::IO, NN::FLANNTree) = print(io, "FLANNTree")
55+
size(NN::FLANNTree) = (NN.d, length(NN.index))
56+
5257
function fit(::Type{FLANNTree}, X::AbstractMatrix{T}) where {T<:Real}
5358
params = FLANN.FLANNParameters()
5459
idx = FLANN.flann(X, params)
55-
FLANNTree(idx)
60+
FLANNTree(size(X,1), idx)
61+
end
62+
63+
function knn(NN::FLANNTree, X::AbstractVecOrMat{T}, k::Integer;
64+
self::Bool=false, weights::Bool=false, kwargs...) where {T<:Real}
65+
m, n = size(X)
66+
E, D = FLANN.knn(NN.index, X, k+1)
67+
idxs = (1:k).+(!self)
68+
69+
A = Vector{Vector{Int}}(undef, n)
70+
W = Vector{Vector{T}}(undef, (weights ? n : 0))
71+
for (i,(es, ds)) in enumerate(zip(eachcol(E), eachcol(D)))
72+
A[i] = es[idxs]
73+
if weights
74+
W[i] = sqrt.(ds[idxs])
75+
end
76+
end
77+
return A, W
5678
end
57-
function knn(NN::FLANNTree, X::AbstractVecOrMat{T}, k::Int; self=false) where {T<:Real}
58-
E, D = FLANN.knn(NN.index, X, NN.k+1)
59-
sqrt.(@view D[2:end, :]), @view E[2:end, :]
79+
80+
function inradius(NN::FLANNTree, X::AbstractVecOrMat{T}, r::Real;
81+
weights::Bool=false, kwargs...) where {T<:Real}
82+
m, n = size(X)
83+
A = Vector{Vector{Int}}(undef, n)
84+
W = Vector{Vector{T}}(undef, (weights ? n : 0))
85+
for (i, x) in enumerate(eachcol(X))
86+
E, D = FLANN.inrange(NN.index, x, r)
87+
A[i] = E
88+
if weights
89+
W[i] = D
90+
end
91+
end
92+
return A, W
6093
end
61-
=#
94+

0 commit comments

Comments
 (0)