We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 0cbab9e + a0f0611 commit dffaef0Copy full SHA for dffaef0
src/utils.jl
@@ -46,6 +46,8 @@ The current defaults are:
46
rng_from_array(::AbstractArray) = default_rng_value()
47
rng_from_array(::CuArray) = CUDA.default_rng()
48
49
+@non_differentiable rng_from_array(::Any)
50
+
51
if VERSION >= v"1.7"
52
@doc """
53
default_rng_value()
test/utils.jl
@@ -799,3 +799,9 @@ end
799
end
800
801
802
803
+# make sure rng_from_array is non_differentiable
804
+@testset "rng_from_array" begin
805
+ m(x) = (rand(rng_from_array(x)) * x)[1]
806
+ gradient(m, ones(2))
807
+end
0 commit comments