Skip to content

Commit e09be28

Browse files
committed
Fix RNG movement tests
1 parent 799ba61 commit e09be28

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

test/cuda/layers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,6 @@ end
287287
m = layer(0.1; rng = MersenneTwister(123))
288288
@test_throws ErrorException gpu(m)
289289
m = layer(0.1; rng = CUDA.default_rng())
290-
@test gpu(m).rng === CUDA.default_rng()
290+
@test gpu(m).rng isa CUDA.RNG
291291
end
292292
end

test/layers/normalisation.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
5959

6060
# CPU RNGs map onto CPU ok
6161
if isempty(rng_kwargs)
62-
@test cpu(m).rng === Random.default_rng()
62+
if VERSION >= v"1.7"
63+
@test cpu(m).rng isa Random.TaskLocalRNG
64+
else
65+
@test cpu(m).rng isa Random._GLOBAL_RNG
66+
end
6367
else
6468
@test cpu(m).rng === only(values(rng_kwargs))
6569
end
@@ -101,7 +105,11 @@ end
101105

102106
# CPU RNGs map onto CPU ok
103107
if isempty(rng_kwargs)
104-
@test cpu(m).rng === Random.default_rng()
108+
if VERSION >= v"1.7"
109+
@test cpu(m).rng isa Random.TaskLocalRNG
110+
else
111+
@test cpu(m).rng isa Random._GLOBAL_RNG
112+
end
105113
else
106114
@test cpu(m).rng === only(values(rng_kwargs))
107115
end

0 commit comments

Comments
 (0)