Skip to content

Commit 9e23809

Browse files
authored
Add specialized pairwise methods to *msd (#232)
1 parent c63dc14 commit 9e23809

File tree

3 files changed

+51
-8
lines changed

3 files changed

+51
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Distances"
22
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
3-
version = "0.10.5"
3+
version = "0.10.6"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/metrics.jl

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -731,10 +731,50 @@ function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedE
731731
r
732732
end
733733

734+
# MeanSqDeviation, RMSDeviation, NormRMSDeviation
735+
function _pairwise!(r::AbstractMatrix, dist::MeanSqDeviation, a::AbstractMatrix, b::AbstractMatrix)
736+
_pairwise!(r, SqEuclidean(), a, b)
737+
# TODO: Replace by rdiv!(r, size(a, 1)) once julia compat ≥v1.2
738+
s = size(a, 1)
739+
@simd for I in eachindex(r)
740+
@inbounds r[I] /= s
741+
end
742+
return r
743+
end
744+
_pairwise!(r::AbstractMatrix, dist::RMSDeviation, a::AbstractMatrix, b::AbstractMatrix) =
745+
sqrt!(_pairwise!(r, MeanSqDeviation(), a, b))
746+
function _pairwise!(r::AbstractMatrix, dist::NormRMSDeviation, a::AbstractMatrix, b::AbstractMatrix)
747+
_pairwise!(r, RMSDeviation(), a, b)
748+
@views for (i, j) in zip(axes(r, 1), axes(a, 2))
749+
amin, amax = extrema(a[:,j])
750+
r[i,:] ./= amax - amin
751+
end
752+
return r
753+
end
754+
755+
function _pairwise!(r::AbstractMatrix, dist::MeanSqDeviation, a::AbstractMatrix)
756+
_pairwise!(r, SqEuclidean(), a)
757+
# TODO: Replace by rdiv!(r, size(a, 1)) once julia compat ≥v1.2
758+
s = size(a, 1)
759+
@simd for I in eachindex(r)
760+
@inbounds r[I] /= s
761+
end
762+
return r
763+
end
764+
_pairwise!(r::AbstractMatrix, dist::RMSDeviation, a::AbstractMatrix) =
765+
sqrt!(_pairwise!(r, MeanSqDeviation(), a))
766+
function _pairwise!(r::AbstractMatrix, dist::NormRMSDeviation, a::AbstractMatrix)
767+
_pairwise!(r, RMSDeviation(), a)
768+
@views for (i, j) in zip(axes(r, 1), axes(a, 2))
769+
amin, amax = extrema(a[:,j])
770+
r[i,:] ./= amax - amin
771+
end
772+
return r
773+
end
774+
734775
# CosineDist
735776

736-
function _pairwise!(r::AbstractMatrix, ::CosineDist,
737-
a::AbstractMatrix, b::AbstractMatrix)
777+
function _pairwise!(r::AbstractMatrix, ::CosineDist, a::AbstractMatrix, b::AbstractMatrix)
738778
require_one_based_indexing(r, a, b)
739779
m, na, nb = get_pairwise_dims(r, a, b)
740780
inplace = promote_type(eltype(r), typeof(oneunit(eltype(a))'oneunit(eltype(b)))) === eltype(r)
@@ -772,10 +812,7 @@ end
772812
# 2. pre-calculated `_centralize_colwise` avoids four times of redundant computations
773813
# of `_centralize` -- ~4x speed up
774814
_centralize_colwise(x::AbstractMatrix) = x .- mean(x, dims=1)
775-
function _pairwise!(r::AbstractMatrix, ::CorrDist,
776-
a::AbstractMatrix, b::AbstractMatrix)
815+
_pairwise!(r::AbstractMatrix, ::CorrDist, a::AbstractMatrix, b::AbstractMatrix) =
777816
_pairwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b))
778-
end
779-
function _pairwise!(r::AbstractMatrix, ::CorrDist, a::AbstractMatrix)
817+
_pairwise!(r::AbstractMatrix, ::CorrDist, a::AbstractMatrix) =
780818
_pairwise!(r, CosineDist(), _centralize_colwise(a))
781-
end

test/test_dists.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,12 @@ end
686686
test_pairwise(cityblock, X, Y, T)
687687
test_pairwise(TotalVariation(), X, Y, T)
688688
test_pairwise(totalvariation, X, Y, T)
689+
test_pairwise(MeanSqDeviation(), X, Y, T)
690+
test_pairwise(msd, X, Y, T)
691+
test_pairwise(RMSDeviation(), X, Y, T)
692+
test_pairwise(rmsd, X, Y, T)
693+
test_pairwise(NormRMSDeviation(), X, Y, T)
694+
test_pairwise(nrmsd, X, Y, T)
689695
test_pairwise(Chebyshev(), X, Y, T)
690696
test_pairwise(chebyshev, X, Y, T)
691697
test_pairwise(Minkowski(2.5), X, Y, T)

0 commit comments

Comments
 (0)