Skip to content

Commit 8ca6ce0

Browse files
add path option to trainables (#174)
* add path=true * fix * fix * fix docs * fix docs * update doc workflow
1 parent a87ffd5 commit 8ca6ce0

File tree

14 files changed

+268
-137
lines changed

14 files changed

+268
-137
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,3 @@ jobs:
5353
file: lcov.info
5454
continue-on-error: ${{ matrix.julia-version == 'nightly' }}
5555

56-
docs:
57-
name: Documentation
58-
runs-on: ubuntu-latest
59-
steps:
60-
- uses: actions/checkout@v3
61-
- uses: julia-actions/setup-julia@v1
62-
with:
63-
version: '1.6'
64-
- run: |
65-
julia --project=docs -e '
66-
using Pkg
67-
Pkg.develop(PackageSpec(path=pwd()))
68-
Pkg.instantiate()'
69-
- run: |
70-
julia --color=yes --project=docs/ -e '
71-
using Optimisers
72-
using Documenter
73-
using Documenter: doctest
74-
DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers); recursive = true)
75-
doctest(Optimisers)'
76-
- run: julia --project=docs docs/make.jl
77-
env:
78-
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
79-
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}

.github/workflows/dependabot.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
2+
version: 2
3+
updates:
4+
- package-ecosystem: "github-actions"
5+
directory: "/" # Location of package manifests
6+
schedule:
7+
interval: "weekly"

.github/workflows/documentation.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name: Documentation
2+
3+
on:
4+
push:
5+
branches:
6+
- master # update to match your development branch (master, main, dev, trunk, ...)
7+
tags: '*'
8+
pull_request:
9+
10+
jobs:
11+
build:
12+
permissions:
13+
contents: write
14+
statuses: write
15+
runs-on: ubuntu-latest
16+
steps:
17+
- uses: actions/checkout@v4
18+
- uses: julia-actions/setup-julia@v1
19+
with:
20+
version: '1.10'
21+
- uses: julia-actions/cache@v1
22+
- name: Install dependencies
23+
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
24+
- name: Build and deploy
25+
env:
26+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token
27+
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # If authenticating with SSH deploy key
28+
run: julia --project=docs/ docs/make.jl

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

1313
[compat]
1414
ChainRulesCore = "1"
15-
Functors = "0.4"
15+
Functors = "0.4.9"
1616
Statistics = "1"
1717
Zygote = "0.6.40"
1818
julia = "1.6"

docs/.DS_Store

0 Bytes
Binary file not shown.

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
4+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
45
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
56
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

docs/make.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
using Documenter, Optimisers, Zygote, StaticArrays, Functors
22

3-
DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers); recursive = true)
3+
DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers, Functors); recursive = true)
4+
DocMeta.setdocmeta!(Functors, :DocTestSetup, :(using Functors); recursive = true)
45

5-
makedocs(modules = [Optimisers],
6+
makedocs(modules = [Optimisers, Functors],
67
doctest = false,
78
sitename = "Optimisers.jl",
89
pages = ["Home" => "index.md",
@@ -13,6 +14,7 @@ makedocs(modules = [Optimisers],
1314
assets = ["assets/flux.css"],
1415
prettyurls = get(ENV, "CI", nothing) == "true"
1516
),
17+
checkdocs = :none, # don't check that Functors' docstrings are all reported here
1618
)
1719

1820
deploydocs(

docs/src/api.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
```@meta
2+
CollapsedDocStrings = true
3+
```
14

25
## Optimisation Rules
36

@@ -72,3 +75,14 @@ Optimisers.@lazy
7275
Optimisers.adjust(::AbstractRule, ::Real)
7376
Optimisers.@def
7477
```
78+
79+
## KeyPath
80+
81+
A `KeyPath` is a sequence of keys that can be used to access a value within a nested structure.
82+
It is defined in Functors.jl and re-exported by Optimisers.jl here for convenience.
83+
84+
```@docs
85+
Functors.KeyPath
86+
Functors.haskeypath
87+
Functors.getkeypath
88+
```

src/Optimisers.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
module Optimisers
22

3-
using Functors: functor, fmap, isleaf, @functor, fmapstructure, children, AbstractWalk
3+
using Functors: functor, fmap, fmap_with_path,
4+
KeyPath, haskeypath, getkeypath,
5+
isleaf, @functor, fmapstructure, children, AbstractWalk
46
using LinearAlgebra
57

68
include("interface.jl")
79
export AbstractRule
810

11+
include("utils.jl")
12+
913
include("adjust.jl")
1014

1115
include("destructure.jl")
1216
export destructure
1317

1418
include("trainables.jl")
1519
export trainables
20+
export KeyPath, haskeypath, getkeypath # from Functors.jl
1621

1722
include("rules.jl")
1823
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,

src/destructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878

7979
struct TrainableStructWalk <: AbstractWalk end
8080

81-
(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
81+
(::TrainableStructWalk)(recurse, x) = mapvalue(recurse, _trainable(x))
8282

8383
_vec(x::Number) = LinRange(x,x,1)
8484
_vec(x::AbstractArray) = vec(x)

0 commit comments

Comments
 (0)