Skip to content

Commit 6053597

Browse files
authored
Add in-place updating for gemm and gemv rrules (#49)
Also use approximate equality instead of strict equality in the various accumulation tests. These can easily fail when using floating point values with `==`.
1 parent 1df670e commit 6053597

File tree

4 files changed

+57
-41
lines changed

4 files changed

+57
-41
lines changed

src/rules/linalg/blas.jl

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,23 @@ end
6262
##### `BLAS.gemv`
6363
#####
6464

65-
function rrule(::typeof(BLAS.gemv), tA, α, A, x)
66-
Ω = BLAS.gemv(tA, α, A, x)
67-
∂α = ΔΩ -> dot(ΔΩ, Ω) / α
68-
∂A = ΔΩ -> uppercase(tA) == 'N' ? α * ΔΩ * x' : α * x * ΔΩ'
69-
∂x = ΔΩ -> gemv(uppercase(tA) == 'N' ? 'T' : 'N', α, A, ΔΩ)
70-
return Ω, (DNERule(), _rule_via(∂α), _rule_via(∂A), _rule_via(∂x))
65+
function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T},
66+
x::AbstractVector{T}) where T<:BlasFloat
67+
y = gemv(tA, α, A, x)
68+
if uppercase(tA) === 'N'
69+
∂A = Rule(ȳ -> α ** x', (Ā, ȳ) -> ger!(α, ȳ, x, Ā))
70+
∂x = Rule(ȳ -> gemv('T', α, A, ȳ), (x̄, ȳ) -> gemv!('T', α, A, ȳ, one(T), x̄))
71+
else
72+
∂A = Rule(ȳ -> α * x *', (Ā, ȳ) -> ger!(α, x, ȳ, Ā))
73+
∂x = Rule(ȳ -> gemv('N', α, A, ȳ), (x̄, ȳ) -> gemv!('N', α, A, ȳ, one(T), x̄))
74+
end
75+
return y, (DNERule(), Rule(ȳ -> dot(ȳ, y) / α), ∂A, ∂x)
7176
end
7277

73-
function rrule(f::typeof(BLAS.gemv), tA, A, x)
74-
Ω, (dtA, dα, dA, dx) = rrule(f, tA, one(eltype(A)), A, x)
75-
return Ω, (dtA, dA, dx)
78+
function rrule(::typeof(gemv), tA::Char, A::AbstractMatrix{T},
79+
x::AbstractVector{T}) where T<:BlasFloat
80+
y, (dtA, _, dA, dx) = rrule(gemv, tA, one(T), A, x)
81+
return y, (dtA, dA, dx)
7682
end
7783

7884
#####
@@ -82,25 +88,33 @@ end
8288
function rrule(::typeof(gemm), tA::Char, tB::Char, α::T,
8389
A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat
8490
C = gemm(tA, tB, α, A, B)
85-
∂α = -> sum(C̄ .* C) / α
91+
β = one(T)
8692
if uppercase(tA) === 'N'
8793
if uppercase(tB) === 'N'
88-
∂A =-> gemm('N', 'T', α, C̄, B)
89-
∂B =-> gemm('T', 'N', α, A, C̄)
94+
∂A = Rule(C̄ -> gemm('N', 'T', α, C̄, B),
95+
(Ā, C̄) -> gemm!('N', 'T', α, C̄, B, β, Ā))
96+
∂B = Rule(C̄ -> gemm('T', 'N', α, A, C̄),
97+
(B̄, C̄) -> gemm!('T', 'N', α, A, C̄, β, B̄))
9098
else
91-
∂A =-> gemm('N', 'N', α, C̄, B)
92-
∂B =-> gemm('T', 'N', α, C̄, A)
99+
∂A = Rule(C̄ -> gemm('N', 'N', α, C̄, B),
100+
(Ā, C̄) -> gemm!('N', 'N', α, C̄, B, β, Ā))
101+
∂B = Rule(C̄ -> gemm('T', 'N', α, C̄, A),
102+
(B̄, C̄) -> gemm!('T', 'N', α, C̄, A, β, B̄))
93103
end
94104
else
95105
if uppercase(tB) === 'N'
96-
∂A =-> gemm('N', 'T', α, B, C̄)
97-
∂B =-> gemm('N', 'N', α, A, C̄)
106+
∂A = Rule(C̄ -> gemm('N', 'T', α, B, C̄),
107+
(Ā, C̄) -> gemm!('N', 'T', α, B, C̄, β, Ā))
108+
∂B = Rule(C̄ -> gemm('N', 'N', α, A, C̄),
109+
(B̄, C̄) -> gemm!('N', 'N', α, A, C̄, β, B̄))
98110
else
99-
∂A =-> gemm('T', 'T', α, B, C̄)
100-
∂B =-> gemm('T', 'T', α, C̄, A)
111+
∂A = Rule(C̄ -> gemm('T', 'T', α, B, C̄),
112+
(Ā, C̄) -> gemm!('T', 'T', α, B, C̄, β, Ā))
113+
∂B = Rule(C̄ -> gemm('T', 'T', α, C̄, A),
114+
(B̄, C̄) -> gemm!('T', 'T', α, C̄, A, β, B̄))
101115
end
102116
end
103-
return C, (DNERule(), DNERule(), _rule_via(∂α), _rule_via(∂A), _rule_via(∂B))
117+
return C, (DNERule(), DNERule(), Rule(C̄ -> dot(C̄, C) / α), ∂A, ∂B)
104118
end
105119

106120
function rrule(::typeof(gemm), tA::Char, tB::Char,

test/rules/blas.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using LinearAlgebra.BLAS: gemm
2-
31
@testset "BLAS" begin
42
@testset "gemm" begin
53
rng = MersenneTwister(1)
@@ -9,18 +7,21 @@ using LinearAlgebra.BLAS: gemm
97
A = randn(rng, tA === 'N' ? (m, n) : (n, m))
108
B = randn(rng, tB === 'N' ? (n, p) : (p, n))
119
C = gemm(tA, tB, α, A, B)
12-
fAB, (dtA, dtB, dα, dA, dB) = rrule(gemm, tA, tB, α, A, B)
13-
@test C fAB
14-
@test dtA isa ChainRules.DNERule
15-
@test dtB isa ChainRules.DNERule
16-
for (f, x, dx) in [(X->gemm(tA, tB, X, A, B), α, dα),
17-
(X->gemm(tA, tB, α, X, B), A, dA),
18-
(X->gemm(tA, tB, α, A, X), B, dB)]
19-
= randn(rng, size(C)...)
20-
x̄_ad = dx(ȳ)
21-
x̄_fd = j′vp(central_fdm(5, 1), f, ȳ, x)
22-
@test x̄_ad x̄_fd rtol=1e-9 atol=1e-9
23-
end
10+
= randn(rng, size(C)...)
11+
rrule_test(gemm, ȳ, (tA, nothing), (tB, nothing), (α, randn(rng)),
12+
(A, randn(rng, size(A))), (B, randn(rng, size(B))))
13+
end
14+
end
15+
@testset "gemv" begin
16+
rng = MersenneTwister(2)
17+
for n in 3:5, m in 3:5, t in ('N', 'T')
18+
α = randn(rng)
19+
A = randn(rng, m, n)
20+
x = randn(rng, t === 'N' ? n : m)
21+
y = α * (t === 'N' ? A : A') * x
22+
= randn(rng, size(y)...)
23+
rrule_test(gemv, ȳ, (t, nothing), (α, randn(rng)), (A, randn(rng, size(A))),
24+
(x, randn(rng, size(x))))
2425
end
2526
end
2627
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# TODO: more tests!
22

3-
using ChainRules, Test, FDM, LinearAlgebra, Random
3+
using ChainRules, Test, FDM, LinearAlgebra, LinearAlgebra.BLAS, Random
44
using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule,
55
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
66
Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,
77
DNE, Thunk, Casted, DNERule
88
using Base.Broadcast: broadcastable
9+
import LinearAlgebra: dot
910

1011
include("test_util.jl")
1112

test/test_util.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,41 +119,41 @@ function Base.isapprox(d_ad::Thunk, d_fd; kwargs...)
119119
end
120120

121121
function test_accumulation(x̄, dx, ȳ, partial)
122-
@test all(extern(ChainRules.add(x̄, partial)) .== extern(x̄) .+ extern(partial))
122+
@test all(extern(ChainRules.add(x̄, partial)) . extern(x̄) .+ extern(partial))
123123
test_accumulate(x̄, dx, ȳ, partial)
124124
test_accumulate!(x̄, dx, ȳ, partial)
125125
test_store!(x̄, dx, ȳ, partial)
126126
return nothing
127127
end
128128

129129
function test_accumulate(x̄::Zero, dx, ȳ, partial)
130-
@test extern(accumulate(x̄, dx, ȳ)) == extern(partial)
130+
@test extern(accumulate(x̄, dx, ȳ)) extern(partial)
131131
return nothing
132132
end
133133

134134
function test_accumulate(x̄::Number, dx, ȳ, partial)
135-
@test extern(accumulate(x̄, dx, ȳ)) == extern(x̄) + extern(partial)
135+
@test extern(accumulate(x̄, dx, ȳ)) extern(x̄) + extern(partial)
136136
return nothing
137137
end
138138

139139
function test_accumulate(x̄::AbstractArray, dx, ȳ, partial)
140140
x̄_old = copy(x̄)
141-
@test all(extern(accumulate(x̄, dx, ȳ)) .== (extern(x̄) .+ extern(partial)))
141+
@test all(extern(accumulate(x̄, dx, ȳ)) . (extern(x̄) .+ extern(partial)))
142142
@test== x̄_old
143143
return nothing
144144
end
145145

146146
test_accumulate!(x̄::Zero, dx, ȳ, partial) = nothing
147147

148148
function test_accumulate!(x̄::Number, dx, ȳ, partial)
149-
@test accumulate!(x̄, dx, ȳ) == accumulate(x̄, dx, ȳ)
149+
@test accumulate!(x̄, dx, ȳ) accumulate(x̄, dx, ȳ)
150150
return nothing
151151
end
152152

153153
function test_accumulate!(x̄::AbstractArray, dx, ȳ, partial)
154154
x̄_copy = copy(x̄)
155155
accumulate!(x̄_copy, dx, ȳ)
156-
@test extern(x̄_copy) == (extern(x̄) .+ extern(partial))
156+
@test extern(x̄_copy) (extern(x̄) .+ extern(partial))
157157
return nothing
158158
end
159159

@@ -163,6 +163,6 @@ test_store!(x̄::Number, dx, ȳ, partial) = nothing
163163
function test_store!(x̄::AbstractArray, dx, ȳ, partial)
164164
x̄_copy = copy(x̄)
165165
store!(x̄_copy, dx, ȳ)
166-
@test all(x̄_copy .== extern(partial))
166+
@test all(x̄_copy . extern(partial))
167167
return nothing
168168
end

0 commit comments

Comments
 (0)