-
Notifications
You must be signed in to change notification settings - Fork 92
Assume commutative multiplication exactly when necessary #540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
c35f043
5e036ca
331c65f
405427f
b4e90dc
723ae92
f7293d2
41a7d27
55d64f1
0cb446f
964e424
d1e8c2a
13d0c84
55abdeb
c5f032d
057d7d6
ed16465
8fd8c41
9237994
0e6f76b
8be7c70
9475447
dc29655
3e6e838
c05b577
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -76,22 +76,38 @@ end | |||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule hypot(x::Real) sign(x) | ||||||||||||||||||||
|
||||||||||||||||||||
function frule((_, Δz), ::typeof(hypot), z::Complex) | ||||||||||||||||||||
function frule((_, Δz), ::typeof(hypot), z::Number) | ||||||||||||||||||||
Ω = hypot(z) | ||||||||||||||||||||
∂Ω = _realconjtimes(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω) | ||||||||||||||||||||
return Ω, ∂Ω | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
function rrule(::typeof(hypot), z::Complex) | ||||||||||||||||||||
function rrule(::typeof(hypot), z::Number) | ||||||||||||||||||||
Ω = hypot(z) | ||||||||||||||||||||
function hypot_pullback(ΔΩ) | ||||||||||||||||||||
return (NoTangent(), (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) | ||||||||||||||||||||
end | ||||||||||||||||||||
return (Ω, hypot_pullback) | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule fma(x, y, z) (y, x, true) | ||||||||||||||||||||
@scalar_rule muladd(x, y, z) (y, x, true) | ||||||||||||||||||||
@scalar_rule fma(x, y::CommutativeMulNumber, z) (y, x, true) | ||||||||||||||||||||
function frule((_, Δx, Δy, Δz), ::typeof(fma), x::Number, y::Number, z::Number) | ||||||||||||||||||||
return fma(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz)) | ||||||||||||||||||||
end | ||||||||||||||||||||
function rrule(::typeof(fma), x::Number, y::Number, z::Number) | ||||||||||||||||||||
projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z) | ||||||||||||||||||||
fma_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ) | ||||||||||||||||||||
fma(x, y, z), fma_pullback | ||||||||||||||||||||
end | ||||||||||||||||||||
@scalar_rule muladd(x, y::CommutativeMulNumber, z) (y, x, true) | ||||||||||||||||||||
function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number) | ||||||||||||||||||||
return muladd(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz)) | ||||||||||||||||||||
end | ||||||||||||||||||||
function rrule(::typeof(muladd), x::Number, y::Number, z::Number) | ||||||||||||||||||||
projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z) | ||||||||||||||||||||
muladd_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ) | ||||||||||||||||||||
muladd(x, y, z), muladd_pullback | ||||||||||||||||||||
end | ||||||||||||||||||||
Comment on lines
+106
to
+110
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E.g. here I think this is very clear, the pattern of where the function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number)
return muladd(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz))
end but I'd like to make the corresponding
Suggested change
or perhaps And since it closes over |
||||||||||||||||||||
@scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent()) | ||||||||||||||||||||
@scalar_rule( | ||||||||||||||||||||
mod(x, y), | ||||||||||||||||||||
|
@@ -105,51 +121,51 @@ end | |||||||||||||||||||
@scalar_rule(ldexp(x, y), (2^y, NoTangent())) | ||||||||||||||||||||
|
||||||||||||||||||||
# Can't multiply though sqrt in acosh because of negative complex case for x | ||||||||||||||||||||
@scalar_rule acosh(x) inv(sqrt(x - 1) * sqrt(x + 1)) | ||||||||||||||||||||
@scalar_rule acoth(x) inv(1 - x ^ 2) | ||||||||||||||||||||
@scalar_rule acsch(x) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) | ||||||||||||||||||||
@scalar_rule acosh(x::CommutativeMulNumber) inv(sqrt(x - 1) * sqrt(x + 1)) | ||||||||||||||||||||
@scalar_rule acoth(x::CommutativeMulNumber) inv(1 - x ^ 2) | ||||||||||||||||||||
@scalar_rule acsch(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) | ||||||||||||||||||||
@scalar_rule acsch(x::Real) -(inv(abs(x) * sqrt(1 + x ^ 2))) | ||||||||||||||||||||
@scalar_rule asech(x) -(inv(x * sqrt(1 - x ^ 2))) | ||||||||||||||||||||
@scalar_rule asinh(x) inv(sqrt(x ^ 2 + 1)) | ||||||||||||||||||||
@scalar_rule atanh(x) inv(1 - x ^ 2) | ||||||||||||||||||||
@scalar_rule asech(x::CommutativeMulNumber) -(inv(x * sqrt(1 - x ^ 2))) | ||||||||||||||||||||
@scalar_rule asinh(x::CommutativeMulNumber) inv(sqrt(x ^ 2 + 1)) | ||||||||||||||||||||
@scalar_rule atanh(x::CommutativeMulNumber) inv(1 - x ^ 2) | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule acosd(x) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2) | ||||||||||||||||||||
@scalar_rule acotd(x) (-(oftype(x, 180)) / π) / (1 + x ^ 2) | ||||||||||||||||||||
@scalar_rule acscd(x) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2) | ||||||||||||||||||||
@scalar_rule acosd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2) | ||||||||||||||||||||
@scalar_rule acotd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / (1 + x ^ 2) | ||||||||||||||||||||
@scalar_rule acscd(x::CommutativeMulNumber) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2) | ||||||||||||||||||||
@scalar_rule acscd(x::Real) ((-(oftype(x, 180)) / π) / abs(x)) / sqrt(x ^ 2 - 1) | ||||||||||||||||||||
@scalar_rule asecd(x) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2) | ||||||||||||||||||||
@scalar_rule asecd(x::CommutativeMulNumber) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2) | ||||||||||||||||||||
@scalar_rule asecd(x::Real) ((oftype(x, 180) / π) / abs(x)) / sqrt(x ^ 2 - 1) | ||||||||||||||||||||
@scalar_rule asind(x) (oftype(x, 180) / π) / sqrt(1 - x ^ 2) | ||||||||||||||||||||
@scalar_rule atand(x) (oftype(x, 180) / π) / (1 + x ^ 2) | ||||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule cot(x) -((1 + Ω ^ 2)) | ||||||||||||||||||||
@scalar_rule coth(x) -(csch(x) ^ 2) | ||||||||||||||||||||
@scalar_rule cotd(x) -(π / oftype(x, 180)) * (1 + Ω ^ 2) | ||||||||||||||||||||
@scalar_rule csc(x) -Ω * cot(x) | ||||||||||||||||||||
@scalar_rule cscd(x) -(π / oftype(x, 180)) * Ω * cotd(x) | ||||||||||||||||||||
@scalar_rule csch(x) -(coth(x)) * Ω | ||||||||||||||||||||
@scalar_rule sec(x) Ω * tan(x) | ||||||||||||||||||||
@scalar_rule secd(x) (π / oftype(x, 180)) * Ω * tand(x) | ||||||||||||||||||||
@scalar_rule sech(x) -(tanh(x)) * Ω | ||||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule acot(x) -(inv(1 + x ^ 2)) | ||||||||||||||||||||
@scalar_rule acsc(x) -(inv(x ^ 2 * sqrt(1 - x ^ -2))) | ||||||||||||||||||||
@scalar_rule asind(x::CommutativeMulNumber) (oftype(x, 180) / π) / sqrt(1 - x ^ 2) | ||||||||||||||||||||
@scalar_rule atand(x::CommutativeMulNumber) (oftype(x, 180) / π) / (1 + x ^ 2) | ||||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule cot(x::CommutativeMulNumber) -((1 + Ω ^ 2)) | ||||||||||||||||||||
@scalar_rule coth(x::CommutativeMulNumber) -(csch(x) ^ 2) | ||||||||||||||||||||
@scalar_rule cotd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * (1 + Ω ^ 2) | ||||||||||||||||||||
@scalar_rule csc(x::CommutativeMulNumber) -Ω * cot(x) | ||||||||||||||||||||
@scalar_rule cscd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * Ω * cotd(x) | ||||||||||||||||||||
@scalar_rule csch(x::CommutativeMulNumber) -(coth(x)) * Ω | ||||||||||||||||||||
@scalar_rule sec(x::CommutativeMulNumber) Ω * tan(x) | ||||||||||||||||||||
@scalar_rule secd(x::CommutativeMulNumber) (π / oftype(x, 180)) * Ω * tand(x) | ||||||||||||||||||||
@scalar_rule sech(x::CommutativeMulNumber) -(tanh(x)) * Ω | ||||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule acot(x::CommutativeMulNumber) -(inv(1 + x ^ 2)) | ||||||||||||||||||||
@scalar_rule acsc(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 - x ^ -2))) | ||||||||||||||||||||
@scalar_rule acsc(x::Real) -(inv(abs(x) * sqrt(x ^ 2 - 1))) | ||||||||||||||||||||
@scalar_rule asec(x) inv(x ^ 2 * sqrt(1 - x ^ -2)) | ||||||||||||||||||||
@scalar_rule asec(x::CommutativeMulNumber) inv(x ^ 2 * sqrt(1 - x ^ -2)) | ||||||||||||||||||||
@scalar_rule asec(x::Real) inv(abs(x) * sqrt(x ^ 2 - 1)) | ||||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule cosd(x) -(π / oftype(x, 180)) * sind(x) | ||||||||||||||||||||
@scalar_rule cospi(x) -π * sinpi(x) | ||||||||||||||||||||
@scalar_rule sind(x) (π / oftype(x, 180)) * cosd(x) | ||||||||||||||||||||
@scalar_rule sinpi(x) π * cospi(x) | ||||||||||||||||||||
@scalar_rule tand(x) (π / oftype(x, 180)) * (1 + Ω ^ 2) | ||||||||||||||||||||
@scalar_rule cosd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * sind(x) | ||||||||||||||||||||
@scalar_rule cospi(x::CommutativeMulNumber) -π * sinpi(x) | ||||||||||||||||||||
@scalar_rule sind(x::CommutativeMulNumber) (π / oftype(x, 180)) * cosd(x) | ||||||||||||||||||||
@scalar_rule sinpi(x::CommutativeMulNumber) π * cospi(x) | ||||||||||||||||||||
@scalar_rule tand(x::CommutativeMulNumber) (π / oftype(x, 180)) * (1 + Ω ^ 2) | ||||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule sinc(x) cosc(x) | ||||||||||||||||||||
@scalar_rule sinc(x::CommutativeMulNumber) cosc(x) | ||||||||||||||||||||
|
||||||||||||||||||||
# the position of the minus sign below warrants the correct type for π | ||||||||||||||||||||
if VERSION ≥ v"1.6" | ||||||||||||||||||||
@scalar_rule sincospi(x) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix)) | ||||||||||||||||||||
@scalar_rule sincospi(x::CommutativeMulNumber) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix)) | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule( | ||||||||||||||||||||
|
@@ -160,7 +176,22 @@ end | |||||||||||||||||||
), | ||||||||||||||||||||
(!(islow | ishigh), islow, ishigh), | ||||||||||||||||||||
) | ||||||||||||||||||||
@scalar_rule x \ y (-(Ω / x), one(y) / x) | ||||||||||||||||||||
|
||||||||||||||||||||
@scalar_rule x::CommutativeMulNumber \ y::CommutativeMulNumber (-(x \ Ω), x \ one(y)) | ||||||||||||||||||||
function frule((_, Δx, Δy), ::typeof(\), x::Number, y::Number) | ||||||||||||||||||||
Ω = x \ y | ||||||||||||||||||||
return Ω, x \ muladd(Δy, -Δx, Ω) | ||||||||||||||||||||
end | ||||||||||||||||||||
function rrule(::typeof(\), x::Number, y::Number) | ||||||||||||||||||||
Ω = x \ y | ||||||||||||||||||||
project_x = ProjectTo(x) | ||||||||||||||||||||
project_y = ProjectTo(y) | ||||||||||||||||||||
function backslash_pullback(ΔΩ) | ||||||||||||||||||||
∂x = x' \ ΔΩ | ||||||||||||||||||||
return NoTangent(), project_x(∂x), project_y(-∂x * Ω') | ||||||||||||||||||||
end | ||||||||||||||||||||
return Ω, backslash_pullback | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
function frule((_, ẏ), ::typeof(identity), x) | ||||||||||||||||||||
return (x, ẏ) | ||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I ask why you moved the minus?
If it was
-true * Ω' * ΔΩ * Ω'
then I think you'd save a copy (since this gets fused intomul!
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So that if
ΔΩ
is anAbstractZero
or aUniformScaling
, then the negation is cheaper.I didn't follow this. How is this fused into the
mul!
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this was not an important change, and I'm happy to remove)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I didn't think about those. For dense matrices there's a 4-arg method which fuses this:
But with
I
, no fusion, hencef2
is slower. Maybe*
should have some extra methods for cases withI
.