Skip to content

Commit 609196a

Browse files
committed
add some GPU tests
1 parent ccbe561 commit 609196a

File tree

2 files changed

+34
-18
lines changed

2 files changed

+34
-18
lines changed

test/rulesets/Base/broadcast.jl

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,21 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
8080

8181
@testset "fused rules" begin
8282
@testset "arithmetic" begin
83-
test_rrule(copybroadcasted, +, rand(3), rand(3))
84-
test_rrule(copybroadcasted, +, rand(3), rand(4)')
85-
test_rrule(copybroadcasted, +, rand(3), rand(1), rand())
86-
test_rrule(copybroadcasted, +, rand(3), 1.0*im)
87-
test_rrule(copybroadcasted, +, rand(3), true)
88-
test_rrule(copybroadcasted, +, rand(3), Tuple(rand(3)))
83+
@gpu test_rrule(copybroadcasted, +, rand(3), rand(3))
84+
@gpu test_rrule(copybroadcasted, +, rand(3), rand(4)')
85+
@gpu test_rrule(copybroadcasted, +, rand(3), rand(1), rand())
86+
@gpu test_rrule(copybroadcasted, +, rand(3), 1.0*im)
87+
@gpu test_rrule(copybroadcasted, +, rand(3), true)
88+
@gpu_broken test_rrule(copybroadcasted, +, rand(3), Tuple(rand(3)))
8989

90-
test_rrule(copybroadcasted, -, rand(3), rand(3))
91-
test_rrule(copybroadcasted, -, rand(3), rand(4)')
92-
test_rrule(copybroadcasted, -, rand(3))
90+
@gpu test_rrule(copybroadcasted, -, rand(3), rand(3))
91+
@gpu test_rrule(copybroadcasted, -, rand(3), rand(4)')
92+
@gpu test_rrule(copybroadcasted, -, rand(3))
9393
test_rrule(copybroadcasted, -, Tuple(rand(3)))
9494

95-
test_rrule(copybroadcasted, *, rand(3), rand(3))
96-
test_rrule(copybroadcasted, *, rand(3), rand())
97-
test_rrule(copybroadcasted, *, rand(), rand(3))
95+
@gpu test_rrule(copybroadcasted, *, rand(3), rand(3))
96+
@gpu test_rrule(copybroadcasted, *, rand(3), rand())
97+
@gpu test_rrule(copybroadcasted, *, rand(), rand(3))
9898

9999
test_rrule(copybroadcasted, *, rand(3) .+ im, rand(3) .+ 2im)
100100
test_rrule(copybroadcasted, *, rand(3) .+ im, rand() + 3im)
@@ -107,14 +107,15 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
107107
@test unthunk(bk4([4, 5im, 6+7im])[4]) == [0,5,7]
108108

109109
# These two test vararg rrule * rule:
110-
test_rrule(copybroadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3))
111-
test_rrule(copybroadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)')
110+
@gpu test_rrule(copybroadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3))
111+
@gpu_broken test_rrule(copybroadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)')
112+
# GPU error from dot(x::JLArray{Float32, 1}, y::JLArray{ComplexF32, 2})
112113

113-
test_rrule(copybroadcasted, Base.literal_pow, ^, rand(3), Val(2))
114-
test_rrule(copybroadcasted, Base.literal_pow, ^, rand(3) .+ im, Val(2))
114+
@gpu test_rrule(copybroadcasted, Base.literal_pow, ^, rand(3), Val(2))
115+
@gpu test_rrule(copybroadcasted, Base.literal_pow, ^, rand(3) .+ im, Val(2))
115116

116-
test_rrule(copybroadcasted, /, rand(3), rand())
117-
test_rrule(copybroadcasted, /, rand(3) .+ im, rand() + 3im)
117+
@gpu test_rrule(copybroadcasted, /, rand(3), rand())
118+
@gpu test_rrule(copybroadcasted, /, rand(3) .+ im, rand() + 3im)
118119
end
119120
@testset "identity etc" begin
120121
test_rrule(copybroadcasted, identity, rand(3))

test/unzipped.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,19 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map
7979
@test dx5[2] isa Tangent{<:Tuple}
8080
@test Tuple(dx5[2][2]) == (10, ZeroTangent())
8181
end
82+
83+
@testset "JLArray tests" begin # fake GPU testing
84+
(y1, y2), bk = rrule(CFG, unzip_broadcast, tuple, [1,2,3.0], [4 5.0])
85+
(y1jl, y2jl), bk_jl = rrule(CFG, unzip_broadcast, tuple, jl([1,2,3.0]), jl([4 5.0]))
86+
@test y1 == Array(y1jl)
87+
# TODO invent some tests of this rrule's pullback function
88+
89+
@test unzip(jl([(1,2), (3,4), (5,6)])) == (jl([1, 3, 5]), jl([2, 4, 6]))
90+
91+
@test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] == jl([2, 4, 6])
92+
@test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] isa Base.ReinterpretArray
93+
94+
@test unzip(jl([(1,), (3,), (5,)]))[1] == jl([1, 3, 5])
95+
@test unzip(jl([(1,), (3,), (5,)]))[1] isa Base.ReinterpretArray
96+
end
8297
end

0 commit comments

Comments
 (0)