Skip to content

Commit 6e72a04

Browse files
committed
Extend pullback comparison script.
1 parent 0fcce26 commit 6e72a04

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

scripts/pullback_comparison.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,16 @@ output = rand(1, batch_size)
1919
# output sensitivities
2020
_do = 1.
2121

22-
spb(nn_cpu.params, nn.model, (input, output))[2](_do)
23-
zpb(nn_cpu.params, nn.model, (input, output))[2](_do)
24-
@time spb_evaluated = spb(nn_cpu.params, nn.model, (input, output))[2](_do)
25-
@time zpb_evaluated = zpb(nn_cpu.params, nn.model, (input, output))[2](_do)[1].params
26-
# @assert values(spb_evaluated) .≈ values(zpb_evaluated)
22+
# spb(nn_cpu.params, nn.model, (input, output))[2](_do)
23+
# zpb(nn_cpu.params, nn.model, (input, output))[2](_do)
24+
# @time spb_evaluated = spb(nn_cpu.params, nn.model, (input, output))[2](_do)
25+
# @time zpb_evaluated = zpb(nn_cpu.params, nn.model, (input, output))[2](_do)[1].params
26+
# @assert values(spb_evaluated) .≈ values(zpb_evaluated)
27+
28+
function timenn(pb, params, model, input, output, _do = 1.)
29+
pb(params, model, (input, output))[2](_do)
30+
@time pb(params, model, (input, output))[2](_do)
31+
end
32+
33+
timenn(spb, nn_cpu.params, nn.model, input, output)
34+
timenn(zpb, nn_cpu.params, nn.model, input, output)

0 commit comments

Comments
 (0)