1
- # # Partial conditions
2
-
3
- struct ConditionsXNoByproduct{C,Y,A,K}
1
+ struct ConditionsX{C,K}
4
2
conditions:: C
5
- y:: Y
6
- args:: A
7
3
kwargs:: K
8
4
end
9
5
10
- function (conditions_x_nobyproduct:: ConditionsXNoByproduct )(x:: AbstractVector )
11
- (; conditions, y, args, kwargs) = conditions_x_nobyproduct
12
- return conditions (x, y, args... ; kwargs... )
13
- end
14
-
15
- struct ConditionsYNoByproduct{C,X,A,K}
6
+ struct ConditionsY{C,K}
16
7
conditions:: C
17
- x:: X
18
- args:: A
19
8
kwargs:: K
20
9
end
21
10
22
- function (conditions_y_nobyproduct:: ConditionsYNoByproduct )(y:: AbstractVector )
23
- (; conditions, x, args, kwargs) = conditions_y_nobyproduct
24
- return conditions (x, y, args... ; kwargs... )
11
+ function (cx:: ConditionsX )(x, y, args... )
12
+ return cx. conditions (x, y, args... ; cx. kwargs... )
25
13
end
26
14
27
- struct ConditionsXByproduct{C,Y,Z,A,K}
28
- conditions:: C
29
- y:: Y
30
- z:: Z
31
- args:: A
32
- kwargs:: K
33
- end
34
-
35
- function (conditions_x_byproduct:: ConditionsXByproduct )(x:: AbstractVector )
36
- (; conditions, y, z, args, kwargs) = conditions_x_byproduct
37
- return conditions (x, y, z, args... ; kwargs... )
15
+ function (cy:: ConditionsY )(y, x, args... ) # order switch
16
+ return cy. conditions (x, y, args... ; cy. kwargs... )
38
17
end
39
18
40
- struct ConditionsYByproduct{C,X,Z,A,K}
41
- conditions:: C
19
+ struct PushforwardOperator!{F,P,B,X,C,R}
20
+ f:: F
21
+ prep:: P
22
+ backend:: B
42
23
x:: X
43
- z:: Z
44
- args:: A
45
- kwargs:: K
46
- end
47
-
48
- function (conditions_y_byproduct:: ConditionsYByproduct )(y:: AbstractVector )
49
- (; conditions, x, z, args, kwargs) = conditions_y_byproduct
50
- return conditions (x, y, z, args... ; kwargs... )
51
- end
52
-
53
- function ConditionsX (conditions, x, y_or_yz, args... ; kwargs... )
54
- y = output (y_or_yz)
55
- if y_or_yz isa Tuple
56
- z = byproduct (y_or_yz)
57
- return ConditionsXByproduct (conditions, y, z, args, kwargs)
58
- else
59
- return ConditionsXNoByproduct (conditions, y, args, kwargs)
60
- end
61
- end
62
-
63
- function ConditionsY (conditions, x, y_or_yz, args... ; kwargs... )
64
- if y_or_yz isa Tuple
65
- z = byproduct (y_or_yz)
66
- return ConditionsYByproduct (conditions, x, z, args, kwargs)
67
- else
68
- return ConditionsYNoByproduct (conditions, x, args, kwargs)
69
- end
24
+ contexts:: C
25
+ res_backup:: R
70
26
end
71
27
72
- # # Lazy operators
73
-
74
- struct PushforwardOperator!{F,B,X,E,R}
28
+ struct PullbackOperator!{F,P,B,X,C,R}
75
29
f:: F
30
+ prep:: P
76
31
backend:: B
77
32
x:: X
78
- extras :: E
33
+ contexts :: C
79
34
res_backup:: R
80
35
end
81
36
37
+ function PushforwardOperator! (f, prep, backend, x, contexts)
38
+ res_backup = similar (f (x, map (unwrap, contexts)... ))
39
+ return PushforwardOperator! (f, prep, backend, x, contexts, res_backup)
40
+ end
41
+
42
+ function PullbackOperator! (f, prep, backend, x, contexts)
43
+ res_backup = similar (x)
44
+ return PullbackOperator! (f, prep, backend, x, contexts, res_backup)
45
+ end
46
+
82
47
function (po:: PushforwardOperator! )(res, v, α, β)
48
+ (; f, backend, x, contexts, prep, res_backup) = po
83
49
if iszero (β)
84
- pushforward! (po. f, res, po. backend, po. x, v, po. extras)
85
- res .= α .* res
50
+ pushforward! (f, (res,), prep, backend, x, (v,), contexts... )
51
+ if ! isone (α)
52
+ res .*= α
53
+ end
86
54
else
87
- po . res_backup . = res
88
- pushforward! (po . f, res, po . backend, po . x, v, po . extras )
89
- res . = α .* res .+ β .* po . res_backup
55
+ copyto! ( res_backup, res)
56
+ pushforward! (f, ( res,), prep, backend, x, (v,), contexts ... )
57
+ axpby! (β, res_backup, α, res)
90
58
end
91
59
return res
92
60
end
93
61
94
- struct PullbackOperator!{F,B,X,E,R}
95
- f:: F
96
- backend:: B
97
- x:: X
98
- extras:: E
99
- res_backup:: R
100
- end
101
-
102
62
function (po:: PullbackOperator! )(res, v, α, β)
63
+ (; f, backend, x, contexts, prep, res_backup) = po
103
64
if iszero (β)
104
- pullback! (po. f, res, po. backend, po. x, v, po. extras)
105
- res .= α .* res
65
+ pullback! (f, (res,), prep, backend, x, (v,), contexts... )
66
+ if ! isone (α)
67
+ res .*= α
68
+ end
106
69
else
107
- po . res_backup . = res
108
- pullback! (po . f, res, po . backend, po . x, v, po . extras )
109
- res . = α .* res .+ β .+ po . res_backup
70
+ copyto! ( res_backup, res)
71
+ pullback! (f, ( res,), prep, backend, x, (v,), contexts ... )
72
+ axpby! (β, res_backup, α, res)
110
73
end
111
74
return res
112
75
end
@@ -119,24 +82,25 @@ function build_A(
119
82
suggested_backend,
120
83
kwargs... ,
121
84
) where {lazy}
122
- (; conditions, linear_solver, conditions_y_backend) = implicit
85
+ (; conditions, conditions_y_backend) = implicit
123
86
y = output (y_or_yz)
124
87
n, m = length (x), length (y)
125
88
back_y = isnothing (conditions_y_backend) ? suggested_backend : conditions_y_backend
126
- cond_y = ConditionsY (conditions, x, y_or_yz, args... ; kwargs... )
89
+ cond_y = ConditionsY (conditions, kwargs)
90
+ contexts = (Constant (x), map (Constant, rest (y_or_yz))... , map (Constant, args)... )
127
91
if lazy
128
- extras = prepare_pushforward_same_point (cond_y, back_y, y, zero (y))
92
+ prep = prepare_pushforward_same_point (cond_y, back_y, y, ( zero (y),), contexts ... )
129
93
A = LinearOperator (
130
94
eltype (y),
131
95
m,
132
96
m,
133
97
false ,
134
98
false ,
135
- PushforwardOperator! (cond_y, back_y, y, extras, similar (y) ),
99
+ PushforwardOperator! (cond_y, prep, back_y, y, contexts ),
136
100
typeof (y),
137
101
)
138
102
else
139
- J = jacobian (cond_y, back_y, y)
103
+ J = jacobian (cond_y, back_y, y, contexts ... )
140
104
A = factorize (J)
141
105
end
142
106
return A
@@ -150,24 +114,25 @@ function build_Aᵀ(
150
114
suggested_backend,
151
115
kwargs... ,
152
116
) where {lazy}
153
- (; conditions, linear_solver, conditions_y_backend) = implicit
117
+ (; conditions, conditions_y_backend) = implicit
154
118
y = output (y_or_yz)
155
119
n, m = length (x), length (y)
156
120
back_y = isnothing (conditions_y_backend) ? suggested_backend : conditions_y_backend
157
- cond_y = ConditionsY (conditions, x, y_or_yz, args... ; kwargs... )
121
+ cond_y = ConditionsY (conditions, kwargs)
122
+ contexts = (Constant (x), map (Constant, rest (y_or_yz))... , map (Constant, args)... )
158
123
if lazy
159
- extras = prepare_pullback_same_point (cond_y, back_y, y, zero (y))
124
+ prep = prepare_pullback_same_point (cond_y, back_y, y, ( zero (y),), contexts ... )
160
125
Aᵀ = LinearOperator (
161
126
eltype (y),
162
127
m,
163
128
m,
164
129
false ,
165
130
false ,
166
- PullbackOperator! (cond_y, back_y, y, extras, similar (y) ),
131
+ PullbackOperator! (cond_y, prep, back_y, y, contexts ),
167
132
typeof (y),
168
133
)
169
134
else
170
- Jᵀ = transpose (jacobian (cond_y, back_y, y))
135
+ Jᵀ = transpose (jacobian (cond_y, back_y, y, contexts ... ))
171
136
Aᵀ = factorize (Jᵀ)
172
137
end
173
138
return Aᵀ
@@ -181,24 +146,25 @@ function build_B(
181
146
suggested_backend,
182
147
kwargs... ,
183
148
) where {lazy}
184
- (; conditions, linear_solver, conditions_x_backend) = implicit
149
+ (; conditions, conditions_x_backend) = implicit
185
150
y = output (y_or_yz)
186
151
n, m = length (x), length (y)
187
152
back_x = isnothing (conditions_x_backend) ? suggested_backend : conditions_x_backend
188
- cond_x = ConditionsX (conditions, x, y_or_yz, args... ; kwargs... )
153
+ cond_x = ConditionsX (conditions, kwargs)
154
+ contexts = (Constant (y), map (Constant, rest (y_or_yz))... , map (Constant, args)... )
189
155
if lazy
190
- extras = prepare_pushforward_same_point (cond_x, back_x, x, zero (x))
156
+ prep = prepare_pushforward_same_point (cond_x, back_x, x, ( zero (x),), contexts ... )
191
157
B = LinearOperator (
192
158
eltype (y),
193
159
m,
194
160
n,
195
161
false ,
196
162
false ,
197
- PushforwardOperator! (cond_x, back_x, x, extras, similar (y) ),
163
+ PushforwardOperator! (cond_x, prep, back_x, x, contexts ),
198
164
typeof (x),
199
165
)
200
166
else
201
- B = transpose (jacobian (cond_x, back_x, x))
167
+ B = transpose (jacobian (cond_x, back_x, x, contexts ... ))
202
168
end
203
169
return B
204
170
end
@@ -211,24 +177,25 @@ function build_Bᵀ(
211
177
suggested_backend,
212
178
kwargs... ,
213
179
) where {lazy}
214
- (; conditions, linear_solver, conditions_x_backend) = implicit
180
+ (; conditions, conditions_x_backend) = implicit
215
181
y = output (y_or_yz)
216
182
n, m = length (x), length (y)
217
183
back_x = isnothing (conditions_x_backend) ? suggested_backend : conditions_x_backend
218
- cond_x = ConditionsX (conditions, x, y_or_yz, args... ; kwargs... )
184
+ cond_x = ConditionsX (conditions, kwargs)
185
+ contexts = (Constant (y), map (Constant, rest (y_or_yz))... , map (Constant, args)... )
219
186
if lazy
220
- extras = prepare_pullback_same_point (cond_x, back_x, x, zero (y))
187
+ prep = prepare_pullback_same_point (cond_x, back_x, x, ( zero (y),), contexts ... )
221
188
Bᵀ = LinearOperator (
222
189
eltype (y),
223
190
n,
224
191
m,
225
192
false ,
226
193
false ,
227
- PullbackOperator! (cond_x, back_x, x, extras, similar (x) ),
194
+ PullbackOperator! (cond_x, prep, back_x, x, contexts ),
228
195
typeof (x),
229
196
)
230
197
else
231
- Bᵀ = transpose (jacobian (cond_x, back_x, x))
198
+ Bᵀ = transpose (jacobian (cond_x, back_x, x, contexts ... ))
232
199
end
233
200
return Bᵀ
234
201
end
0 commit comments