File tree Expand file tree Collapse file tree 4 files changed +16
-1
lines changed Expand file tree Collapse file tree 4 files changed +16
-1
lines changed Original file line number Diff line number Diff line change 1
1
name = " ChainRulesCore"
2
2
uuid = " d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3
- version = " 1.11.6 "
3
+ version = " 1.12 "
4
4
5
5
[deps ]
6
6
Compat = " 34da2185-b29b-5c13-b0c7-acf172513d20"
Original file line number Diff line number Diff line change 379
379
380
380
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
381
381
382
+ # UniformScaling can represent its own cotangent
383
+ ProjectTo (x:: UniformScaling ) = ProjectTo {UniformScaling} (; λ= ProjectTo (x. λ))
384
+ ProjectTo (x:: UniformScaling{Bool} ) = ProjectTo (false )
385
+ (pr:: ProjectTo{UniformScaling} )(dx:: UniformScaling ) = UniformScaling (pr. λ (dx. λ))
386
+ (pr:: ProjectTo{UniformScaling} )(dx:: Tangent{<:UniformScaling} ) = UniformScaling (pr. λ (dx. λ))
387
+
382
388
# Row vectors
383
389
ProjectTo (x:: AdjointAbsVec ) = ProjectTo {Adjoint} (; parent= ProjectTo (parent (x)))
384
390
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
Original file line number Diff line number Diff line change @@ -37,6 +37,8 @@ Base.sum(z::AbstractZero; dims=:) = z
37
37
Base. reshape (z:: AbstractZero , size... ) = z
38
38
Base. reverse (z:: AbstractZero , args... ; kwargs... ) = z
39
39
40
+ (:: Type{<:UniformScaling} )(z:: AbstractZero ) = z
41
+
40
42
"""
41
43
ZeroTangent() <: AbstractZero
42
44
Original file line number Diff line number Diff line change @@ -208,6 +208,13 @@ struct NoSuperType end
208
208
# #### `LinearAlgebra`
209
209
# ####
210
210
211
+ @testset " UniformScaling" begin
212
+ @test ProjectTo (I)(123 ) === NoTangent ()
213
+ @test ProjectTo (2 * I)(I * 3im ) === 0.0 * I
214
+ @test ProjectTo ((4 + 5im ) * I)(Tangent {typeof(im * I)} (; λ = 6 )) === (6.0 + 0.0im ) * I
215
+ @test ProjectTo (7 * I)(Tangent {typeof(2I)} ()) == ZeroTangent ()
216
+ end
217
+
211
218
@testset " LinearAlgebra: $adj vectors" for adj in [transpose, adjoint]
212
219
# adjoint vectors
213
220
padj = ProjectTo (adj ([1 , 2 , 3 ]))
You can’t perform that action at this time.
0 commit comments