Skip to content

Test differentiation of Oceananigans broadcast kernel with Enzyme #3598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 48 commits into from
Closed
Changes from 5 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
4439078
Added test of tracer advection-diffusion
jlk9 Feb 14, 2024
1ce8978
Update test/test_enzyme.jl
jlk9 Feb 16, 2024
bc988a0
Update test/test_enzyme.jl
jlk9 Feb 16, 2024
2e061d8
Update test/test_enzyme.jl
jlk9 Feb 16, 2024
31cc3c2
Update test/test_enzyme.jl
jlk9 Feb 16, 2024
aea14b8
Added elaboration to comment on sum
jlk9 Feb 16, 2024
4331906
Update test/test_enzyme.jl
jlk9 Feb 16, 2024
e308701
Added include statement and changed @test assertion with typo
jlk9 Feb 16, 2024
93b0ff7
Updated compat to Enzyme 0.11.16
jlk9 Feb 16, 2024
a887706
Made some more suggested edits
jlk9 Feb 16, 2024
e45fd40
Removed call to minimum_spacing, keeping the spate and time steps con…
jlk9 Feb 16, 2024
851d1a2
Removed runTimeActivity and compute_sparams (still need LooseTypeAnal…
jlk9 Feb 16, 2024
cd04557
remove Enzyme and Test from core deps
navidcy Feb 19, 2024
498efac
revert changes in Manifest + resolve deps
navidcy Feb 19, 2024
67dc67a
Merge branch 'main' into jlk9/enzyme-test
navidcy Feb 19, 2024
b452207
Added updated @test to verify the AD and FD generated derivatives are…
jlk9 Feb 19, 2024
fc16f93
Trying Vararg argument for differentiability without catastrophic per…
jlk9 Feb 19, 2024
7421aa2
Reverted back to splat operators for the args. This is not intended t…
jlk9 Feb 19, 2024
e13abd9
update KernelAbstractions to v0.9.16
navidcy Feb 20, 2024
63995ae
Replaced tuple splats in function declarations in compute_hydrostatic…
jlk9 Feb 21, 2024
cb07dd4
Reinstated try-catch in set.jl
jlk9 Feb 21, 2024
44ebc7a
Added splats to compuation of momentum tendencies
jlk9 Feb 26, 2024
a3d9f59
Let's try CI again
jlk9 Feb 27, 2024
bda4803
Merge branch 'main' into jlk9/enzyme-test
jlk9 Feb 27, 2024
db8de01
An excuse to test CI again :)
jlk9 Mar 5, 2024
f0df4da
Updated dependencies
jlk9 Mar 5, 2024
60e7b21
Current versions of packages allow this PR to pass Enzyme tests local…
jlk9 Mar 5, 2024
9027b28
Merge branch 'main' into jlk9/enzyme-test
jlk9 Apr 4, 2024
5619cc9
Fixed inconsistency between active_cells_map and only_active_cells
jlk9 Apr 4, 2024
6ce27a8
Update Enzyme dependency
jlk9 Apr 5, 2024
756102f
Merge branch 'main' into jlk9/enzyme-test
jlk9 Apr 7, 2024
5f0299d
Merge branch 'main' into jlk9/enzyme-test
jlk9 Apr 8, 2024
ace664b
Merge branch 'main' into jlk9/enzyme-test
jlk9 Apr 12, 2024
da43d3f
Adding test to check for Oceananigans' broadcasted arrays (via KA) br…
jlk9 May 10, 2024
3ea07ad
Merge branch 'main' into jlk9/enzyme-test
jlk9 May 10, 2024
0596636
Update test_enzyme.jl
wsmoses May 11, 2024
d4a7a3f
Update test_enzyme.jl
wsmoses May 12, 2024
d46d75e
Update test_enzyme.jl
wsmoses May 13, 2024
4cc9796
Update test_enzyme.jl
wsmoses May 13, 2024
246749d
Update Project.toml
wsmoses May 13, 2024
fb1bd38
Up KernelAbstractions
glwagner May 15, 2024
d5b9351
Merge branch 'main' into jlk9/enzyme-constructor-any-test
glwagner May 16, 2024
112b10f
Update test_enzyme.jl
wsmoses May 16, 2024
2661e49
Merge branch 'main' into jlk9/enzyme-test
jlk9 May 20, 2024
9055b5a
Made a few problematic objects inactive in OceananigansEnzymeExt
jlk9 May 21, 2024
383cc48
Accidentally pushed some bug reductions, undoing those
jlk9 May 21, 2024
4ff3db6
Merge branch 'jlk9/enzyme-test' into jlk9/enzyme-constructor-any-test
jlk9 May 21, 2024
a6fcd38
Merge branch 'main' into jlk9/enzyme-constructor-any-test
glwagner May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions test/test_enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
using Oceananigans
using Enzyme
using Oceananigans.Fields: FunctionField
using Oceananigans: architecture
using KernelAbstractions

# Required presently
Enzyme.API.runtimeActivity!(true)

EnzymeRules.inactive_type(::Type{<:Oceananigans.Grids.AbstractGrid}) = true
EnzymeRules.inactive_type(::Type{<:Oceananigans.Clock}) = true

f(grid) = CenterField(grid)

Expand All @@ -23,3 +27,91 @@ f(grid) = CenterField(grid)

@test size(primal) == size(shadow)
end

function set_initial_condition_via_launch!(model_tracer, amplitude)
# Set initial condition
amplitude = Ref(amplitude)

# This has a "width" of 0.1
cᵢ(x, y, z) = amplitude[]
temp = Base.broadcasted(Base.identity, FunctionField((Center, Center, Center), cᵢ, model_tracer.grid))

temp = convert(Base.Broadcast.Broadcasted{Nothing}, temp)
grid = model_tracer.grid
arch = architecture(model_tracer)

param = Oceananigans.Utils.KernelParameters(size(model_tracer), map(Oceananigans.Fields.offset_index, model_tracer.indices))
Oceananigans.Utils.launch!(arch, grid, param, Oceananigans.Fields._broadcast_kernel!, model_tracer, temp)

return nothing
end

function set_initial_condition!(model, amplitude)
amplitude = Ref(amplitude)

# This has a "width" of 0.1
cᵢ(x, y, z) = amplitude[] * exp(-z^2 / 0.02 - (x^2 + y^2) / 0.05)
set!(model, c=cᵢ)

return nothing
end

@testset "Enzyme + Oceananigans Initialization Broadcast Kernel" begin

Enzyme.API.looseTypeAnalysis!(true)

Nx = Ny = 64
Nz = 8

x = y = (-π, π)
z = (-0.5, 0.5)
topology = (Periodic, Periodic, Bounded)

grid = RectilinearGrid(size=(Nx, Ny, Nz); x, y, z, topology)

model = HydrostaticFreeSurfaceModel(; grid,
tracers = :c,
buoyancy = nothing)

model_tracer = model.tracers.c

amplitude = 1.0
amplitude = Ref(amplitude)

# This has a "width" of 0.1
cᵢ(x, y, z) = amplitude[]
temp = Base.broadcasted(Base.identity, FunctionField((Center, Center, Center), cᵢ, model_tracer.grid))

temp = convert(Base.Broadcast.Broadcasted{Nothing}, temp)
grid = model_tracer.grid
arch = architecture(model_tracer)

param = Oceananigans.Utils.KernelParameters(size(model_tracer), map(Oceananigans.Fields.offset_index, model_tracer.indices))

dmodel_tracer = Enzyme.make_zero(model_tracer)

# Test the individual kernel launch
autodiff(Enzyme.Reverse,
Oceananigans.Utils.launch!,
Const(arch),
Const(grid),
Const(param),
Const(Oceananigans.Fields._broadcast_kernel!),
Duplicated(model_tracer, dmodel_tracer),
Const(temp))

# Test out differentiation of the broadcast infrastructure
autodiff(Enzyme.Reverse,
set_initial_condition_via_launch!,
Duplicated(model_tracer, dmodel_tracer),
Active(1.0))

# Test differentiation of the high-level set interface
dmodel = Enzyme.make_zero(model)
autodiff(Enzyme.Reverse,
set_initial_condition!,
Duplicated(model, dmodel),
Active(1.0))

Enzyme.API.looseTypeAnalysis!(false)
end