You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Only the fields returned by `trainable` will be collected as trainable parameters of the layer when calling `Flux.params`.
59
59
60
-
Another way of achieving this is through the `@functor` macro directly. Here, we can mark the fields we are interested in by grouping them in the second argument:
60
+
The exact same method of `trainable`can also be defined using the macro, for convenience:
61
61
62
62
```julia
63
-
Flux.@functor Affine (W,)
63
+
Flux.@layer Affine trainable=(W,)
64
64
```
65
65
66
-
However, doing this requires the `struct` to have a corresponding constructor that accepts those parameters.
66
+
There is a second, more severe, kind of restriction possible:
67
+
68
+
```
69
+
Flux.@layer Affine children=(W,)
70
+
```
71
+
72
+
This is equivalent to `Functors.@functor Affine (W,)`. It means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This is not usually recommended.
Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field.
129
136
130
-
The next step is to use [`Functors.@functor`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
137
+
The next step is to use [`Functors.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
131
138
```julia
132
-
Flux.@functor Join
139
+
Flux.@layer Join
133
140
```
134
141
135
142
Finally, we define the forward pass. For `Join`, this means applying each `path` in `paths` to each input array, then using `combine` to merge the results.
@@ -182,7 +189,7 @@ model(xs)
182
189
183
190
Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs.
184
191
185
-
We start by following the same steps as the `Join` layer: define a struct, use [`Functors.@functor`](@ref), and define the forward pass.
192
+
We start by following the same steps as the `Join` layer: define a struct, use [`@layer`](@ref), and define the forward pass.
0 commit comments