Skip to content

Commit bb18173

Browse files
committed
Make params non_differentiable
1 parent 276e372 commit bb18173

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/functor.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ function params(m...)
8888
return ps
8989
end
9090

91+
# Allows caching of the parameters when params is called within gradient()
92+
@non_differentiable params(m...)
93+
9194
struct FluxCUDAAdaptor end
9295
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
9396
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))

0 commit comments

Comments
 (0)