Skip to content

Commit af839e3

Browse files
committed
update deps & fix NN plug-ins
1 parent 3f7b455 commit af839e3

File tree

4 files changed

+73
-35
lines changed

4 files changed

+73
-35
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ keywords = ["manifold learning", "dimensionality reduction", "nonlinear"]
44
license = "MIT"
55
desc = "A Julia package for nonlinear dimensionality reduction"
66
repository = "https://github.com/JuliaStats/ManifoldLearning.jl.git"
7-
version = "0.6.1"
7+
version = "0.6.2"
88

99
[deps]
1010
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
@@ -19,7 +19,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1919
[compat]
2020
Combinatorics = "^1"
2121
LightGraphs = "^1.3"
22-
MultivariateStats = "~0.7"
22+
MultivariateStats = "^0.8"
2323
SimpleWeightedGraphs = "^1.1"
2424
StatsBase = "~0.33"
2525
julia = "1"

misc/knn-test.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using ManifoldLearning
2+
3+
include("nearestneighbors.jl")
4+
5+
k=13
6+
X, L = ManifoldLearning.swiss_roll()
7+
8+
# Use default distance matrix based method to find nearest neighbors
9+
M1 = fit(Isomap, X)
10+
Y1 = transform(M1)
11+
12+
# Use NearestNeighbors package to find nearest neighbors
13+
M2 = fit(Isomap, X, nntype=KDTree)
14+
Y2 = transform(M2)
15+
16+
# Use FLANN package to find nearest neighbors
17+
M3 = fit(Isomap, X, nntype=FLANNTree)
18+
Y3 = transform(M3)
19+
20+
using Plots
21+
plot(
22+
plot(X[1,:], X[2,:], X[3,:], zcolor=L, m=2, t=:scatter3d, leg=false, title="Swiss Roll"),
23+
plot(Y1[1,:], Y1[2,:], c=L, m=2, t=:scatter, title="Distance Matrix"),
24+
plot(Y2[1,:], Y2[2,:], c=L, m=2, t=:scatter, title="NearestNeighbors"),
25+
plot(Y3[1,:], Y3[2,:], c=L, m=2, t=:scatter, title="FLANN")
26+
, leg=false)

misc/nearestneighbors.jl

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
using Plots
1+
using ManifoldLearning
2+
import StatsBase
23

3-
import FLANN
4-
function knn_flann(X::AbstractMatrix{T}, k::Int=12) where T<:Real
5-
params = FLANN.FLANNParameters()
6-
E, D = FLANN.knn(X, X, k+1, params)
7-
sqrt.(@view D[2:end, :]), @view E[2:end, :]
4+
import NearestNeighbors
5+
struct KDTree <: ManifoldLearning.AbstractNearestNeighbors
6+
k::Integer
7+
fitted::NearestNeighbors.KDTree
88
end
9+
Base.show(io::IO, NN::KDTree) = print(io, "KDTree(k=$(NN.k))")
10+
StatsBase.fit(::Type{KDTree}, X::AbstractMatrix{T}, k::Integer) where {T<:Real} = KDTree(k, NearestNeighbors.KDTree(X))
11+
function ManifoldLearning.knn(NN::KDTree, X::AbstractVecOrMat{T}; self=false) where {T<:Real}
12+
m, n = size(X)
13+
k = NN.k
14+
@assert n > k "Number of observations must be more then $(k)"
915

10-
import NearestNeighbors
11-
function knn_nearestneighbors(X::AbstractMatrix{T}, k::Int=12) where T<:Real
12-
n = size(X,2)
13-
kdtree = NearestNeighbors.KDTree(X)
14-
idxs, dist = NearestNeighbors.knn(kdtree, X, k+1, true)
16+
idxs, dist = NearestNeighbors.knn(NN.fitted, X, k+1, true)
1517
D = Array{T}(undef, k, n)
1618
E = Array{Int32}(undef, k, n)
1719
for i in eachindex(idxs)
@@ -21,25 +23,18 @@ function knn_nearestneighbors(X::AbstractMatrix{T}, k::Int=12) where T<:Real
2123
return D, E
2224
end
2325

24-
using ManifoldLearning
25-
k=13
26-
X, L = ManifoldLearning.swiss_roll()
27-
28-
# Use default distance matrix based method to find nearest neighbors
29-
M1 = fit(Isomap, X)
30-
Y1 = transform(M1)
31-
32-
# Use NearestNeighbors package to find nearest neighbors
33-
M2 = fit(Isomap, X, knn=knn_nearestneighbors)
34-
Y2 = transform(M2)
35-
36-
# Use FLANN package to find nearest neighbors
37-
M3 = fit(Isomap, X, knn=knn_flann)
38-
Y3 = transform(M3)
39-
40-
plot(
41-
plot(X[1,:], X[2,:], X[3,:], zcolor=L, m=2, t=:scatter3d, leg=false, title="Swiss Roll"),
42-
plot(Y1[1,:], Y1[2,:], c=L, m=2, t=:scatter, title="Distance Matrix"),
43-
plot(Y2[1,:], Y2[2,:], c=L, m=2, t=:scatter, title="NearestNeighbors"),
44-
plot(Y3[1,:], Y3[2,:], c=L, m=2, t=:scatter, title="FLANN")
45-
, leg=false)
26+
import FLANN
27+
struct FLANNTree{T <: Real} <: ManifoldLearning.AbstractNearestNeighbors
28+
k::Integer
29+
index::FLANN.FLANNIndex{T}
30+
end
31+
Base.show(io::IO, NN::FLANNTree) = print(io, "FLANNTree(k=$(NN.k))")
32+
function StatsBase.fit(::Type{FLANNTree}, X::AbstractMatrix{T}, k::Integer) where {T<:Real}
33+
params = FLANNParameters()
34+
idx = FLANN.flann(X, params)
35+
FLANNTree(k, idx)
36+
end
37+
function ManifoldLearning.knn(NN::FLANNTree, X::AbstractVecOrMat{T}; self=false) where {T<:Real}
38+
E, D = FLANN.knn(NN.index, X, NN.k+1)
39+
sqrt.(@view D[2:end, :]), @view E[2:end, :]
40+
end

src/nearestneighbors.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
1+
"""
2+
AbstractNearestNeighbors
3+
4+
Abstract type for nearest neighbor plug-in implementations.
5+
"""
16
abstract type AbstractNearestNeighbors end
27

8+
"""
9+
knn(NN::AbstractNearestNeighbors, X::AbstractVecOrMat{T}; kwargs...) -> (D,E)
10+
11+
Perform construction of the distance matrix `D` and neares neighbor weighted graph `E` from the `NN` object
12+
"""
13+
function knn(NN::AbstractNearestNeighbors, X::AbstractVecOrMat{T}; kwargs...) where T<:Real end
14+
15+
"""
16+
BruteForce
17+
18+
Calculate NN using pairwise distance matrix.
19+
"""
320
struct BruteForce{T<:Real} <: AbstractNearestNeighbors
421
k::Integer
522
fitted::AbstractMatrix{T}

0 commit comments

Comments
 (0)