@@ -13,6 +13,90 @@ as an object with mirroring fields.
13
13
"""
14
14
abstract type StructuralTangent{P} <: AbstractTangent end
15
15
16
+ """
17
+ Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent
18
+
19
+ This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`.
20
+ `P` is the the corresponding primal type that this is a tangent for.
21
+
22
+ `Tangent{P}` should have fields (technically properties), that match to a subset of the
23
+ fields of the primal type; and each should be a tangent type matching to the primal
24
+ type of that field.
25
+ Fields of the P that are not present in the Tangent are treated as `Zero`.
26
+
27
+ `T` is an implementation detail representing the backing data structure.
28
+ For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
29
+ It should not be passed in by user.
30
+
31
+ For `Tangent`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly
32
+ to for a tuple.
33
+ For `Tangent`s of `struct`s, `getproperty` is overloaded to allow for accessing values
34
+ via `tangent.fieldname`.
35
+ Any fields not explictly present in the `Tangent` are treated as being set to `ZeroTangent()`.
36
+ To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref)
37
+ function is provided.
38
+ """
39
+ struct Tangent{P,T} <: StructuralTangent{P}
40
+ # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
41
+ # (but potentially a different one, as it doesn't contain tangents)
42
+ backing:: T
43
+
44
+ function Tangent {P,T} (backing) where {P,T}
45
+ if P <: Tuple
46
+ T <: Tuple || _backing_error (P, T, Tuple)
47
+ elseif P <: AbstractDict
48
+ T <: AbstractDict || _backing_error (P, T, AbstractDict)
49
+ elseif P === Any # can be anything
50
+ else # Any other struct (including NamedTuple)
51
+ T <: NamedTuple || _backing_error (P, T, NamedTuple)
52
+ end
53
+ return new (backing)
54
+ end
55
+ end
56
+
57
+ """
58
+ MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent
59
+
60
+ This type represents the tangent to a mutable struct.
61
+ It itself is also mutable.
62
+
63
+ !!! warning Exprimental
64
+ MutableTangent is an experimental feature, and is part of the mutation support featureset.
65
+ While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
66
+ Exactly how it should be used (e.g. is it forward-mode only?)
67
+
68
+ !!! warning Do not directly mess with the tangent backing data
69
+ It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values.
70
+ However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is).
71
+ If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this.
72
+ """
73
+ struct MutableTangent{P,F} <: StructuralTangent{P}
74
+ backing:: F
75
+
76
+ function MutableTangent {P} (fieldvals) where P
77
+ backing = map (Ref, fieldvals)
78
+ return new {P, typeof(backing)} (backing)
79
+ end
80
+ function MutableTangent {P} (
81
+ any_mask:: NamedTuple{names, <:NTuple{<:Any, Bool}} , fvals:: NamedTuple{names}
82
+ ) where {names, P}
83
+
84
+ backing = map (any_mask, fvals) do isany, fval
85
+ ref = if isany
86
+ Ref{Any}
87
+ else
88
+ Ref
89
+ end
90
+ return ref (fval)
91
+ end
92
+ return new {P, typeof(backing)} (backing)
93
+ end
94
+ end
95
+
96
+ # ###################################################################
97
+ # StructuralTangent Common
98
+
99
+
16
100
function StructuralTangent {P} (nt:: NamedTuple ) where {P}
17
101
if has_mutable_tangent (P)
18
102
return MutableTangent {P} (nt)
@@ -21,6 +105,7 @@ function StructuralTangent{P}(nt::NamedTuple) where {P}
21
105
end
22
106
end
23
107
108
+
24
109
has_mutable_tangent (:: Type{P} ) where P = ismutabletype (P) && (! isabstracttype (P) && fieldcount (P) > 0 )
25
110
26
111
40
125
Base. iszero (t:: StructuralTangent ) = all (iszero, backing (t))
41
126
42
127
function Base. map (f, tangent:: StructuralTangent{P} ) where {P}
128
+ # TODO : is it even useful to support this on MutableTangents?
129
+ # TODO : we implictly assume only linear `f` are called and that it is safe to ignore noncanonical Zeros
130
+ # This feels like a fair assumption since all normal operations on tangents are linear
43
131
L = propertynames (backing (tangent))
44
132
vals = map (f, Tuple (backing (tangent)))
45
133
named_vals = NamedTuple {L,typeof(vals)} (vals)
@@ -63,7 +151,8 @@ primal types.
63
151
backing (x:: Tuple ) = x
64
152
backing (x:: NamedTuple ) = x
65
153
backing (x:: Dict ) = x
66
- backing (x:: StructuralTangent ) = getfield (x, :backing )
154
+ backing (x:: Tangent ) = getfield (x, :backing )
155
+ backing (x:: MutableTangent ) = map (getindex, getfield (x, :backing ))
67
156
68
157
# For generic structs
69
158
function backing (x:: T ):: NamedTuple where {T}
@@ -206,46 +295,8 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P}
206
295
return println (io)
207
296
end
208
297
209
- """
210
- Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent
211
-
212
- This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`.
213
- `P` is the the corresponding primal type that this is a tangent for.
214
-
215
- `Tangent{P}` should have fields (technically properties), that match to a subset of the
216
- fields of the primal type; and each should be a tangent type matching to the primal
217
- type of that field.
218
- Fields of the P that are not present in the Tangent are treated as `Zero`.
219
-
220
- `T` is an implementation detail representing the backing data structure.
221
- For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
222
- It should not be passed in by user.
223
-
224
- For `Tangent`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly
225
- to for a tuple.
226
- For `Tangent`s of `struct`s, `getproperty` is overloaded to allow for accessing values
227
- via `tangent.fieldname`.
228
- Any fields not explictly present in the `Tangent` are treated as being set to `ZeroTangent()`.
229
- To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref)
230
- function is provided.
231
- """
232
- struct Tangent{P,T} <: StructuralTangent{P}
233
- # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
234
- # (but potentially a different one, as it doesn't contain tangents)
235
- backing:: T
236
-
237
- function Tangent {P,T} (backing) where {P,T}
238
- if P <: Tuple
239
- T <: Tuple || _backing_error (P, T, Tuple)
240
- elseif P <: AbstractDict
241
- T <: AbstractDict || _backing_error (P, T, AbstractDict)
242
- elseif P === Any # can be anything
243
- else # Any other struct (including NamedTuple)
244
- T <: NamedTuple || _backing_error (P, T, NamedTuple)
245
- end
246
- return new (backing)
247
- end
248
- end
298
+ # ######################################
299
+ # immutable Tangent
249
300
250
301
function Tangent {P} (; kwargs... ) where {P}
251
302
backing = (; kwargs... ) # construct as NamedTuple
@@ -401,46 +452,19 @@ canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent
401
452
canonicalize (tangent:: Tangent{Any,<:Tuple} ) = tangent
402
453
canonicalize (tangent:: Tangent{Any,<:AbstractDict} ) = tangent
403
454
404
-
405
- """
406
- MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent
407
-
408
- This type represents the tangent to a mutable struct.
409
- It itself is also mutable.
410
-
411
- !!! warning Exprimental
412
- MutableTangent is an experimental feature, and is part of the mutation support featureset.
413
- While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
414
- Exactly how it should be used (e.g. is it forward-mode only?)
415
-
416
- !!! warning Do not directly mess with the tangent backing data
417
- It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values.
418
- However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is).
419
- If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this.
420
- """
421
- mutable struct MutableTangent{P} <: StructuralTangent{P}
422
- # TODO : we may want to absolutely lock the type of this down
423
- backing:: NamedTuple
424
- end
455
+ # ##################################################
456
+ # MutableTangent
425
457
426
458
MutableTangent {P} (;kwargs... ) where P = MutableTangent {P} (NamedTuple (kwargs))
427
459
428
- Base. getproperty (tangent:: MutableTangent , idx:: Symbol ) = getfield (backing (tangent), idx)
429
- Base. getproperty (tangent:: MutableTangent , idx:: Int ) = getfield (backing (tangent), idx) # break ambig
460
+ ref_backing (t:: MutableTangent ) = getfield (t, :backing )
430
461
431
- function Base. setproperty! (tangent:: MutableTangent , name:: Symbol , x)
432
- new_backing = Base. setindex (backing (tangent), x, name)
433
- setfield! (tangent, :backing , new_backing)
434
- return x
435
- end
462
+ Base. getproperty (tangent:: MutableTangent , idx:: Symbol ) = getfield (ref_backing (tangent), idx)[]
463
+ Base. getproperty (tangent:: MutableTangent , idx:: Int ) = getfield (ref_backing (tangent), idx)[] # break ambig
436
464
437
- function Base. setproperty! (tangent:: MutableTangent , idx:: Int , x)
438
- # needed due to https://github.com/JuliaLang/julia/issues/43155
439
- name = idx2sym (backing (tangent), idx)
440
- return setproperty! (tangent, name, x)
441
- end
465
+ Base. setproperty! (tangent:: MutableTangent , name:: Symbol , x) = getproperty (ref_backing (tangent), name)[] = x
466
+ Base. setproperty! (tangent:: MutableTangent , idx:: Int , x) = getproperty (ref_backing (tangent), idx)[] = x # break ambig
442
467
443
- idx2sym (:: NamedTuple{names} , idx) where names = names[idx]
444
468
445
469
Base. hash (tangent:: MutableTangent , h:: UInt64 ) = hash (backing (tangent), h)
446
470
function Base.:(== )(t1:: MutableTangent{T1} , t2:: MutableTangent{T2} ) where {T1, T2}
0 commit comments