Skip to content

Commit 7ff012f

Browse files
authored
Enzyme: Mark launch_configuration as inactive (#2563)
[only downstream]
1 parent e02786c commit 7ff012f

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

ext/EnzymeCoreExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ end
2525
function EnzymeCore.EnzymeRules.inactive_noinl(::typeof(CUDA.is_pinned), args...)
2626
return nothing
2727
end
28+
function EnzymeCore.EnzymeRules.inactive_noinl(::typeof(CUDA.launch_configuration), args...; kwargs...)
29+
return nothing
30+
end
2831

2932
function EnzymeCore.compiler_job_from_backend(::CUDABackend, @nospecialize(F::Type), @nospecialize(TT::Type))
3033
mi = GPUCompiler.methodinstance(F, TT)

test/extensions/enzyme.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,34 @@ end
117117
@test all(dx .≈ 1.0)
118118
end
119119

120+
121+
function setadd(out, x, y)
122+
out .= x .+ y
123+
nothing
124+
end
125+
126+
@testset "Forward setadd" begin
127+
out = CuArray([0.0, 0.0, 0.0, 0.0])
128+
dout = CuArray([0.0, 0.0, 0.0, 0.0])
129+
x = CuArray([1.0, 2.0, 3.0, 4.0])
130+
dx = CuArray([100., 300.0, 500.0, 700.0])
131+
y = CuArray([5.0, 6.0, 7.0, 8.0])
132+
dy = CuArray([500., 600.0, 700.0, 800.0])
133+
res = Enzyme.autodiff(Forward, setadd, Duplicated(out, dout), Duplicated(x, dx), Duplicated(y, dy))
134+
@test all(dout .≈ dx .+ dy)
135+
end
136+
137+
@testset "setadd sum" begin
138+
out = CuArray([0.0, 0.0, 0.0, 0.0])
139+
dout = CuArray([1.0, 1.0, 1.0, 1.0])
140+
x = CuArray([1.0, 2.0, 3.0, 4.0])
141+
dx = CuArray([0., 0.0, 0.0, 0.0])
142+
y = CuArray([5.0, 6.0, 7.0, 8.0])
143+
dy = CuArray([0., 0.0, 0.0, 0.0])
144+
res = Enzyme.autodiff(Reverse, setadd, Duplicated(out, dout), Duplicated(x, dx), Duplicated(y, dy))
145+
@test all(dx .≈ 1)
146+
@test all(dy .≈ 1)
147+
end
120148
# TODO once reverse kernels are in
121149
# function togpu(x)
122150
# x = CuArray(x)

0 commit comments

Comments
 (0)