Skip to content

Commit fcf8a90

Browse files
committed
fixed docs & KDTree plug-in
1 parent b56b091 commit fcf8a90

File tree

3 files changed

+37
-22
lines changed

3 files changed

+37
-22
lines changed

docs/src/interface.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,16 @@ ManifoldLearning.AbstractNearestNeighbors
4646
The above interface requires implementation of the following methods:
4747

4848
```@docs
49-
ManifoldLearning.knn(NN::AbstractNearestNeighbors, X::AbstractVecOrMat{T}, k::Integer) where T<:Real
50-
ManifoldLearning.inradius(NN::AbstractNearestNeighbors, X::AbstractVecOrMat{T}, r::Real) where T<:Real
49+
ManifoldLearning.knn(NN::ManifoldLearning.AbstractNearestNeighbors, X::AbstractVecOrMat{T}, k::Integer) where T<:Real
50+
ManifoldLearning.inradius(NN::ManifoldLearning.AbstractNearestNeighbors, X::AbstractVecOrMat{T}, r::Real) where T<:Real
5151
```
5252

5353
Following auxiliary methods available for any implementation of
5454
`AbstractNearestNeighbors`-derived type:
5555

5656
```@docs
57-
ManifoldLearning.adjacency_list(NN::AbstractNearestNeighbors, X::AbstractVecOrMat{T}, k::Integer) where T<:Real
58-
ManifoldLearning.adjacency_list(NN::AbstractNearestNeighbors, X::AbstractVecOrMat{T}, r::Real) where T<:Real
59-
ManifoldLearning.adjacency_matrix(NN::AbstractNearestNeighbors, X::AbstractVecOrMat{T}, k::Integer) where T<:Real
60-
ManifoldLearning.adjacency_matrix(NN::AbstractNearestNeighbors, X::AbstractVecOrMat{T}, r::Real) where T<:Real
57+
ManifoldLearning.adjacency_list(NN::ManifoldLearning.AbstractNearestNeighbors, X::AbstractVecOrMat{T}, k::Integer) where T<:Real
58+
ManifoldLearning.adjacency_matrix(NN::ManifoldLearning.AbstractNearestNeighbors, X::AbstractVecOrMat{T}, k::Integer) where T<:Real
6159
```
6260

6361
The default implementation uses inefficient ``O(n^2)`` algorithm for nearest

misc/nearestneighbors.jl

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,48 @@
1-
# Additional wrappers for performant calculations of nearest neighbors
1+
# Additional wrappers for calculations of nearest neighbors
22

33
using ManifoldLearning
4-
import Base: show
4+
using LinearAlgebra: norm
5+
import Base: show, size
56
import StatsAPI: fit
6-
import ManifoldLearning: knn
7+
import ManifoldLearning: knn, inradius
78

89
# Wrapper around NearestNeighbors functionality
910
using NearestNeighbors: NearestNeighbors
1011
struct KDTree <: ManifoldLearning.AbstractNearestNeighbors
11-
fitted::NearestNeighbors.KDTree
12+
fitted::AbstractMatrix
13+
tree::NearestNeighbors.KDTree
1214
end
1315
show(io::IO, NN::KDTree) = print(io, "KDTree")
14-
fit(::Type{KDTree}, X::AbstractMatrix{T}) where {T<:Real} = KDTree(NearestNeighbors.KDTree(X))
15-
function knn(NN::KDTree, X::AbstractVecOrMat{T}, k::Int; self=false) where {T<:Real}
16+
fit(::Type{KDTree}, X::AbstractMatrix{T}) where {T<:Real} =
17+
KDTree(X, NearestNeighbors.KDTree(X))
18+
size(NN::KDTree) = (length(NN.fitted.data[1]), length(NN.fitted.data))
19+
function knn(NN::KDTree, X::AbstractVecOrMat{T}, k::Integer;
20+
self::Bool=false, weights::Bool=true, kwargs...) where {T<:Real}
1621
m, n = size(X)
1722
@assert n > k "Number of observations must be more then $(k)"
18-
19-
idxs, dist = NearestNeighbors.knn(NN.fitted, X, k+1, true)
20-
D = Array{T}(undef, k, n)
21-
E = Array{Int32}(undef, k, n)
22-
for i in eachindex(idxs)
23-
E[:, i] = idxs[i][2:end]
24-
D[:, i] = dist[i][2:end]
23+
A, D = NearestNeighbors.knn(NN.tree, X, k, true)
24+
return A, D
25+
end
26+
function inradius(NN::KDTree, X::AbstractVecOrMat{T}, r::Real;
27+
weights::Bool=false, kwargs...) where {T<:Real}
28+
m, n = size(X)
29+
A = NearestNeighbors.inrange(NN.tree, X, r)
30+
W = Vector{Vector{T}}(undef, (weights ? n : 0))
31+
if weights
32+
for (i, ii) in enumerate(A)
33+
W[i] = T[]
34+
if length(ii) > 0
35+
for v in eachcol(NN.fitted[:, ii])
36+
d = norm(X[:,i] - v)
37+
push!(W[i], d)
38+
end
39+
end
40+
end
2541
end
26-
return D, E
42+
return A, W
2743
end
2844

45+
#=
2946
# Wrapper around FLANN functionality
3047
using FLANN: FLANN
3148
struct FLANNTree{T <: Real} <: ManifoldLearning.AbstractNearestNeighbors
@@ -41,4 +58,4 @@ function knn(NN::FLANNTree, X::AbstractVecOrMat{T}, k::Int; self=false) where {T
4158
E, D = FLANN.knn(NN.index, X, NN.k+1)
4259
sqrt.(@view D[2:end, :]), @view E[2:end, :]
4360
end
44-
61+
=#

src/nearestneighbors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ function knn(NN::BruteForce{T}, X::AbstractVecOrMat{T}, k::Integer;
112112
end
113113

114114
function inradius(NN::BruteForce{T}, X::AbstractVecOrMat{T}, r::Real;
115-
self::Bool=false, weights::Bool=false, kwargs...) where T<:Real
115+
self::Bool=false, weights::Bool=false, kwargs...) where T<:Real
116116
# construct distance matrix
117117
D = pairwise((x,y)->norm(x-y), eachcol(NN.fitted), eachcol(X))
118118

0 commit comments

Comments
 (0)