Skip to content

Commit 91abecc

Browse files
bors[bot]Gregliestlogankilpatrick
authored
Merge #1758
1758: Add custom model example to docs. r=ToucheSir a=Gregliest This PR adds a simple custom model example to the docs. As discussed in Slack, this example would have helped me understand how all the pieces fit together. But, happy to edit/move/delete as desired. cc `@ToucheSir` ### PR Checklist - [ ] Tests are added - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [ ] API changes require approval from a committer (different from the author, if applicable) Co-authored-by: Greg <gregliest@gmail.com> Co-authored-by: Logan Kilpatrick <23kilpatrick23@gmail.com>
2 parents 9c789dc + a056a95 commit 91abecc

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

docs/src/models/advanced.md

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,36 @@
22

33
Here we will try and describe usage of some more advanced features that Flux provides to give more control over model building.
44

5+
## Custom Model Example
6+
7+
Here is a basic example of a custom model. It simply adds the input to the result from the neural network.
8+
9+
```julia
10+
struct CustomModel
11+
chain::Chain
12+
end
13+
14+
function (m::CustomModel)(x)
15+
# Arbitrary code can go here, but note that everything will be differentiated.
16+
# Zygote does not allow some operations, like mutating arrays.
17+
18+
return m.chain(x) + x
19+
end
20+
21+
# Call @functor to allow for training. Described below in more detail.
22+
Flux.@functor CustomModel
23+
```
24+
25+
You can then use the model like:
26+
27+
```julia
28+
chain = Chain(Dense(10, 10))
29+
model = CustomModel(chain)
30+
model(rand(10))
31+
```
32+
33+
For an intro to Flux and automatic differentiation, see this [tutorial](https://fluxml.ai/tutorials/2020/09/15/deep-learning-flux.html).
34+
535
## Customising Parameter Collection for a Model
636

737
Taking reference from our example `Affine` layer from the [basics](basics.md#Building-Layers-1).
@@ -68,7 +98,7 @@ by simply deleting it from `ps`:
6898

6999
```julia
70100
ps = params(m)
71-
delete!(ps, m[2].bias)
101+
delete!(ps, m[2].bias)
72102
```
73103

74104
## Custom multiple input or output layer

0 commit comments

Comments
 (0)