Skip to content

Commit 045fab2

Browse files
authored
Merge pull request #255 from JuliaGPU/vc/KernelGradients
add KernelGradients
2 parents 545457a + 5a72fb2 commit 045fab2

File tree

15 files changed

+206
-8
lines changed

15 files changed

+206
-8
lines changed

.ci/develop.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,25 @@ root_directory = dirname(@__DIR__)
66

77
kernelabstractions = Pkg.PackageSpec(path = root_directory)
88

9+
BACKEND = get(ENV, "KERNELABSTRACTIONS_TEST_BACKEND", "all")
10+
BUILDKITE = parse(Bool, get(ENV, "BUILDKITE", "false"))
11+
12+
@info "Develop..." BUILDKITE
13+
914
Pkg.develop(kernelabstractions)
1015
if !(VERSION < v"1.6-")
11-
rockernels = Pkg.PackageSpec(path = joinpath(root_directory, "lib", "ROCKernels"))
12-
Pkg.develop(rockernels)
16+
if !BUILDKITE || BACKEND == "ROCM"
17+
rockernels = Pkg.PackageSpec(path = joinpath(root_directory, "lib", "ROCKernels"))
18+
Pkg.develop(rockernels)
19+
end
20+
21+
if !BUILDKITE || BACKEND == "CUDA"
22+
cudakernels = Pkg.PackageSpec(path = joinpath(root_directory, "lib", "CUDAKernels"))
23+
Pkg.develop(cudakernels)
24+
end
1325

14-
cudakernels = Pkg.PackageSpec(path = joinpath(root_directory, "lib", "CUDAKernels"))
15-
Pkg.develop(cudakernels)
26+
kernelgradients = Pkg.PackageSpec(path = joinpath(root_directory, "lib", "KernelGradients"))
27+
Pkg.develop(kernelgradients)
1628
end
1729
Pkg.build()
1830
Pkg.precompile()

.ci/test.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
import Pkg
22

3+
BACKEND = get(ENV, "KERNELABSTRACTIONS_TEST_BACKEND", "all")
4+
BUILDKITE = parse(Bool, get(ENV, "BUILDKITE", "false"))
5+
36
pkgs = [
47
"KernelAbstractions",
58
]
69
if !(VERSION < v"1.6-")
7-
push!(pkgs, "ROCKernels")
8-
push!(pkgs, "CUDAKernels")
10+
if !BUILDKITE || BACKEND == "ROCM"
11+
push!(pkgs, "ROCKernels")
12+
end
13+
if !BUILDKITE || BACKEND == "CUDA"
14+
push!(pkgs, "CUDAKernels")
15+
end
16+
push!(pkgs, "KernelGradients")
917
end
1018

1119
Pkg.test(pkgs; coverage = true)

lib/CUDAKernels/test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
24
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
35
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
6+
KernelGradients = "e5faadeb-7f6c-408e-9747-a7a26e81c66a"
47
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
58
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
69
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
710
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
811
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

lib/CUDAKernels/test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
using KernelAbstractions
2+
using KernelGradients
3+
using Enzyme
24
using CUDA
35
using CUDAKernels
46
using Test
57

68
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
9+
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))
710

811
if parse(Bool, get(ENV, "CI", "false"))
912
default = "CPU"
@@ -22,6 +25,7 @@ CUDA.versioninfo()
2225
if CUDA.functional(true)
2326
CUDA.allowscalar(false)
2427
Testsuite.testsuite(CUDADevice, backend, CUDA, CuArray, CUDA.CuDeviceArray)
28+
GradientsTestsuite.testsuite(CUDADevice, backend, CUDA, CuArray, CUDA.CuDeviceArray)
2529
else
2630
error("No CUDA GPUs available!")
2731
end

lib/KernelGradients/LICENSE.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
The MIT License (MIT)
2+
3+
Copyright &copy; 2021: Valentin Churavy, and other contributors
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in
13+
all copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+
THE SOFTWARE.

lib/KernelGradients/Project.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name = "KernelGradients"
2+
uuid = "e5faadeb-7f6c-408e-9747-a7a26e81c66a"
3+
authors = ["Valentin Churavy <v.churavy@gmail.com>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
8+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
9+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
10+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
11+
12+
13+
[compat]
14+
Cassette = "0.3"
15+
KernelAbstractions = "0.7"
16+
Requires = "1.1"
17+
Enzyme = "0.7"
18+
julia = "1.6"
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
module KernelGradients
2+
3+
import KernelAbstractions: Kernel, CPUCTX, CPUCtx
4+
import Cassette
5+
import Enzyme
6+
using Requires
7+
8+
@inline function Cassette.overdub(::CPUCtx, ::typeof(Enzyme.autodiff_deferred), f, annotation::Enzyme.Annotation, args...)
9+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(CPUCTX, f, args...))
10+
Enzyme.autodiff_deferred(f′, annotation, args...)
11+
end
12+
13+
@inline function Cassette.overdub(::CPUCtx, ::typeof(Enzyme.autodiff_deferred), f, args...)
14+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(CPUCTX, f, args...))
15+
Enzyme.autodiff_deferred(f′, args...)
16+
end
17+
18+
function Enzyme.autodiff(kernel::Kernel{<:Any, <:Any, <:Any, Fun}) where Fun
19+
function df(ctx, args...)
20+
Enzyme.autodiff_deferred(kernel.f, Enzyme.Const, ctx, args...)
21+
end
22+
similar(kernel, df)
23+
end
24+
25+
function __init__()
26+
@require CUDAKernels="72cfdca4-0801-4ab0-bf6a-d52aa10adc57" include("cuda.jl")
27+
@require ROCKernels="7eb9e9f0-4bd3-4c4c-8bef-26bd9629d9b9" include("roc.jl")
28+
end
29+
30+
end # module

lib/KernelGradients/src/cuda.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import .CUDAKernels: CUDACtx, CUDACTX
2+
import Cassette
3+
import Enzyme
4+
5+
@inline function Cassette.overdub(::CUDACtx, ::typeof(Enzyme.autodiff_deferred), f, annotation::Enzyme.Annotation, args...)
6+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(CUDACTX, f, args...))
7+
Enzyme.autodiff_deferred(f′, annotation, args...)
8+
end
9+
10+
@inline function Cassette.overdub(::CUDACtx, ::typeof(Enzyme.autodiff_deferred), f, args...)
11+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(CUDACTX, f, args...))
12+
Enzyme.autodiff_deferred(f′, args...)
13+
end

lib/KernelGradients/src/roc.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import .ROCKernels: ROCCTX, ROCCtx
2+
import Cassette
3+
import Enzyme
4+
5+
@inline function Cassette.overdub(::ROCCtx, ::typeof(Enzyme.autodiff_deferred), f, annotation::Enzyme.Annotation, args...)
6+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(ROCCTX, f, args...))
7+
Enzyme.autodiff_deferred(f′, annotation, args...)
8+
end
9+
10+
@inline function Cassette.overdub(::ROCCtx, ::typeof(Enzyme.autodiff_deferred), f, args...)
11+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(ROCCTX, f, args...))
12+
Enzyme.autodiff_deferred(f′, args...)
13+
end

lib/KernelGradients/test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
4+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)