Skip to content

Commit 89f49ed

Browse files
committed
add KernelGradients
1 parent 545457a commit 89f49ed

File tree

15 files changed

+184
-2
lines changed

15 files changed

+184
-2
lines changed

.ci/develop.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ if !(VERSION < v"1.6-")
1313

1414
cudakernels = Pkg.PackageSpec(path = joinpath(root_directory, "lib", "CUDAKernels"))
1515
Pkg.develop(cudakernels)
16+
17+
kernelgradients = Pkg.PackageSpec(path = joinpath(root_directory, "lib", "KernelGradients"))
18+
Pkg.develop(kernelgradients)
1619
end
1720
Pkg.build()
1821
Pkg.precompile()

.ci/test.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pkgs = [
66
if !(VERSION < v"1.6-")
77
push!(pkgs, "ROCKernels")
88
push!(pkgs, "CUDAKernels")
9+
push!(pkgs, "KernelGradients")
910
end
1011

1112
Pkg.test(pkgs; coverage = true)

lib/CUDAKernels/test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
24
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
35
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
6+
KernelGradients = "e5faadeb-7f6c-408e-9747-a7a26e81c66a"
47
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
58
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
69
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
710
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
811
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

lib/CUDAKernels/test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
using KernelAbstractions
2+
using KernelGradients
3+
using Enzyme
24
using CUDA
35
using CUDAKernels
46
using Test
57

68
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
9+
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))
710

811
if parse(Bool, get(ENV, "CI", "false"))
912
default = "CPU"
@@ -22,6 +25,7 @@ CUDA.versioninfo()
2225
if CUDA.functional(true)
2326
CUDA.allowscalar(false)
2427
Testsuite.testsuite(CUDADevice, backend, CUDA, CuArray, CUDA.CuDeviceArray)
28+
GradientsTestsuite.testsuite(CUDADevice, backend, CUDA, CuArray, CUDA.CuDeviceArray)
2529
else
2630
error("No CUDA GPUs available!")
2731
end

lib/KernelGradients/LICENSE.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
The MIT License (MIT)
2+
3+
Copyright &copy; 2021: Valentin Churavy, and other contributors
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in
13+
all copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+
THE SOFTWARE.

lib/KernelGradients/Project.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name = "KernelGradients"
2+
uuid = "e5faadeb-7f6c-408e-9747-a7a26e81c66a"
3+
authors = ["Valentin Churavy <v.churavy@gmail.com>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
8+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
9+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
10+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
11+
12+
13+
[compat]
14+
Cassette = "0.3"
15+
KernelAbstractions = "0.7"
16+
Requires = "1.1"
17+
Enzyme = "0.7"
18+
julia = "1.6"
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
module KernelGradients
2+
3+
import KernelAbstractions: Kernel, CPUCTX, CPUCtx
4+
import Cassette
5+
import Enzyme
6+
using Requires
7+
8+
@inline function Cassette.overdub(::CPUCtx, ::typeof(Enzyme.autodiff_deferred), f, annotation::Enzyme.Annotation, args...)
9+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(CPUCTX, f, args...))
10+
Enzyme.autodiff_deferred(f′, annotation, args...)
11+
end
12+
13+
@inline function Cassette.overdub(::CPUCtx, ::typeof(Enzyme.autodiff_deferred), f, args...)
14+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(CPUCTX, f, args...))
15+
Enzyme.autodiff_deferred(f′, args...)
16+
end
17+
18+
function Enzyme.autodiff(kernel::Kernel{<:Any, <:Any, <:Any, Fun}) where Fun
19+
function df(ctx, args...)
20+
Enzyme.autodiff_deferred(kernel.f, Enzyme.Const, ctx, args...)
21+
end
22+
similar(kernel, df)
23+
end
24+
25+
function __init__()
26+
@require CUDAKernels="72cfdca4-0801-4ab0-bf6a-d52aa10adc57" include("cuda.jl")
27+
@require ROCKernels="7eb9e9f0-4bd3-4c4c-8bef-26bd9629d9b9" include("roc.jl")
28+
end
29+
30+
end # module

lib/KernelGradients/src/cuda.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import .CUDAKernels: CUDACtx, CUDACTX
2+
import Cassette
3+
import Enzyme
4+
5+
@inline function Cassette.overdub(::CUDACtx, ::typeof(Enzyme.autodiff_deferred), f, annotation::Enzyme.Annotation, args...)
6+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(CUDACTX, f, args...))
7+
Enzyme.autodiff_deferred(f′, annotation, args...)
8+
end
9+
10+
@inline function Cassette.overdub(::CUDACtx, ::typeof(Enzyme.autodiff_deferred), f, args...)
11+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(CUDACTX, f, args...))
12+
Enzyme.autodiff_deferred(f′, args...)
13+
end

lib/KernelGradients/src/roc.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import .ROCKernels: ROCCTX, ROCCtx
2+
import Cassette
3+
import Enzyme
4+
5+
@inline function Cassette.overdub(::ROCCtx, ::typeof(Enzyme.autodiff_deferred), f, annotation::Enzyme.Annotation, args...)
6+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(ROCCTX, f, args...))
7+
Enzyme.autodiff_deferred(f′, annotation, args...)
8+
end
9+
10+
@inline function Cassette.overdub(::ROCCtx, ::typeof(Enzyme.autodiff_deferred), f, args...)
11+
f′ = (args...) -> (Base.@_inline_meta; Cassette.overdub(ROCCTX, f, args...))
12+
Enzyme.autodiff_deferred(f′, args...)
13+
end

lib/KernelGradients/test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
4+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)