Skip to content

Test differentiation of Oceananigans broadcast kernel with Enzyme #3598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
4439078
Added test of tracer advection-diffusion
jlk9 Feb 14, 2024
1ce8978
Update test/test_enzyme.jl
jlk9 Feb 16, 2024
bc988a0
Update test/test_enzyme.jl
jlk9 Feb 16, 2024
2e061d8
Update test/test_enzyme.jl
jlk9 Feb 16, 2024
31cc3c2
Update test/test_enzyme.jl
jlk9 Feb 16, 2024
aea14b8
Added elaboration to comment on sum
jlk9 Feb 16, 2024
4331906
Update test/test_enzyme.jl
jlk9 Feb 16, 2024
e308701
Added include statement and changed @test assertion with typo
jlk9 Feb 16, 2024
93b0ff7
Updated compat to Enzyme 0.11.16
jlk9 Feb 16, 2024
a887706
Made some more suggested edits
jlk9 Feb 16, 2024
e45fd40
Removed call to minimum_spacing, keeping the spate and time steps con…
jlk9 Feb 16, 2024
851d1a2
Removed runTimeActivity and compute_sparams (still need LooseTypeAnal…
jlk9 Feb 16, 2024
cd04557
remove Enzyme and Test from core deps
navidcy Feb 19, 2024
498efac
revert changes in Manifest + resolve deps
navidcy Feb 19, 2024
67dc67a
Merge branch 'main' into jlk9/enzyme-test
navidcy Feb 19, 2024
b452207
Added updated @test to verify the AD and FD generated derivatives are…
jlk9 Feb 19, 2024
fc16f93
Trying Vararg argument for differentiability without catastrophic per…
jlk9 Feb 19, 2024
7421aa2
Reverted back to splat operators for the args. This is not intended t…
jlk9 Feb 19, 2024
e13abd9
update KernelAbstractions to v0.9.16
navidcy Feb 20, 2024
63995ae
Replaced tuple splats in function declarations in compute_hydrostatic…
jlk9 Feb 21, 2024
cb07dd4
Reinstated try-catch in set.jl
jlk9 Feb 21, 2024
44ebc7a
Added splats to compuation of momentum tendencies
jlk9 Feb 26, 2024
a3d9f59
Let's try CI again
jlk9 Feb 27, 2024
bda4803
Merge branch 'main' into jlk9/enzyme-test
jlk9 Feb 27, 2024
db8de01
An excuse to test CI again :)
jlk9 Mar 5, 2024
f0df4da
Updated dependencies
jlk9 Mar 5, 2024
60e7b21
Current versions of packages allow this PR to pass Enzyme tests local…
jlk9 Mar 5, 2024
9027b28
Merge branch 'main' into jlk9/enzyme-test
jlk9 Apr 4, 2024
5619cc9
Fixed inconsistency between active_cells_map and only_active_cells
jlk9 Apr 4, 2024
6ce27a8
Update Enzyme dependency
jlk9 Apr 5, 2024
756102f
Merge branch 'main' into jlk9/enzyme-test
jlk9 Apr 7, 2024
5f0299d
Merge branch 'main' into jlk9/enzyme-test
jlk9 Apr 8, 2024
ace664b
Merge branch 'main' into jlk9/enzyme-test
jlk9 Apr 12, 2024
da43d3f
Adding test to check for Oceananigans' broadcasted arrays (via KA) br…
jlk9 May 10, 2024
3ea07ad
Merge branch 'main' into jlk9/enzyme-test
jlk9 May 10, 2024
0596636
Update test_enzyme.jl
wsmoses May 11, 2024
d4a7a3f
Update test_enzyme.jl
wsmoses May 12, 2024
d46d75e
Update test_enzyme.jl
wsmoses May 13, 2024
4cc9796
Update test_enzyme.jl
wsmoses May 13, 2024
246749d
Update Project.toml
wsmoses May 13, 2024
fb1bd38
Up KernelAbstractions
glwagner May 15, 2024
d5b9351
Merge branch 'main' into jlk9/enzyme-constructor-any-test
glwagner May 16, 2024
112b10f
Update test_enzyme.jl
wsmoses May 16, 2024
2661e49
Merge branch 'main' into jlk9/enzyme-test
jlk9 May 20, 2024
9055b5a
Made a few problematic objects inactive in OceananigansEnzymeExt
jlk9 May 21, 2024
383cc48
Accidentally pushed some bug reductions, undoing those
jlk9 May 21, 2024
4ff3db6
Merge branch 'jlk9/enzyme-test' into jlk9/enzyme-constructor-any-test
jlk9 May 21, 2024
a6fcd38
Merge branch 'main' into jlk9/enzyme-constructor-any-test
glwagner May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.2"
julia_version = "1.10.0-rc1"
manifest_format = "2.0"
project_hash = "04d395caf937b0921325a77873167e8baa293a99"
project_hash = "aeaf0f67a467f08d2890f46e1746ccd3879587cb"

[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -164,7 +164,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.0+0"
version = "1.0.5+1"

[[deps.ConstructionBase]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -419,9 +419,9 @@ version = "0.2.1+0"

[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49"
git-tree-sha1 = "db02395e4c374030c53dc28f3c1d33dec35f7272"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
version = "0.9.18"
version = "0.9.19"

[deps.KernelAbstractions.extensions]
EnzymeExt = "EnzymeCore"
Expand Down Expand Up @@ -648,7 +648,7 @@ weakdeps = ["Adapt"]
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+4"
version = "0.3.23+2"

[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
Expand Down Expand Up @@ -677,7 +677,6 @@ version = "0.5.5+0"
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.6.3"

[[deps.P11Kit_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "2cd396108e178f3ae8dedbd8e938a18726ab2fbf"
Expand Down Expand Up @@ -968,7 +967,7 @@ deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
version = "7.2.1+1"

Expand Down Expand Up @@ -1086,7 +1085,6 @@ version = "1.1.2+0"
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+1"

[[deps.libzip_jll]]
deps = ["Artifacts", "Bzip2_jll", "GnuTLS_jll", "JLLWrappers", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"]
git-tree-sha1 = "3282b7d16ae7ac3e57ec2f3fa8fafb564d8f9f7f"
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ CubedSphere = "7445602f-e544-4518-8976-18f8e8ae6cdb"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
Expand Down Expand Up @@ -49,14 +50,14 @@ CubedSphere = "0.1, 0.2"
Dates = "1.9"
Distances = "0.10"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.11.14"
Enzyme = "0.12.5"
FFTW = "1"
Glob = "1.3"
IncompleteLU = "0.2"
InteractiveUtils = "1.9"
IterativeSolvers = "0.9"
JLD2 = "0.4"
KernelAbstractions = "0.9"
KernelAbstractions = "0.9.19"
LinearAlgebra = "1.9"
Logging = "1.9"
MPI = "0.16, 0.17, 0.18, 0.19, 0.20"
Expand Down
1 change: 1 addition & 0 deletions ext/OceananigansEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using Enzyme.EnzymeCore: Active, Const, Duplicated

EnzymeCore.EnzymeRules.inactive_noinl(::typeof(Oceananigans.Utils.flatten_reduced_dimensions), x...) = nothing
EnzymeCore.EnzymeRules.inactive(::typeof(Oceananigans.Grids.total_size), x...) = nothing
EnzymeCore.EnzymeRules.inactive(::typeof(Oceananigans.BoundaryConditions.parent_size_and_offset), x...) = nothing
@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{Oceananigans.Utils.KernelParameters}) = true

@inline batch(::Val{1}, ::Type{T}) where T = T
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function compute_hydrostatic_free_surface_tendency_contributions!(model, kernel_
c_tendency,
grid,
active_cells_map,
args;
args...;
active_cells_map)
end
end
Expand All @@ -134,7 +134,7 @@ end
##### Boundary condributions to hydrostatic free surface model
#####

function apply_flux_bcs!(Gcⁿ, c, arch, args)
function apply_flux_bcs!(Gcⁿ, c, arch, args::Vararg{Any, N}) where {N}
apply_x_bcs!(Gcⁿ, c, arch, args...)
apply_y_bcs!(Gcⁿ, c, arch, args...)
apply_z_bcs!(Gcⁿ, c, arch, args...)
Expand All @@ -154,7 +154,7 @@ function compute_free_surface_tendency!(grid, model, kernel_parameters)

launch!(arch, grid, kernel_parameters,
compute_hydrostatic_free_surface_Gη!, model.timestepper.Gⁿ.η,
grid, args)
grid, args...)

return nothing
end
Expand Down Expand Up @@ -209,15 +209,15 @@ function compute_hydrostatic_boundary_tendency_contributions!(Gⁿ, arch, veloci

# Velocity fields
for i in (:u, :v)
apply_flux_bcs!(Gⁿ[i], velocities[i], arch, args)
apply_flux_bcs!(Gⁿ[i], velocities[i], arch, args...)
end

# Free surface
apply_flux_bcs!(Gⁿ.η, displacement(free_surface), arch, args)
apply_flux_bcs!(Gⁿ.η, displacement(free_surface), arch, args...)

# Tracer fields
for i in propertynames(tracers)
apply_flux_bcs!(Gⁿ[i], tracers[i], arch, args)
apply_flux_bcs!(Gⁿ[i], tracers[i], arch, args...)
end

return nothing
Expand All @@ -228,24 +228,24 @@ end
#####

""" Calculate the right-hand-side of the u-velocity equation. """
@kernel function compute_hydrostatic_free_surface_Gu!(Gu, grid, map, args)
@kernel function compute_hydrostatic_free_surface_Gu!(Gu, grid, map, args::Vararg{Any, N}) where {N}
i, j, k = @index(Global, NTuple)
@inbounds Gu[i, j, k] = hydrostatic_free_surface_u_velocity_tendency(i, j, k, grid, args...)
end

@kernel function compute_hydrostatic_free_surface_Gu!(Gu, grid::ActiveCellsIBG, map, args)
@kernel function compute_hydrostatic_free_surface_Gu!(Gu, grid::ActiveCellsIBG, map, args::Vararg{Any, N}) where {N}
idx = @index(Global, Linear)
i, j, k = active_linear_index_to_tuple(idx, map, grid)
@inbounds Gu[i, j, k] = hydrostatic_free_surface_u_velocity_tendency(i, j, k, grid, args...)
end

""" Calculate the right-hand-side of the v-velocity equation. """
@kernel function compute_hydrostatic_free_surface_Gv!(Gv, grid, map, args)
@kernel function compute_hydrostatic_free_surface_Gv!(Gv, grid, map, args::Vararg{Any, N}) where {N}
i, j, k = @index(Global, NTuple)
@inbounds Gv[i, j, k] = hydrostatic_free_surface_v_velocity_tendency(i, j, k, grid, args...)
end

@kernel function compute_hydrostatic_free_surface_Gv!(Gv, grid::ActiveCellsIBG, map, args)
@kernel function compute_hydrostatic_free_surface_Gv!(Gv, grid::ActiveCellsIBG, map, args::Vararg{Any, N}) where {N}
idx = @index(Global, Linear)
i, j, k = active_linear_index_to_tuple(idx, map, grid)
@inbounds Gv[i, j, k] = hydrostatic_free_surface_v_velocity_tendency(i, j, k, grid, args...)
Expand All @@ -256,24 +256,24 @@ end
#####

""" Calculate the right-hand-side of the tracer advection-diffusion equation. """
@kernel function compute_hydrostatic_free_surface_Gc!(Gc, grid, map, args)
@kernel function compute_hydrostatic_free_surface_Gc!(Gc, grid, map, args::Vararg{Any, N}) where {N}
i, j, k = @index(Global, NTuple)
@inbounds Gc[i, j, k] = hydrostatic_free_surface_tracer_tendency(i, j, k, grid, args...)
end

@kernel function compute_hydrostatic_free_surface_Gc!(Gc, grid::ActiveCellsIBG, map, args)
@kernel function compute_hydrostatic_free_surface_Gc!(Gc, grid::ActiveCellsIBG, map, args::Vararg{Any, N}) where {N}
idx = @index(Global, Linear)
i, j, k = active_linear_index_to_tuple(idx, map, grid)
@inbounds Gc[i, j, k] = hydrostatic_free_surface_tracer_tendency(i, j, k, grid, args...)
end

""" Calculate the right-hand-side of the subgrid scale energy equation. """
@kernel function compute_hydrostatic_free_surface_Ge!(Ge, grid, map, args)
@kernel function compute_hydrostatic_free_surface_Ge!(Ge, grid, map, args::Vararg{Any, N}) where {N}
i, j, k = @index(Global, NTuple)
@inbounds Ge[i, j, k] = hydrostatic_turbulent_kinetic_energy_tendency(i, j, k, grid, args...)
end

@kernel function compute_hydrostatic_free_surface_Ge!(Ge, grid::ActiveCellsIBG, map, args)
@kernel function compute_hydrostatic_free_surface_Ge!(Ge, grid::ActiveCellsIBG, map, args::Vararg{Any, N}) where {N}
idx = @index(Global, Linear)
i, j, k = active_linear_index_to_tuple(idx, map, grid)
@inbounds Ge[i, j, k] = hydrostatic_turbulent_kinetic_energy_tendency(i, j, k, grid, args...)
Expand All @@ -284,7 +284,7 @@ end
#####

""" Calculate the right-hand-side of the free surface displacement (``η``) equation. """
@kernel function compute_hydrostatic_free_surface_Gη!(Gη, grid, args)
@kernel function compute_hydrostatic_free_surface_Gη!(Gη, grid, args::Vararg{Any, N}) where {N}
i, j = @index(Global, NTuple)
@inbounds Gη[i, j, grid.Nz+1] = free_surface_tendency(i, j, grid, args...)
end
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ model.velocities.u
```
"""
@inline function set!(model::HydrostaticFreeSurfaceModel; kwargs...)

for (fldname, value) in kwargs
if fldname ∈ propertynames(model.velocities)
ϕ = getproperty(model.velocities, fldname)
Expand Down
4 changes: 2 additions & 2 deletions src/Utils/kernel_launching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ end
Launches `kernel!`, with arguments `args` and keyword arguments `kwargs`,
over the `dims` of `grid` on the architecture `arch`. kernels run on the default stream
"""
function launch!(arch, grid, workspec, kernel!, kernel_args...;
@inline function launch!(arch, grid, workspec, kernel!, kernel_args::Vararg{Any, N};
include_right_boundaries = false,
reduced_dimensions = (),
location = nothing,
active_cells_map = nothing,
kwargs...)
kwargs...) where {N}

loop! = configured_kernel(arch, grid, workspec, kernel!;
include_right_boundaries,
Expand Down
Loading