Skip to content

Commit 15b1452

Browse files
mcabbottMichael Abbott
andauthored
RFC: rules related to 3-arg * (#412)
* gradients for 3-arg *, untested * fixup, add tests * version bound defn * fix depwarns & indents * update InplaceableThunk * fix complaints * simplify + update * tweak * version * tweak Co-authored-by: Michael Abbott <me@escbook>
1 parent 6aed351 commit 15b1452

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.18.1"
3+
version = "1.19"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/arraymath.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,77 @@ function rrule(
128128
return A * B, times_pullback
129129
end
130130

131+
#####
132+
##### fused 3-argument *
133+
#####
134+
135+
if VERSION > v"1.7.0-"
136+
137+
@eval using LinearAlgebra: mat_mat_scalar, mat_vec_scalar, StridedMaybeAdjOrTransMat
138+
139+
function rrule(
140+
::typeof(mat_mat_scalar),
141+
A::StridedMaybeAdjOrTransMat{<:CommutativeMulNumber},
142+
B::StridedMaybeAdjOrTransMat{<:CommutativeMulNumber},
143+
γ::CommutativeMulNumber
144+
)
145+
project_A = ProjectTo(A)
146+
project_B = ProjectTo(B)
147+
project_γ = ProjectTo(γ)
148+
C = mat_mat_scalar(A, B, γ)
149+
function mat_mat_scalar_back(Ȳraw)
150+
= unthunk(Ȳraw)
151+
Athunk = InplaceableThunk(
152+
dA -> mul!(dA, Ȳ, B', conj(γ), true),
153+
@thunk(project_A(mat_mat_scalar(Ȳ, B', conj(γ)))),
154+
)
155+
Bthunk = InplaceableThunk(
156+
dB -> mul!(dB, A', Ȳ, conj(γ), true),
157+
@thunk(project_B(mat_mat_scalar(A', Ȳ, conj(γ)))),
158+
)
159+
γthunk = @thunk if iszero(γ)
160+
# Could save A*B on the forward pass, but it's messy.
161+
# This ought to be rare, should guarantee the same type:
162+
project_γ(dot(mat_mat_scalar(A, B, oneunit(γ)), Ȳ) / one(γ))
163+
else
164+
project_γ(dot(C, Ȳ) / conj(γ))
165+
end
166+
return (NoTangent(), Athunk, Bthunk, γthunk)
167+
end
168+
return C, mat_mat_scalar_back
169+
end
170+
171+
function rrule(
172+
::typeof(mat_vec_scalar),
173+
A::StridedMaybeAdjOrTransMat{<:CommutativeMulNumber},
174+
b::StridedVector{<:CommutativeMulNumber},
175+
γ::CommutativeMulNumber
176+
)
177+
project_A = ProjectTo(A)
178+
project_b = ProjectTo(b)
179+
project_γ = ProjectTo(γ)
180+
y = mat_vec_scalar(A, b, γ)
181+
function mat_vec_scalar_back(dy_raw)
182+
dy = unthunk(dy_raw)
183+
Athunk = InplaceableThunk(
184+
dA -> mul!(dA, dy, b', conj(γ), true),
185+
@thunk(project_A(*(dy, b', conj(γ)))),
186+
)
187+
Bthunk = InplaceableThunk(
188+
db -> mul!(db, A', dy, conj(γ), true),
189+
@thunk(project_b(*(A', dy, conj(γ)))),
190+
)
191+
γthunk = @thunk if iszero(γ)
192+
project_γ(dot(mat_vec_scalar(A, b, oneunit(γ)), dy))
193+
else
194+
project_γ(dot(y, dy) / conj(γ))
195+
end
196+
return (NoTangent(), Athunk, Bthunk, γthunk)
197+
end
198+
return y, mat_vec_scalar_back
199+
end
200+
201+
end # VERSION
131202

132203
#####
133204
##### `muladd`

test/rulesets/Base/arraymath.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,28 @@
111111
end
112112
end
113113

114+
if VERSION > v"1.7.0-"
115+
@eval using LinearAlgebra: mat_mat_scalar, mat_vec_scalar, StridedMaybeAdjOrTransMat
116+
117+
@testset "3-arg *, $T" for T in [Float64, ComplexF64]
118+
119+
test_rrule(mat_mat_scalar, rand(T,4,4), rand(T,4,4), rand(T))
120+
test_rrule(mat_mat_scalar, rand(T,4,4), rand(T,4,4), 0.0)
121+
test_rrule(mat_mat_scalar, rand(T,4,4)' rand(T,4,4), rand(T,4,4), rand(T))
122+
123+
test_rrule(mat_vec_scalar, rand(T,4,4), rand(T,4), rand(T))
124+
test_rrule(mat_vec_scalar, rand(T,4,4), rand(T,4), 0.0)
125+
126+
T == ComplexF64 && continue
127+
# Test with γ of a wider type
128+
A, B, b, γ = rand(3,3), rand(3,3), rand(3), rand()
129+
dZ, dz = rand(3,3), rand(3)
130+
unthunk(rrule(mat_mat_scalar, A, B, γ + 0im)[2](dZ)[4]) unthunk(rrule(mat_mat_scalar, A, B, γ)[2](dZ)[4])
131+
unthunk(rrule(mat_mat_scalar, A, B, 0 + 0im)[2](dZ)[4]) unthunk(rrule(mat_mat_scalar, A, B, 0)[2](dZ)[4])
132+
unthunk(rrule(mat_vec_scalar, A, b, γ + 0im)[2](dz)[4]) unthunk(rrule(mat_vec_scalar, A, b, γ)[2](dz)[4])
133+
end
134+
end # VERSION
135+
114136
@testset "$f" for f in (/, \)
115137
@testset "Matrix" begin
116138
for n in 3:5, m in 3:5

0 commit comments

Comments
 (0)