@@ -6,12 +6,15 @@ Random.seed!(123)
6
6
N = 300
7
7
8
8
# Use Float32 since Zygote defaults to Float32
9
- x = rand (Float32, N)
9
+ x1 = rand (Float32, N)
10
+ x2 = rand (Float32, N)
11
+
10
12
v = rand (Float32, N)
11
13
12
14
# Save original values of x and v to make sure they are not ever mutated
13
- x0 = copy (x)
14
- v0 = copy (v)
15
+ _x1 = copy (x1)
16
+ _x2 = copy (x2)
17
+ _v = copy (v)
15
18
16
19
a, b = rand (Float32, 2 )
17
20
@@ -23,81 +26,106 @@ _f(x) = A * (x .^ 2)
23
26
include (" update_coeffs_testutils.jl" )
24
27
f = WrapFunc (_f, 1.0f0 , 1.0f0 )
25
28
26
- @test auto_vecjac (f, x , v) ≈ Zygote. jacobian (f, x )[1 ]' * v
27
- @test auto_vecjac! (zero (x ), f, x , v) ≈ auto_vecjac (f, x , v)
28
- @test num_vecjac! (zero (x ), f, copy (x ), v) ≈ num_vecjac (f, copy (x ), v)
29
- @test auto_vecjac (f, x , v) ≈ num_vecjac (f, copy (x ), copy (v)) rtol = 1e-2
29
+ @test auto_vecjac (f, x1 , v) ≈ Zygote. jacobian (f, x1 )[1 ]' * v
30
+ @test auto_vecjac! (zero (x1 ), f, x1 , v) ≈ auto_vecjac (f, x1 , v)
31
+ @test num_vecjac! (zero (x1 ), f, copy (x1 ), v) ≈ num_vecjac (f, copy (x1 ), v)
32
+ @test auto_vecjac (f, x1 , v) ≈ num_vecjac (f, copy (x1 ), copy (v)) rtol = 1e-2
30
33
31
34
# Compute Jacobian via Zygote
32
35
33
36
@info " VecJac AutoZygote"
34
37
35
- L = VecJac (f, copy (x), 1.0f0 , 1.0f0 ; autodiff = AutoZygote ())
38
+ p, t = rand (Float32, 2 )
39
+ L = VecJac (f, copy (x1), p, t; autodiff = AutoZygote ())
40
+ update_coefficients! (L, v, p, t)
36
41
37
- Jx = Zygote. jacobian (f, x)[1 ]
38
- Jv = Zygote. jacobian (f, v)[1 ]
42
+ update_coefficients! (f, v, p, t)
43
+ J1 = Zygote. jacobian (f, x1)[1 ]
44
+ J2 = Zygote. jacobian (f, x2)[1 ]
39
45
40
- @test L * x ≈ Jx' * x
41
- @test L * v ≈ Jx' * v
42
- y= zero (x); @test mul! (y, L, v) ≈ Jx' * v
43
- y= zero (x); @test mul! (y, L, v) ≈ Jx' * v
46
+ # test operator application
47
+ @test L * v ≈ J1' * v
48
+ @test L (v, p, t) ≈ J1' * v
49
+ y= zeros (N); @test mul! (y, L, v) ≈ J1' * v
50
+ y= zeros (N); @test L (y, v, p, t) ≈ J1' * v
44
51
45
- @test L (x, 1.0f0 , 1.0f0 ) ≈ Jx' * x
46
- y= zero (x); @test L (y, x, 1.0f0 , 1.0f0 ) ≈ Jx' * x
47
- @test L (v, 1.0f0 , 1.0f0 ) ≈ Jv' * v
48
- y= zero (v); @test L (y, v, 1.0f0 , 1.0f0 ) ≈ Jv' * v
52
+ # use kwarg VJP_input = x2
53
+ @test L (v, p, t; VJP_input = x2) ≈ J2' * v
54
+ y= zeros (N); @test L (y, v, p, t; VJP_input = x2) ≈ J2' * v
49
55
50
- update_coefficients! (L, v, 3.0 , 4.0 )
56
+ # update_coefficients
57
+ p, t = rand (Float32, 2 )
58
+ L = update_coefficients (L, v, p, t; JVP_input = x2)
51
59
52
- Jx = Zygote. jacobian (f, x)[1 ]
53
- Jv = Zygote. jacobian (f, v)[1 ]
60
+ update_coefficients! (f, v, p, t)
61
+ J1 = Zygote. jacobian (f, x1)[1 ]
62
+ J2 = Zygote. jacobian (f, x2)[1 ]
54
63
55
- @test L * x ≈ Jv' * x
56
- @test L * v ≈ Jv' * v
57
- y= zero (x); @test mul! (y, L, v) ≈ Jv' * v
58
- y= zero (x); @test mul! (y, L, v) ≈ Jv' * v
64
+ # @show p, t
65
+ # @show f.p, f.t
66
+ # @show L.op.f.p, L.op.f.t
59
67
60
- @test L (x, 3.0f0 , 4.0f0 ) ≈ Jx' * x
61
- y= zero (x); @test L (y, x, 3.0f0 , 4.0f0 ) ≈ Jx' * x
62
- @test L (v, 3.0f0 , 4.0f0 ) ≈ Jv' * v
63
- y= zero (v); @test L (y, v, 3.0f0 , 4.0f0 ) ≈ Jv' * v
68
+ @test L * v ≈ J2' * v
69
+ @test L (v, p, t) ≈ J2' * v
70
+ y= zeros (N); @test mul! (y, L, v) ≈ J2' * v
71
+ y= zeros (N); @test L (y, v, p, t) ≈ J2' * v
72
+
73
+ # use kwarg VJP_input = x1
74
+ @test L (v, p, t; VJP_input = x1) ≈ J1' * v
75
+ y= zeros (N); @test L (y, v, p, t; VJP_input = x1) ≈ J1' * v
64
76
65
77
@info " VecJac AutoFiniteDiff"
66
78
67
- L = VecJac (f, copy (x), 1.0f0 , 1.0f0 ; autodiff = AutoFiniteDiff ())
79
+ p, t = rand (Float32, 2 )
80
+ L = VecJac (f, copy (x1), 1.0f0 , 1.0f0 ; autodiff = AutoFiniteDiff ())
81
+ update_coefficients! (L, v, p, t)
82
+ update_coefficients! (f, v, p, t)
83
+
84
+ @test L * v ≈ num_vecjac (f, copy (x1), v)
85
+ @test L (v, p, t) ≈ num_vecjac (f, copy (x1), v)
86
+ y= zeros (N); @test mul! (y, L, v) ≈ num_vecjac (f, copy (x1), v)
87
+ y= zeros (N); @test L (y, v, p, t) ≈ num_vecjac (f, copy (x1), v)
68
88
69
- @test L * x ≈ num_vecjac (f, copy (x), x)
70
- @test L * v ≈ num_vecjac (f, copy (x ), v)
71
- y= zero (x ); @test mul! (y, L, v ) ≈ num_vecjac (f, copy (x ), v)
89
+ # use kwarg VJP_input = x2
90
+ @test L (v, p, t; VJP_input = x2) ≈ num_vecjac (f, copy (x2 ), v)
91
+ y= zeros (N ); @test L (y, v, p, t; VJP_input = x2 ) ≈ num_vecjac (f, copy (x2 ), v)
72
92
73
- update_coefficients! (L, v, 3.0 , 4.0 )
74
- @test mul! (y, L, x) ≈ num_vecjac (f, copy (v), x)
75
- _y = copy (y); @test mul! (y, L, x, a, b) ≈ a * num_vecjac (f,copy (v),x) + b * _y
93
+ # update_coefficients
94
+ p, t = rand (Float32, 2 )
95
+ L = update_coefficients (L, v, p, t; JVP_input = x2)
96
+ update_coefficients! (f, v, p, t)
76
97
77
- update_coefficients! (f, v, 5.0 , 6.0 )
78
- @test L (y, v, 5.0 , 6.0 ) ≈ num_vecjac (f, copy (v), v)
98
+ @test L * v ≈ num_vecjac (f, copy (x2), v)
99
+ @test L (v, p, t) ≈ num_vecjac (f, copy (x2), v)
100
+ y= zeros (N); @test mul! (y, L, v) ≈ num_vecjac (f, copy (x2), v)
101
+ y= zeros (N); @test L (y, v, p, t) ≈ num_vecjac (f, copy (x2), v)
102
+
103
+ # use kwarg VJP_input = x2
104
+ @test L (v, p, t; VJP_input = x1) ≈ num_vecjac (f, copy (x1), v)
105
+ y= zeros (N); @test L (y, v, p, t; VJP_input = x1) ≈ num_vecjac (f, copy (x1), v)
79
106
80
107
# Test that x and v were not mutated
81
- @test x ≈ x0
82
- @test v ≈ v0
108
+ @test x1 ≈ _x1
109
+ @test x2 ≈ _x2
110
+ @test v ≈ v
83
111
84
112
@info " Base.resize!"
85
113
86
114
# Resize test
87
115
f2 (x) = 2 x
88
116
f2 (y, x) = (copy! (y, x); lmul! (2 , y); y)
89
117
118
+ x = rand (Float32, N)
90
119
for M in (100 , 400 )
91
120
local L = VecJac (f2, copy (x), 1.0f0 , 1.0f0 ; autodiff = AutoZygote ())
92
121
resize! (L, M)
93
122
94
123
_x = resize! (copy (x), M)
95
124
_u = rand (M)
96
- J2 = Zygote. jacobian (f2, _x)[1 ]
125
+ local J2 = Zygote. jacobian (f2, _x)[1 ]
97
126
98
- update_coefficients! (L, _x , 1.0f0 , 1.0f0 )
127
+ update_coefficients! (L, _u , 1.0f0 , 1.0f0 ; VJP_input = _x )
99
128
@test L * _u ≈ J2' * _u rtol= 1e-6
100
- _v = zeros (M); @test mul! (_v, L, _u) ≈ J2' * _u rtol= 1e-6
129
+ local _v = zeros (M); @test mul! (_v, L, _u) ≈ J2' * _u rtol= 1e-6
101
130
end
102
-
103
131
#
0 commit comments