Skip to content

Commit 34aec33

Browse files
authored
Restrict types accepted by frule muladd and make tests more comprehensive (#728)
* restrict muladd frule and add more tests * bump version
1 parent f0095e0 commit 34aec33

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.52.0"
3+
version = "1.52.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/arraymath.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,13 @@ end # VERSION
210210
##### `muladd`
211211
#####
212212

213-
function frule((_, ΔA, ΔB, Δz), ::typeof(muladd), A, B, z)
213+
function frule(
214+
(_, ΔA, ΔB, Δz),
215+
::typeof(muladd),
216+
A::AbstractVecOrMat{<:CommutativeMulNumber},
217+
B::AbstractVecOrMat{<:CommutativeMulNumber},
218+
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}}
219+
)
214220
Ω = muladd(A, B, z)
215221
return Ω, ΔA * B .+ A * ΔB .+ Δz
216222
end

test/rulesets/Base/arraymath.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,44 +85,49 @@
8585

8686
@testset "muladd: $T" for T in (Float64, ComplexF64)
8787
@testset "add $(typeof(z))" for z in [rand(), rand(T, 3), rand(T, 3, 3), false]
88-
@testset "forward mode" begin
89-
@gpu test_frule(muladd, rand(T, 3, 5), rand(T, 5, 3), z)
90-
end
9188
@testset "matrix * matrix" begin
9289
A = rand(T, 3, 3)
9390
B = rand(T, 3, 3)
9491
@gpu test_rrule(muladd, A, B, z)
9592
@gpu test_rrule(muladd, A', B, z)
9693
@gpu test_rrule(muladd, A , B', z)
94+
@gpu test_frule(muladd, A, B, z)
95+
@gpu test_frule(muladd, A', B, z)
96+
@gpu test_frule(muladd, A , B', z)
9797

9898
C = rand(T, 3, 5)
9999
D = rand(T, 5, 3)
100100
@gpu test_rrule(muladd, C, D, z)
101+
@gpu test_frule(muladd, C, D, z)
101102
end
102103
if ndims(z) <= 1
103104
@testset "matrix * vector" begin
104105
A, B = rand(T, 3, 3), rand(T, 3)
105106
test_rrule(muladd, A, B, z)
106107
test_rrule(muladd, A, B rand(T, 3,1), z)
108+
test_frule(muladd, A, B, z)
107109
end
108110
@testset "adjoint * matrix" begin
109111
At, B = rand(T, 3)', rand(T, 3, 3)
110112
test_rrule(muladd, At, B, z')
111113
test_rrule(muladd, At rand(T,1,3), B, z')
114+
test_frule(muladd, At, B, z')
112115
end
113116
end
114117
if ndims(z) == 0
115118
@testset "adjoint * vector" begin # like dot
116119
A, B = rand(T, 3)', rand(T, 3)
117120
test_rrule(muladd, A, B, z)
118121
test_rrule(muladd, A rand(T,1,3), B, z')
122+
test_frule(muladd, A, B, z)
119123
end
120124
end
121125
if ndims(z) == 2 # other dims lead to e.g. muladd(ones(4), ones(1,4), 1)
122126
@testset "vector * adjoint" begin # outer product
123127
A, B = rand(T, 3), rand(T, 3)'
124128
test_rrule(muladd, A, B, z)
125129
test_rrule(muladd, A, B rand(T,1,3), z)
130+
test_frule(muladd, A, B, z)
126131
end
127132
end
128133
end

0 commit comments

Comments
 (0)