Skip to content

Commit 6b82b3e

Browse files
vchuravymaleadt
andauthored
Add EnzymeCore extension for parent_job (#2281)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent cc25b24 commit 6b82b3e

File tree

5 files changed

+32
-0
lines changed

5 files changed

+32
-0
lines changed

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ CUDA_Runtime_Discovery = "1af6417a-86b4-443c-805f-a4643ffb695f"
1212
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
1313
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
1414
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
15+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1516
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
1617
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1718
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
@@ -37,10 +38,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3738

3839
[weakdeps]
3940
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
41+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
4042
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4143

4244
[extensions]
4345
ChainRulesCoreExt = "ChainRulesCore"
46+
EnzymeCoreExt = "EnzymeCore"
4447
SpecialFunctionsExt = "SpecialFunctions"
4548

4649
[compat]
@@ -55,6 +58,7 @@ ChainRulesCore = "1"
5558
Crayons = "4"
5659
DataFrames = "1"
5760
ExprTools = "0.1"
61+
EnzymeCore = "0.7.1"
5862
GPUArrays = "10.0.1"
5963
GPUCompiler = "0.24, 0.25, 0.26"
6064
KernelAbstractions = "0.9.2"
@@ -81,4 +85,5 @@ julia = "1.8"
8185

8286
[extras]
8387
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
8489
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

ext/EnzymeCoreExt.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# compatibility with EnzymeCore
2+
3+
module EnzymeCoreExt
4+
5+
using CUDA
6+
import CUDA: GPUCompiler, CUDABackend
7+
8+
isdefined(Base, :get_extension) ? (import EnzymeCore) : (import ..EnzymeCore)
9+
10+
function EnzymeCore.compiler_job_from_backend(::CUDABackend, @nospecialize(F::Type), @nospecialize(TT::Type))
11+
mi = GPUCompiler.methodinstance(F, TT)
12+
return GPUCompiler.CompilerJob(mi, CUDA.compiler_config(CUDA.device()))
13+
end
14+
15+
end # module
16+

src/initialization.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ function __init__()
155155
@require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin
156156
include("../ext/SpecialFunctionsExt.jl")
157157
end
158+
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" begin
159+
include("../ext/EnzymeCoreExt.jl")
160+
end
158161
end
159162

160163
# ensure that operations executed by the REPL back-end finish before returning,

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
88
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
99
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
10+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1011
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1112
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
13+
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
1214
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1315
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
1416
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"

test/libraries/enzyme.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using EnzymeCore
2+
using GPUCompiler
3+
4+
@testset "compiler_job_from_backend" begin
5+
@test EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(()->nothing), Tuple{}) isa GPUCompiler.CompilerJob
6+
end

0 commit comments

Comments
 (0)