Skip to content

Commit 40c547c

Browse files
authored
Merge pull request #1 from FluxML/dg/grad
Allow gradients in fmap
2 parents d58d273 + b5b872b commit 40c547c

File tree

4 files changed

+55
-4
lines changed

4 files changed

+55
-4
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ julia = "1"
88

99
[extras]
1010
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
11+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1112

1213
[targets]
13-
test = ["Test"]
14+
test = ["Test", "Zygote"]

src/functor.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,29 @@ Equivalent to `functor(x)[1]`.
7575
"""
7676
children(x) = functor(x)[1]
7777

78+
function functor_tuple(f, x::Tuple, dx::Tuple)
79+
map(x, dx) do x, x̄
80+
_default_walk(f, x, x̄)
81+
end
82+
end
83+
functor_tuple(f, x, dx) = f(x, dx)
84+
functor_tuple(f, x, ::Nothing) = x
85+
86+
# @functor Chain
87+
# Chain -> func = (layers = (Dense,Dense),), gs -> (layers...)
88+
function _default_walk(f, x, dx)
89+
func, re = functor(x)
90+
map(func, dx) do x, x̄
91+
# functor_tuple(f, x, x̄)
92+
f(x, x̄)
93+
end |> re
94+
end
95+
7896
function _default_walk(f, x)
7997
func, re = functor(x)
8098
re(map(f, func))
8199
end
100+
_default_walk(f, ::Nothing, ::Nothing) = nothing
82101

83102
"""
84103
fmap(f, x; exclude = isleaf, walk = Functors._default_walk)
@@ -205,3 +224,11 @@ function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
205224
end
206225
return output
207226
end
227+
228+
# Allow gradients and other constructs that match the structure of the functor
229+
# to allow for `map` style computations and return a modified version of the struct.
230+
# This way we can use `fmap` to update the params with their gradients
231+
function fmap(f, x, dx...; cache = IdDict())
232+
haskey(cache, x) && return cache[x]
233+
cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...)
234+
end

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using Functors, Test
2+
using Zygote
23

34
@testset "Functors.jl" begin
45

5-
include("basics.jl")
6-
include("base.jl")
7-
6+
include("basics.jl")
7+
include("update.jl")
88
end

test/update.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
@testset "Generalized fmap over equivalent functors" begin
2+
struct M{F,T,S}
3+
σ::F
4+
W::T
5+
b::S
6+
end
7+
8+
@functor M
9+
10+
(m::M)(x) = m.σ.(m.W * x .+ m.b)
11+
12+
m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3))
13+
x = ones(Float32, 4, 2)
14+
m̄, _ = gradient((m,x) -> sum(m(x)), m, x)
15+
= Functors.fmap(m, m̄) do x, y
16+
isnothing(x) && return y
17+
isnothing(y) && return x
18+
x .- 0.1f0 .* y
19+
end
20+
21+
@test.W fill(0.8f0, size(m.W))
22+
@test.b fill(-0.2f0, size(m.b))
23+
end

0 commit comments

Comments
 (0)