Skip to content

Commit dffaef0

Browse files
authored
Merge pull request #2065 from Saransh-cpp/non-diff
Make `rng_from_array` non differentiable
2 parents 0cbab9e + a0f0611 commit dffaef0

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ The current defaults are:
4646
rng_from_array(::AbstractArray) = default_rng_value()
4747
rng_from_array(::CuArray) = CUDA.default_rng()
4848

49+
@non_differentiable rng_from_array(::Any)
50+
4951
if VERSION >= v"1.7"
5052
@doc """
5153
default_rng_value()

test/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,3 +799,9 @@ end
799799
end
800800
end
801801
end
802+
803+
# make sure rng_from_array is non_differentiable
804+
@testset "rng_from_array" begin
805+
m(x) = (rand(rng_from_array(x)) * x)[1]
806+
gradient(m, ones(2))
807+
end

0 commit comments

Comments
 (0)