Skip to content

Commit f78f028

Browse files
authored
Simplify implementation of Delta (#378)
* Simplify implementation of `Delta` * Fix bug in `pairwise!` implementation * Update Project.toml
1 parent f5bcae4 commit f78f028

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.24"
3+
version = "0.10.25"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/distances/delta.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
# Delta is not following the PreMetric rules since d(x, x) == 1
22
struct Delta <: Distances.UnionPreMetric end
33

4-
@inline Distances.eval_op(::Delta, a::Real, b::Real) = a == b
5-
@inline Distances.eval_reduce(::Delta, a, b) = a && b
6-
@inline Distances.eval_start(::Delta, a, b) = true
7-
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
8-
@inline (dist::Delta)(a::Number, b::Number) = a == b
4+
(dist::Delta)(a::Number, b::Number) = a == b
5+
Base.@propagate_inbounds function (dist::Delta)(
6+
a::AbstractArray{<:Number}, b::AbstractArray{<:Number}
7+
)
8+
@boundscheck if length(a) != length(b)
9+
throw(
10+
DimensionMismatch(
11+
"first array has length $(length(a)) which does not match the length of the second, $(length(b)).",
12+
),
13+
)
14+
end
15+
return a == b
16+
end
917

1018
Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool

src/distances/pairwise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ end
77
pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X)
88

99
function pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector, Y::AbstractVector)
10-
return broadcast!(d, out, X, Y')
10+
return broadcast!(d, out, X, permutedims(Y))
1111
end
1212

1313
pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X)

0 commit comments

Comments
 (0)