Skip to content

Commit 101f643

Browse files
add test for generic f-map
1 parent 3354cfb commit 101f643

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

test/update.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
@testset "Generalized fmap over equivalent functors" begin
2+
struct M{F,T,S}
3+
σ::F
4+
W::T
5+
b::S
6+
end
7+
8+
@functor M
9+
10+
(m::M)(x) = m.σ.(m.W * x .+ m.b)
11+
12+
m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3))
13+
x = ones(Float32, 4, 2)
14+
m̄, _ = gradient((m,x) -> sum(m(x)), m, x)
15+
= Functors.fmap(m, m̄) do x, y
16+
isnothing(x) && return y
17+
isnothing(y) && return x
18+
x .- 0.1f0 .* y
19+
end
20+
21+
@test.W fill(0.8f0, size(m.W))
22+
@test.b fill(-0.2f0, size(m.b))
23+
end

0 commit comments

Comments
 (0)