Skip to content

Commit ffbaa5f

Browse files
authored
Use RealDot.realdot (#542)
* Use `RealDot.realdot` * More `realdot`
1 parent c2ab41b commit ffbaa5f

File tree

10 files changed

+18
-22
lines changed

10 files changed

+18
-22
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.11.6"
3+
version = "1.12.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
10+
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1112

1213
[compat]
@@ -15,6 +16,7 @@ ChainRulesTestUtils = "1"
1516
Compat = "3.35"
1617
FiniteDifferences = "0.12.8"
1718
JuliaInterpreter = "0.8"
19+
RealDot = "0.1"
1820
StaticArrays = "1.2"
1921
julia = "1"
2022

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Compat
66
using LinearAlgebra
77
using LinearAlgebra.BLAS
88
using Random
9+
using RealDot: realdot
910
using Statistics
1011

1112
# Basically everything this package does is overloading these, so we make an exception

src/rulesets/Base/base.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878

7979
function frule((_, Δz), ::typeof(hypot), z::Complex)
8080
Ω = hypot(z)
81-
∂Ω = _realconjtimes(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω)
81+
∂Ω = realdot(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω)
8282
return Ω, ∂Ω
8383
end
8484

src/rulesets/Base/fastmath_able.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ let
6868
Ω = abs(x)
6969
# `ifelse` is applied only to denominator to ensure type-stability.
7070
signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω)
71-
return Ω, _realconjtimes(signx, Δx)
71+
return Ω, realdot(signx, Δx)
7272
end
7373

7474
function rrule(::typeof(abs), x::Union{Real, Complex})
@@ -82,7 +82,7 @@ let
8282

8383
## abs2
8484
function frule((_, Δz), ::typeof(abs2), z::Union{Real, Complex})
85-
return abs2(z), 2 * _realconjtimes(z, Δz)
85+
return abs2(z), 2 * realdot(z, Δz)
8686
end
8787

8888
function rrule(::typeof(abs2), z::Union{Real, Complex})
@@ -146,7 +146,7 @@ let
146146
) where {T<:Union{Real,Complex}}
147147
Ω = hypot(x, y)
148148
n = ifelse(iszero(Ω), one(Ω), Ω)
149-
∂Ω = (_realconjtimes(x, Δx) + _realconjtimes(y, Δy)) / n
149+
∂Ω = (realdot(x, Δx) + realdot(y, Δy)) / n
150150
return Ω, ∂Ω
151151
end
152152

src/rulesets/Base/mapreduce.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ function frule(
111111
= unthunk(Δx)
112112
y = sum(abs2, x; dims=dims)
113113
∂y = if dims isa Colon
114-
2 * real(dot(x, ẋ))
114+
2 * realdot(x, ẋ)
115115
elseif VERSION v"1.2" # multi-iterator mapreduce introduced in v1.2
116116
mapreduce(+, x, ẋ; dims=dims) do xi, dxi
117-
2 * _realconjtimes(xi, dxi)
117+
2 * realdot(xi, dxi)
118118
end
119119
else
120-
2 * sum(_realconjtimes.(x, ẋ); dims=dims)
120+
2 * sum(realdot.(x, ẋ); dims=dims)
121121
end
122122
return y, ∂y
123123
end

src/rulesets/Base/utils.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
# real(conj(x) * y) avoiding computing the imaginary part if possible
2-
@inline _realconjtimes(x, y) = real(conj(x) * y)
3-
@inline _realconjtimes(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y))
4-
@inline _realconjtimes(x::Real, y::Complex) = x * real(y)
5-
@inline _realconjtimes(x::Complex, y::Real) = real(x) * y
6-
@inline _realconjtimes(x::Real, y::Real) = x * y
7-
81
# imag(conj(x) * y) avoiding computing the real part if possible
92
@inline _imagconjtimes(x, y) = imag(conj(x) * y)
103
@inline function _imagconjtimes(x::Complex, y::Complex)

src/rulesets/LinearAlgebra/blas.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function frule((_, Δx), ::typeof(BLAS.nrm2), x)
4141
∂Ω = if x isa Real
4242
BLAS.dot(x, Δx) / s
4343
else
44-
sum(y -> _realconjtimes(y...), zip(x, Δx)) / s
44+
sum(y -> realdot(y...), zip(x, Δx)) / s
4545
end
4646
return Ω, ∂Ω
4747
end
@@ -72,7 +72,7 @@ end
7272

7373
function frule((_, Δx), ::typeof(BLAS.asum), x)
7474
∂Ω = sum(zip(x, Δx)) do (xi, Δxi)
75-
return _realconjtimes(_signcomp(xi), Δxi)
75+
return realdot(_signcomp(xi), Δxi)
7676
end
7777
return BLAS.asum(x), ∂Ω
7878
end

src/rulesets/LinearAlgebra/factorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ function _eigen_norm_phase_fwd!(∂V, A, V)
347347
@inbounds for i in axes(V, 2)
348348
v, ∂v = @views V[:, i], ∂V[:, i]
349349
# account for unit normalization
350-
∂c_norm = -real(dot(v, ∂v))
350+
∂c_norm = -realdot(v, ∂v)
351351
if eltype(V) <: Real
352352
∂c = ∂c_norm
353353
else

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function frule((_, ẋ), ::typeof(norm), x::Number, p::Real)
1414
zero(real(x)) * zero(real(Δx))
1515
else
1616
signx = x isa Real ? sign(x) : x * pinv(y)
17-
_realconjtimes(signx, Δx)
17+
realdot(signx, Δx)
1818
end
1919
return y, ∂y
2020
end
@@ -235,7 +235,7 @@ function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number})
235235
end
236236

237237
function _norm2_forward(x, Δx, y)
238-
∂y = real(dot(x, Δx)) * pinv(y)
238+
∂y = realdot(x, Δx) * pinv(y)
239239
return ∂y
240240
end
241241
function _norm2_back(x, y, Δy)
@@ -280,7 +280,7 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number})
280280
LinearAlgebra.__normalize!(y, nrm)
281281
function normalize_pullback(ȳ)
282282
Δy = unthunk(ȳ)
283-
∂x = (Δy .- real(dot(y, Δy)) .* y) .* pinv(nrm)
283+
∂x = (Δy .- realdot(y, Δy) .* y) .* pinv(nrm)
284284
return (NoTangent(), ∂x)
285285
end
286286
normalize_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent())

src/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ function frule(
218218
# diag(U' * tmp) without computing matrix product
219219
∂λ = similar(λ)
220220
@inbounds for i in eachindex(λ)
221-
∂λ[i] = @views real(dot(U[:, i], tmp[:, i]))
221+
∂λ[i] = @views realdot(U[:, i], tmp[:, i])
222222
end
223223
return λ, ∂λ
224224
end

0 commit comments

Comments
 (0)