diff --git a/src/entities/KernelEval.jl b/src/entities/KernelEval.jl index d9587c9..db9e9fe 100644 --- a/src/entities/KernelEval.jl +++ b/src/entities/KernelEval.jl @@ -9,7 +9,7 @@ abstract type AbstractKernel end p::MvNormal{T,M} # TDB might already be covered in p.Σ.chol but having issues with SymPD (not particular to this AMP repo) """ Manually maintained square root concentration matrix for faster compute, TODO likely duplicate of existing Distrubtions.jl functionality. """ - sqrt_iΣ::iM = sqrt(inv(p.Σ)) + sqrt_iΣ::iM = sqrt(inv(cov(p))) """ Nonparametric weight value """ weight::Float64 = 1.0 end diff --git a/src/services/KernelEval.jl b/src/services/KernelEval.jl index eaecc3e..69ffdc2 100644 --- a/src/services/KernelEval.jl +++ b/src/services/KernelEval.jl @@ -24,12 +24,22 @@ Statistics.mean(m::MvNormalKernel) = m.μ # mean(m.p) # m.p.μ Statistics.cov(m::MvNormalKernel) = cov(m.p) # note also about m.sqrt_iΣ Statistics.std(m::MvNormalKernel) = sqrt(cov(m)) -updateKernelBW(k::MvNormalKernel,_bw) = (p=MvNormal(_bw); MvNormalKernel(;μ=k.μ,p,weight=k.weight)) +function updateKernelBW( + k::MvNormalKernel, + _bw, + isq_bw = inv(sqrt(_bw)) +) + p=MvNormal(_bw) + sqrt_iΣ = typeof(k.sqrt_iΣ)(isq_bw) + return MvNormalKernel(;μ=k.μ,p,sqrt_iΣ,weight=k.weight) +end +updateKernelBW(ekr::MvNormalKernel, ::Nothing) = ekr # avoid ifs for noops + function evaluate( M::AbstractManifold, ekr::MvNormalKernel, - p # on manifold point + p, # on manifold point ) # dim = manifold_dimension(M) @@ -98,10 +108,7 @@ function distanceMalahanobisSq( basis=DefaultOrthogonalBasis() ) δc = distanceMalahanobisCoordinates(M,K,q,basis) - # p = mean(K) - # ϵ = identity_element(M, q) - # X = get_vector(M, ϵ, δc, basis) - # return inner(M, p, X, X) + # return inner(M, p, X, X) # did not work as inner gave almost 2x the answer? return δc'*δc end @@ -114,7 +121,6 @@ function _distance( p=MvNormal(_p,SVector(ntuple((s)->1,manifold_dimension(M))...)) ), distFnc::Function=distanceMalahanobisSq, - # distFnc::Function=distanceMalahanobisSq, ) distFnc(M, kernel(p), q) end diff --git a/src/services/ManellicTree.jl b/src/services/ManellicTree.jl index b54245a..528dd4f 100644 --- a/src/services/ManellicTree.jl +++ b/src/services/ManellicTree.jl @@ -199,7 +199,7 @@ function Base.show( printstyled(io,"::TK ";color=:magenta) println(io) end - printstyled(io, " (depth) : ", floor(Int,log2(length(mt.tree_kernels))),"+1"; color=:light_black) + printstyled(io, " (depth) : 1+", floor(Int,log2(length(mt.tree_kernels))); color=:light_black) println(io) printstyled(io, " (blncd) : ", "true : _wip_";color=:light_black) println(io) @@ -600,6 +600,7 @@ function evaluate( mt::ManellicTree{M,D,N,HL}, pt, LOO::Bool = false, + force_kbw = nothing ) where {M,D,N,HL} # # force function barrier, just to be sure dyndispatch is limited # _F() = getfield(ApproxManifoldProducts,HL.name.name) @@ -616,8 +617,10 @@ function evaluate( # FIXME, is this assuming length(pts) and length(mt.leaf_kernels) are the same? # FIXME use consolidated getKernelLeaf instead ekr = mt.leaf_kernels[i] + ekr = updateKernelBW(ekr, force_kbw) # TODO remember special handling for partials in the future oneval = mt.weights[i] * evaluate(mt.manifold, ekr, pt) + # leave one out requires kernel weighting to removal of leave out weight oneval *= !LOO ? 1 : 1/(1-w[i]) sumval += oneval end @@ -645,6 +648,7 @@ function evaluateDensityAtPoints( ) # evaluate new sampling weights of points in out component + # TODO use agnostic-Dual tree or MonteCarloDualTree evaluation # vector for storing resulting weights smw = zeros(length(eval_at_points)) for (i,ev) in enumerate(eval_at_points) @@ -666,14 +670,15 @@ end function expectedLogL( mt::ManellicTree{M,D,N}, epts::AbstractVector, - LOO::Bool = false + LOO::Bool = false, + force_kbw = nothing ) where {M,D,N} T = Float64 - # TODO really slow brute force evaluation + # TODO really slow brute force evaluation, use agnostic-DualTree or MonteCarloDualTree eL = MVector{length(epts),T}(undef) for (i,p) in enumerate(epts) # LOO skip for leave-one-out - eL[i] = evaluate(mt, p, LOO) + eL[i] = evaluate(mt, p, LOO, force_kbw) end # set numerical tolerance floor zrs = findall(isapprox.(0,eL)) @@ -693,7 +698,8 @@ end entropy( mt::ManellicTree, -) = -expectedLogL(mt, getPoints(mt), true) + force_kbw = nothing +) = -expectedLogL(mt, getPoints(mt), true, force_kbw) (mt::ManellicTree)( diff --git a/src/services/ManifoldKernelDensity.jl b/src/services/ManifoldKernelDensity.jl index 42d47dc..998f7c2 100644 --- a/src/services/ManifoldKernelDensity.jl +++ b/src/services/ManifoldKernelDensity.jl @@ -120,9 +120,9 @@ function manikde!_manellic( # Cost function to optimize _cost(_pts, σ) = begin - # FIXME avoid rebuilding tree at each optim iteration!!! - mtr = buildTree_Manellic!(M, _pts; kernel_bw=reshape(σ,manifold_dimension(M),1), kernel=MvNormalKernel) - entropy(mtr) + # avoid rebuilding tree at each optim iteration!!! + # mtr = buildTree_Manellic!(M, _pts; kernel_bw=reshape(σ,manifold_dimension(M),1), kernel=MvNormalKernel) + entropy(mtree,reshape(σ,manifold_dimension(M),1)) end # optimize for best LOOCV bandwidth @@ -133,7 +133,7 @@ function manikde!_manellic( (s)->_cost(pts,[s^2;]), lcov[1], ucov[1], Optim.GoldenSection() ) - best_cov = [Optim.minimizer(res);] + best_cov = [Optim.minimizer(res);;] # reuse (heavy lift parts of) earlier tree build # return tree with correct bandwidth diff --git a/test/manellic/testManellicTree.jl b/test/manellic/testManellicTree.jl index bcc2328..93d3506 100755 --- a/test/manellic/testManellicTree.jl +++ b/test/manellic/testManellicTree.jl @@ -629,6 +629,22 @@ AMP.expectedLogL(mtree, pts) @test AMP.expectedLogL(mtree, pts) < Inf +# to enable faster bandwidth selection/optimization +ekr = ApproxManifoldProducts.getKernelLeaf(mtree,1,false) +ekr_ = ApproxManifoldProducts.updateKernelBW(ekr,SA[1.0;;]) + +@test typeof(ekr) == typeof(ekr_) + +# confirm that updating the bandwidths works properly +Σ = [0.1+0.5*rand();;] + +mtr = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw=Σ,kernel=AMP.MvNormalKernel) +mtr_ = ApproxManifoldProducts.updateBandwidths(mtree, Σ) + +# +@test isapprox( mtr([0.0]), mtr_([0.0]); atol=1e-10) +@test isapprox( ApproxManifoldProducts.entropy(mtr), ApproxManifoldProducts.entropy(mtr_); atol=1e-10) + ## end @@ -711,8 +727,7 @@ Y = S .|> s->cost(pts,s^2) # should pass the optimal kbw somewhere in the given range @test any(0 .< diff(Y)) -# and optimize - +# and optimize with rebuild tree cost res = Optim.optimize( (s)->cost(pts,s^2), 0.05, 3.0, Optim.GoldenSection() @@ -720,6 +735,38 @@ res = Optim.optimize( best_cov = Optim.minimizer(res) @test isapprox(0.5, best_cov; atol=0.3) +bcov_ = deepcopy(best_cov) + +## Test more efficient updateKernelBW version + +cost2(σ) = begin + mtr = ApproxManifoldProducts.updateBandwidths(mtree_0, [σ;;]) + AMP.entropy(mtr) +end + +# and optimize with "update" kernel bandwith cost +res = Optim.optimize( + (s)->cost2(s^2), + 0.05, 3.0, Optim.GoldenSection() +) +@show best_cov = Optim.minimizer(res) + +@test isapprox(bcov_, best_cov; atol=1e-3) + +# mask bandwith by passing in an alternative + +cost3(σ) = begin + AMP.entropy(mtree_0, [σ;;]) +end + +# and optimize with "update" kernel bandwith cost +res = Optim.optimize( + (s)->cost3(s^2), + 0.05, 3.0, Optim.GoldenSection() +) +@show best_cov = Optim.minimizer(res) + +@test isapprox(bcov_, best_cov; atol=1e-3) ## @@ -734,10 +781,25 @@ end M = TranslationGroup(1) # pts = [[0.;],[0.1],[0.2;],[0.3;]] -pts = [1*randn(1) for _ in 1:64] +pts = [1*randn(1) for _ in 1:128] mkd = ApproxManifoldProducts.manikde!_manellic(M,pts) +best_cov = cov(ApproxManifoldProducts.getKernelLeaf(mkd.belief,1))[1] |> sqrt +@show best_cov + +@test isapprox(0.5, best_cov; atol=0.3) + +# remember broken code in get w bounds + +try + pts = [1*randn(1) for _ in 1:100] + mkd = ApproxManifoldProducts.manikde!_manellic(M,pts) +catch + @test_broken false +end + + ## end