Skip to content

Commit 1d076fa

Browse files
authored
Projection for UniformScaling (#533)
* I does his own stunts * version
1 parent 328ac3c commit 1d076fa

File tree

4 files changed

+16
-1
lines changed

4 files changed

+16
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.11.6"
3+
version = "1.12"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/projection.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,12 @@ end
379379

380380
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
381381

382+
# UniformScaling can represent its own cotangent
383+
ProjectTo(x::UniformScaling) = ProjectTo{UniformScaling}(; λ=ProjectTo(x.λ))
384+
ProjectTo(x::UniformScaling{Bool}) = ProjectTo(false)
385+
(pr::ProjectTo{UniformScaling})(dx::UniformScaling) = UniformScaling(pr.λ(dx.λ))
386+
(pr::ProjectTo{UniformScaling})(dx::Tangent{<:UniformScaling}) = UniformScaling(pr.λ(dx.λ))
387+
382388
# Row vectors
383389
ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x)))
384390
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.

src/tangent_types/abstract_zero.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ Base.sum(z::AbstractZero; dims=:) = z
3737
Base.reshape(z::AbstractZero, size...) = z
3838
Base.reverse(z::AbstractZero, args...; kwargs...) = z
3939

40+
(::Type{<:UniformScaling})(z::AbstractZero) = z
41+
4042
"""
4143
ZeroTangent() <: AbstractZero
4244

test/projection.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,13 @@ struct NoSuperType end
208208
##### `LinearAlgebra`
209209
#####
210210

211+
@testset "UniformScaling" begin
212+
@test ProjectTo(I)(123) === NoTangent()
213+
@test ProjectTo(2 * I)(I * 3im) === 0.0 * I
214+
@test ProjectTo((4 + 5im) * I)(Tangent{typeof(im * I)}(; λ = 6)) === (6.0 + 0.0im) * I
215+
@test ProjectTo(7 * I)(Tangent{typeof(2I)}()) == ZeroTangent()
216+
end
217+
211218
@testset "LinearAlgebra: $adj vectors" for adj in [transpose, adjoint]
212219
# adjoint vectors
213220
padj = ProjectTo(adj([1, 2, 3]))

0 commit comments

Comments
 (0)