Skip to content

Commit 970fce4

Browse files
mcabbottmzgubic
andauthored
Some rules for UniformScaling (#571)
* rules for uniformscaling + - Matrix * use CRC 1.12 * constructor * have to press save before pressing commit * version * tweak * return Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com> * trick Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
1 parent 3355ae7 commit 970fce4

File tree

5 files changed

+136
-2
lines changed

5 files changed

+136
-2
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.25"
3+
version = "1.26"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -12,7 +12,7 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
1212
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313

1414
[compat]
15-
ChainRulesCore = "1.11.5"
15+
ChainRulesCore = "1.12"
1616
ChainRulesTestUtils = "1.5"
1717
Compat = "3.35"
1818
FiniteDifferences = "0.12.20"

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ include("rulesets/LinearAlgebra/matfun.jl")
4141
include("rulesets/LinearAlgebra/structured.jl")
4242
include("rulesets/LinearAlgebra/symmetric.jl")
4343
include("rulesets/LinearAlgebra/factorization.jl")
44+
include("rulesets/LinearAlgebra/uniformscaling.jl")
4445

4546
include("rulesets/Random/random.jl")
4647

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
2+
#####
3+
##### constructor
4+
#####
5+
6+
function rrule(::Type{T}, x::Number) where {T<:UniformScaling}
7+
UniformScaling_back(dx) = (NoTangent(), ProjectTo(x)(unthunk(dx).λ))
8+
return T(x), UniformScaling_back
9+
end
10+
11+
#####
12+
##### `+`
13+
#####
14+
15+
function frule((_, Δx, ΔJ), ::typeof(+), x::AbstractMatrix, J::UniformScaling)
16+
return x + J, Δx + (zero(J) + ΔJ) # This (0 + ΔJ) allows for ΔJ::Tangent{UniformScaling}
17+
end
18+
19+
function frule((_, ΔJ, Δx), ::typeof(+), J::UniformScaling, x::AbstractMatrix)
20+
return J + x, (zero(J) + ΔJ) + Δx
21+
end
22+
23+
function rrule(::typeof(+), x::AbstractMatrix, J::UniformScaling)
24+
project_x = ProjectTo(x)
25+
project_J = ProjectTo(J)
26+
function plus_back(dy)
27+
dx = unthunk(dy)
28+
return (NoTangent(), project_x(dx), project_J(I * tr(dx)))
29+
end
30+
return x + J, plus_back
31+
end
32+
33+
function rrule(::typeof(+), J::UniformScaling, x::AbstractMatrix)
34+
y, back = rrule(+, x, J)
35+
function plus_back_2(dy)
36+
df, dx, dJ = back(dy)
37+
return (df, dJ, dx)
38+
end
39+
return y, plus_back_2
40+
end
41+
42+
#####
43+
##### `-`
44+
#####
45+
46+
function frule((_, Δx, ΔJ), ::typeof(-), x::AbstractMatrix, J::UniformScaling)
47+
return x - J, Δx - (zero(J) + ΔJ)
48+
end
49+
50+
function frule((_, ΔJ, Δx), ::typeof(-), J::UniformScaling, x::AbstractMatrix)
51+
return J - x, (zero(J) + ΔJ) - Δx
52+
end
53+
54+
function rrule(::typeof(-), x::AbstractMatrix, J::UniformScaling)
55+
y, back = rrule(+, x, -J)
56+
project_J = ProjectTo(J)
57+
function minus_back_1(dy)
58+
df, dx, dJ = back(dy)
59+
return (df, dx, project_J(-dJ)) # re-project as -true isa Int
60+
end
61+
return y, minus_back_1
62+
end
63+
64+
function rrule(::typeof(-), J::UniformScaling, x::AbstractMatrix)
65+
project_x = ProjectTo(x)
66+
project_J = ProjectTo(J)
67+
function minus_back_2(dy)
68+
dx = -unthunk(dy)
69+
return (NoTangent(), project_J(-tr(dx) * I), project_x(dx))
70+
end
71+
return J - x, minus_back_2
72+
end
73+
74+
#####
75+
##### `Matrix`
76+
#####
77+
78+
function rrule(::Type{T}, I::UniformScaling{<:Bool}, (m, n)) where {T<:AbstractMatrix}
79+
Matrix_back_I(dy) = (NoTangent(), NoTangent(), NoTangent())
80+
return T(I, m, n), Matrix_back_I
81+
end
82+
83+
function rrule(::Type{T}, J::UniformScaling, (m, n)) where {T<:AbstractMatrix}
84+
project_J = ProjectTo(J)
85+
function Matrix_back_I(dy)
86+
dJ = if m == n
87+
project_J(I * tr(unthunk(dy)))
88+
else
89+
project_J(I * sum(diag(unthunk(dy))))
90+
end
91+
return (NoTangent(), dJ, NoTangent())
92+
end
93+
return T(J, m, n), Matrix_back_I
94+
end
95+
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
@testset "UniformScaling rules" begin
2+
3+
@testset "constructor" begin
4+
test_rrule(UniformScaling, rand())
5+
end
6+
7+
@testset "+" begin
8+
# Forward
9+
test_frule(+, rand(3,3), I * rand(ComplexF64))
10+
test_frule(+, I, rand(3,3))
11+
12+
# Reverse
13+
test_rrule(+, rand(3,3), I)
14+
test_rrule(+, rand(3,3), I * rand(ComplexF64))
15+
test_rrule(+, I, rand(3,3))
16+
test_rrule(+, I * rand(), rand(ComplexF64, 3,3))
17+
end
18+
19+
@testset "-" begin
20+
# Forward
21+
test_frule(-, rand(3,3), I * rand(ComplexF64))
22+
test_frule(-, I, rand(3,3))
23+
24+
# Reverse
25+
test_rrule(-, rand(3,3), I)
26+
test_rrule(-, rand(3,3), I * rand(ComplexF64))
27+
test_rrule(-, I, rand(3,3))
28+
test_rrule(-, I * rand(), rand(ComplexF64, 3,3))
29+
end
30+
31+
@testset "Matrix" begin
32+
test_rrule(Matrix, I, (2, 2))
33+
test_rrule(Matrix{ComplexF64}, rand()*I, (3, 3))
34+
test_rrule(Matrix, rand(ComplexF64)*I, (2, 4))
35+
end
36+
37+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ end
7171
include_test("rulesets/LinearAlgebra/factorization.jl")
7272
include_test("rulesets/LinearAlgebra/blas.jl")
7373
include_test("rulesets/LinearAlgebra/lapack.jl")
74+
include_test("rulesets/LinearAlgebra/uniformscaling.jl")
7475

7576
println()
7677

0 commit comments

Comments
 (0)