@@ -4,6 +4,11 @@ struct Foo
4
4
y:: Float64
5
5
end
6
6
7
+ mutable struct MFoo
8
+ x:: Float64
9
+ y
10
+ end
11
+
7
12
# For testing Primal + Tangent performance
8
13
struct Bar
9
14
x:: Float64
@@ -452,14 +457,40 @@ end
452
457
end
453
458
454
459
@testset " == and hash" begin
455
- @test MutableTangent {Any } (; x= 1.0 ) == MutableTangent {MDemo} (; x= 1.0 )
456
- @test MutableTangent {MDemo} (; x= 1.0 ) == MutableTangent {Any } (; x= 1.0 )
457
- @test MutableTangent {Any } (; x= 2.0 ) != MutableTangent {MDemo} (; x= 1.0 )
458
- @test MutableTangent {MDemo} (; x= 1.0 ) != MutableTangent {Any } (; x= 2.0 )
460
+ @test MutableTangent {MDemo } (; x= 1f0 ) == MutableTangent {MDemo} (; x= 1.0 )
461
+ @test MutableTangent {MDemo} (; x= 1.0 ) == MutableTangent {MDemo } (; x= 1f0 )
462
+ @test MutableTangent {MDemo } (; x= 2.0 ) != MutableTangent {MDemo} (; x= 1.0 )
463
+ @test MutableTangent {MDemo} (; x= 1.0 ) != MutableTangent {MDemo } (; x= 2.0 )
459
464
460
465
nt = (; x= 1.0 )
461
466
@test MutableTangent {typeof(nt)} (nt) != MutableTangent {MDemo} (; x= 1.0 )
462
467
463
- @test hash (MutableTangent {Any} (; x= 1.0 )) == hash (MutableTangent {MDemo} (; x= 1.0 ))
468
+ @test hash (MutableTangent {MDemo} (; x= 1f0 )) == hash (MutableTangent {MDemo} (; x= 1.0 ))
469
+ end
470
+
471
+ @testset " Mutation" begin
472
+ v = MutableTangent {MFoo} (x= 1.5 , y= 2.4 )
473
+ v. x = 1.6
474
+ @test v == MutableTangent {MFoo} (x= 1.6 , y= 2.4 )
475
+ v. y = [1.0 , 2.0 ] # change type, because primal can change type
476
+ @test v == MutableTangent {MFoo} (x= 1.6 , y= [1.0 , 2.0 ])
477
+ end
478
+ end
479
+
480
+ @testset " map" begin
481
+ @testset " Tangent" begin
482
+ ∂foo = Tangent {Foo} (x= 1.5 , y= 2.4 )
483
+ @test map (v-> 2 * v, ∂foo) == Tangent {Foo} (x= 3.0 , y= 4.8 )
484
+
485
+ ∂foo = Tangent {Foo} (x= 1.5 )
486
+ @test map (v-> 2 * v, ∂foo) == Tangent {Foo} (x= 3.0 )
487
+ end
488
+ @testset " MutableTangent" begin
489
+ ∂foo = MutableTangent {MFoo} (x= 1.5 , y= 2.4 )
490
+ ∂foo2 = map (v-> 2 * v, ∂foo)
491
+ @test ∂foo2 == MutableTangent {MFoo} (x= 3.0 , y= 4.8 )
492
+ # Check can still be mutated to new typ
493
+ ∂foo2. y= [1.0 , 2.0 ]
494
+ @test ∂foo2 == MutableTangent {MFoo} (x= 3.0 , y= [1.0 , 2.0 ])
464
495
end
465
496
end
0 commit comments