@@ -19,8 +19,16 @@ output = rand(1, batch_size)
19
19
# output sensitivities
20
20
_do = 1.
21
21
22
- spb (nn_cpu. params, nn. model, (input, output))[2 ](_do)
23
- zpb (nn_cpu. params, nn. model, (input, output))[2 ](_do)
24
- @time spb_evaluated = spb (nn_cpu. params, nn. model, (input, output))[2 ](_do)
25
- @time zpb_evaluated = zpb (nn_cpu. params, nn. model, (input, output))[2 ](_do)[1 ]. params
26
- # @assert values(spb_evaluated) .≈ values(zpb_evaluated)
22
+ # spb(nn_cpu.params, nn.model, (input, output))[2](_do)
23
+ # zpb(nn_cpu.params, nn.model, (input, output))[2](_do)
24
+ # @time spb_evaluated = spb(nn_cpu.params, nn.model, (input, output))[2](_do)
25
+ # @time zpb_evaluated = zpb(nn_cpu.params, nn.model, (input, output))[2](_do)[1].params
26
+ # @assert values(spb_evaluated) .≈ values(zpb_evaluated)
27
+
28
+ function timenn (pb, params, model, input, output, _do = 1. )
29
+ pb (params, model, (input, output))[2 ](_do)
30
+ @time pb (params, model, (input, output))[2 ](_do)
31
+ end
32
+
33
+ timenn (spb, nn_cpu. params, nn. model, input, output)
34
+ timenn (zpb, nn_cpu. params, nn. model, input, output)
0 commit comments