From 23cf805ef74bd12063e9e39b6ca1c9d556673b4c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 28 May 2024 09:33:13 -0400 Subject: [PATCH 1/7] Add EnzymeCore extension --- lib/GPUArraysCore/Project.toml | 7 +++++++ lib/GPUArraysCore/ext/EnzymeCoreExt.jl | 27 ++++++++++++++++++++++++++ test/Project.toml | 1 + test/gpuarrayscore.jl | 9 +++++++++ test/runtests.jl | 4 ++++ 5 files changed, 48 insertions(+) create mode 100644 lib/GPUArraysCore/ext/EnzymeCoreExt.jl create mode 100644 test/gpuarrayscore.jl diff --git a/lib/GPUArraysCore/Project.toml b/lib/GPUArraysCore/Project.toml index c842d718..39c3624d 100644 --- a/lib/GPUArraysCore/Project.toml +++ b/lib/GPUArraysCore/Project.toml @@ -6,6 +6,13 @@ version = "0.1.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +[weakdeps] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[extensions] +EnzymeCoreExt = "EnzymeCore" + [compat] Adapt = "4.0" julia = "1.6" +EnzymeCore = "0.6, 0.7" diff --git a/lib/GPUArraysCore/ext/EnzymeCoreExt.jl b/lib/GPUArraysCore/ext/EnzymeCoreExt.jl new file mode 100644 index 00000000..ff02ddb4 --- /dev/null +++ b/lib/GPUArraysCore/ext/EnzymeCoreExt.jl @@ -0,0 +1,27 @@ +# compatibility with EnzymeCore + +module EnzymeCoreExt + +using GPUArraysCore + +if isdefined(Base, :get_extension) + using EnzymeCore + using EnzymeCore.EnzymeRules +else + using ..EnzymeCore + using ..EnzymeCore.EnzymeRules +end + +function EnzymeCore.EnzymeRules.inactive_noinl(::typeof(GPUArraysCore.default_scalar_indexing), args...) + return nothing +end + +function EnzymeCore.EnzymeRules.inactive_noinl(::typeof(GPUArraysCore.assertscalar), args...) + return nothing +end + +function EnzymeCore.EnzymeRules.inactive_noinl(::typeof(GPUArraysCore.allowscalar), args...) + return nothing +end + +end # module diff --git a/test/Project.toml b/test/Project.toml index 76e1e22a..fe9429b9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/gpuarrayscore.jl b/test/gpuarrayscore.jl new file mode 100644 index 00000000..88f1d62b --- /dev/null +++ b/test/gpuarrayscore.jl @@ -0,0 +1,9 @@ +using Test, GPUArraysCore, EnzymeCore + +@testset "EnzymeCore" begin + @test nothing == EnzymeCore.EnzymeRules.inactive_noinl(GPUArraysCore.assertscalar) + + @test nothing == EnzymeCore.EnzymeRules.inactive_noinl(GPUArraysCore.default_scalar_indexing) + + @test nothing == EnzymeCore.EnzymeRules.inactive_noinl(GPUArraysCore.allowscalar, identity) +end diff --git a/test/runtests.jl b/test/runtests.jl index 4df72b2b..d5ef5f62 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,9 @@ using GPUArrays, Test, Pkg +@tests "GPUArraysCore" begin + include("gpuarrayscore.jl") +end + include("testsuite.jl") @testset "JLArray" begin From 110f32b1154d243bcba55fdd38ebc1b6506a5110 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 8 Jun 2024 12:19:48 -0400 Subject: [PATCH 2/7] Update Project.toml --- lib/GPUArraysCore/Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/GPUArraysCore/Project.toml b/lib/GPUArraysCore/Project.toml index 39c3624d..2b160c6d 100644 --- a/lib/GPUArraysCore/Project.toml +++ b/lib/GPUArraysCore/Project.toml @@ -16,3 +16,6 @@ EnzymeCoreExt = "EnzymeCore" Adapt = "4.0" julia = "1.6" EnzymeCore = "0.6, 0.7" + +[extras] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" From 98b6641b62315fe0c396912ff3a11e73203eb201 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 10 Jun 2024 13:04:30 -0700 Subject: [PATCH 3/7] Update runtests.jl --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index d5ef5f62..81b9c78e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using GPUArrays, Test, Pkg -@tests "GPUArraysCore" begin +@testset "GPUArraysCore" begin include("gpuarrayscore.jl") end From 42cc0ee08f775403bcd85771cc65e515db37454d Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 10 Jun 2024 13:14:49 -0700 Subject: [PATCH 4/7] Just do enzyme integration testing --- test/Project.toml | 1 + test/gpuarrayscore.jl | 9 --------- test/runtests.jl | 4 ---- test/testsuite.jl | 1 + test/testsuite/enzyme.jl | 13 +++++++++++++ 5 files changed, 15 insertions(+), 13 deletions(-) delete mode 100644 test/gpuarrayscore.jl create mode 100644 test/testsuite/enzyme.jl diff --git a/test/Project.toml b/test/Project.toml index fe9429b9..ff2f783a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/gpuarrayscore.jl b/test/gpuarrayscore.jl deleted file mode 100644 index 88f1d62b..00000000 --- a/test/gpuarrayscore.jl +++ /dev/null @@ -1,9 +0,0 @@ -using Test, GPUArraysCore, EnzymeCore - -@testset "EnzymeCore" begin - @test nothing == EnzymeCore.EnzymeRules.inactive_noinl(GPUArraysCore.assertscalar) - - @test nothing == EnzymeCore.EnzymeRules.inactive_noinl(GPUArraysCore.default_scalar_indexing) - - @test nothing == EnzymeCore.EnzymeRules.inactive_noinl(GPUArraysCore.allowscalar, identity) -end diff --git a/test/runtests.jl b/test/runtests.jl index 81b9c78e..4df72b2b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,5 @@ using GPUArrays, Test, Pkg -@testset "GPUArraysCore" begin - include("gpuarrayscore.jl") -end - include("testsuite.jl") @testset "JLArray" begin diff --git a/test/testsuite.jl b/test/testsuite.jl index e7c14646..8f367fad 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -96,6 +96,7 @@ include("testsuite/math.jl") include("testsuite/random.jl") include("testsuite/uniformscaling.jl") include("testsuite/statistics.jl") +include("testsuite/enzyme.jl") """ Runs the entire GPUArrays test suite on array type `AT` diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl new file mode 100644 index 00000000..962987df --- /dev/null +++ b/test/testsuite/enzyme.jl @@ -0,0 +1,13 @@ +using Enzyme + +@testsuite "Enzyme" (AT, eltypes)->begin + for ET in eltypes + T = AT{ET} + @testset "Forward $ET" begin + x = T(ones(3)) + dx = T(3*ones(3)) + res = autodiff(Forward, scalarfirst, Duplicated(x, dx)) + @test approx(res, 3) + end + end +end From 34c92c27b076d6a4f1a4d09a0647ce9419aefcc9 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 10 Jun 2024 13:15:54 -0700 Subject: [PATCH 5/7] add missing fn --- test/testsuite/enzyme.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index 962987df..49b038f3 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -1,5 +1,9 @@ using Enzyme +function scalarfirst(x) + @allowscalar x[1] +end + @testsuite "Enzyme" (AT, eltypes)->begin for ET in eltypes T = AT{ET} From 2ff2623dd44a3a935b1d8c4693536d747639ab61 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 13 Jun 2024 09:44:00 -0700 Subject: [PATCH 6/7] Add jlarrays extension --- Project.toml | 1 + lib/JLArrays/Project.toml | 11 ++++++-- lib/JLArrays/ext/EnzymeExt.jl | 51 +++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 lib/JLArrays/ext/EnzymeExt.jl diff --git a/Project.toml b/Project.toml index 82f8ccc9..13a1253d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "10.1.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" diff --git a/lib/JLArrays/Project.toml b/lib/JLArrays/Project.toml index ce8959b7..5dac4485 100644 --- a/lib/JLArrays/Project.toml +++ b/lib/JLArrays/Project.toml @@ -1,15 +1,22 @@ name = "JLArrays" uuid = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" authors = ["Tim Besard "] -version = "0.1.4" +version = "0.1.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + +[extensions] +EnzymeExt = "Enzyme" + [compat] Adapt = "2.0, 3.0, 4.0" +Enzyme = "0.12" GPUArrays = "10" julia = "1.8" -Random = "1" +Random = "1" \ No newline at end of file diff --git a/lib/JLArrays/ext/EnzymeExt.jl b/lib/JLArrays/ext/EnzymeExt.jl new file mode 100644 index 00000000..88a59c71 --- /dev/null +++ b/lib/JLArrays/ext/EnzymeExt.jl @@ -0,0 +1,51 @@ +module EnzymeExt + +using JLArrays + +if isdefined(Base, :get_extension) + using Enzyme +else + using ..Enzyme +end + +# Override default type tree. This is because JLArray stores data as Vector{UInt8}, causing issues for +# type analysis not determining the proper element type (instead determining the memory is of type UInt8). +function Enzyme.typetree_inner(::Type{JLT}, ctx, dl, seen::Enzyme.TypeTreeTable) where {JLT<:JLArray} + if T isa UnionAll || T isa Union || T == Union{} || Base.isabstracttype(T) + return TypeTree() + end + + if !Base.isconcretetype(T) + return Enzyme.TypeTree(Enzyme.API.DT_Pointer, -1, ctx) + end + + elT = eltype(JLT) + + fieldTypes = [DataRef{Vector{elT}}, Int, Dims{length(size(JLT))}] + + tt = Enzyme.TypeTree() + for f in 1:fieldcount(T) + offset = fieldoffset(T, f) + subT = fieldTypes[f] + subtree = copy(Enzyme.typetree(subT, ctx, dl, seen)) + + if subT isa UnionAll || subT isa Union || subT == Union{} + # FIXME: Handle union + continue + end + + # Allocated inline so adjust first path + if allocatedinline(subT) + Enzyme.shift!(subtree, dl, 0, sizeof(subT), offset) + else + Enzyme.merge!(subtree, TypeTree(API.DT_Pointer, ctx)) + Enzyme.only!(subtree, offset) + end + + Enzyme.merge!(tt, subtree) + end + Enzyme.canonicalize!(tt, sizeof(T), dl) + return tt +end + +end # module From 33633821527ff37f5fff155093ed68b478ac0fad Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 13 Jun 2024 10:02:44 -0700 Subject: [PATCH 7/7] Fix jlarrays enzyme ext --- lib/JLArrays/ext/EnzymeExt.jl | 21 ++++++++++++--------- test/Project.toml | 1 + test/runtests.jl | 18 ++++++++++++++++++ 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/lib/JLArrays/ext/EnzymeExt.jl b/lib/JLArrays/ext/EnzymeExt.jl index 88a59c71..9789a300 100644 --- a/lib/JLArrays/ext/EnzymeExt.jl +++ b/lib/JLArrays/ext/EnzymeExt.jl @@ -2,30 +2,33 @@ module EnzymeExt using JLArrays +using GPUArrays + if isdefined(Base, :get_extension) using Enzyme else using ..Enzyme end + # Override default type tree. This is because JLArray stores data as Vector{UInt8}, causing issues for # type analysis not determining the proper element type (instead determining the memory is of type UInt8). function Enzyme.typetree_inner(::Type{JLT}, ctx, dl, seen::Enzyme.TypeTreeTable) where {JLT<:JLArray} - if T isa UnionAll || T isa Union || T == Union{} || Base.isabstracttype(T) - return TypeTree() + if JLT isa UnionAll || JLT isa Union || JLT == Union{} || Base.isabstracttype(JLT) + return Enzyme.TypeTree() end - if !Base.isconcretetype(T) + if !Base.isconcretetype(JLT) return Enzyme.TypeTree(Enzyme.API.DT_Pointer, -1, ctx) end elT = eltype(JLT) - fieldTypes = [DataRef{Vector{elT}}, Int, Dims{length(size(JLT))}] + fieldTypes = [DataRef{Vector{elT}}, Int, Dims{ndims(JLT)}] tt = Enzyme.TypeTree() - for f in 1:fieldcount(T) - offset = fieldoffset(T, f) + for f in 1:fieldcount(JLT) + offset = fieldoffset(JLT, f) subT = fieldTypes[f] subtree = copy(Enzyme.typetree(subT, ctx, dl, seen)) @@ -35,16 +38,16 @@ function Enzyme.typetree_inner(::Type{JLT}, ctx, dl, seen::Enzyme.TypeTreeTable) end # Allocated inline so adjust first path - if allocatedinline(subT) + if Enzyme.allocatedinline(subT) Enzyme.shift!(subtree, dl, 0, sizeof(subT), offset) else - Enzyme.merge!(subtree, TypeTree(API.DT_Pointer, ctx)) + Enzyme.merge!(subtree, Enzyme.TypeTree(Enzyme.API.DT_Pointer, ctx)) Enzyme.only!(subtree, offset) end Enzyme.merge!(tt, subtree) end - Enzyme.canonicalize!(tt, sizeof(T), dl) + Enzyme.canonicalize!(tt, sizeof(JLT), dl) return tt end diff --git a/test/Project.toml b/test/Project.toml index ff2f783a..23060348 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/runtests.jl b/test/runtests.jl index 4df72b2b..a9d5ae81 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,25 @@ using GPUArrays, Test, Pkg + +@testset "Enzyme JLArray: TypeTree" begin + + using Enzyme + using JLArrays + using LLVM + + import Enzyme: typetree, TypeTree, API, make_zero + + ctx = LLVM.Context() + dl = string(LLVM.DataLayout(LLVM.JITTargetMachine())) + + tt(T) = string(typetree(T, ctx, dl)) + @test tt(JLArray{Float64, 1}) == "{[0]:Pointer, [0,0]:Pointer, [0,0,-1]:Pointer, [0,0,0,0]:Pointer, [0,0,0,0,-1]:Float@double, [0,0,0,8]:Integer, [0,0,0,9]:Integer, [0,0,0,10]:Integer, [0,0,0,11]:Integer, [0,0,0,12]:Integer, [0,0,0,13]:Integer, [0,0,0,14]:Integer, [0,0,0,15]:Integer, [0,0,0,16]:Integer, [0,0,0,17]:Integer, [0,0,0,18]:Integer, [0,0,0,19]:Integer, [0,0,0,20]:Integer, [0,0,0,21]:Integer, [0,0,0,22]:Integer, [0,0,0,23]:Integer, [0,0,0,24]:Integer, [0,0,0,25]:Integer, [0,0,0,26]:Integer, [0,0,0,27]:Integer, [0,0,0,28]:Integer, [0,0,0,29]:Integer, [0,0,0,30]:Integer, [0,0,0,31]:Integer, [0,0,0,32]:Integer, [0,0,0,33]:Integer, [0,0,0,34]:Integer, [0,0,0,35]:Integer, [0,0,0,36]:Integer, [0,0,0,37]:Integer, [0,0,0,38]:Integer, [0,0,0,39]:Integer, [0,0,16,-1]:Integer, [0,8]:Integer, [8]:Integer, [9]:Integer, [10]:Integer, [11]:Integer, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer}" +end + + include("testsuite.jl") + @testset "JLArray" begin using JLArrays