Skip to content

Commit c20fe61

Browse files
Introduce post-operation callback
1 parent b099c3e commit c20fe61

33 files changed

+286
-42
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ withenv("GKSwstype" => "nul") do
8686
tutorial in TUTORIALS
8787
],
8888
"Examples" => "examples.md",
89+
"Debugging" => "debugging.md",
8990
"Libraries" => [
9091
joinpath("lib", "ClimaCorePlots.md"),
9192
joinpath("lib", "ClimaCoreMakie.md"),

docs/src/api.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,12 @@ InputOutput.defaultname
401401
Remapping.interpolate_array
402402
Remapping.interpolate
403403
```
404+
405+
## DebugOnly
406+
407+
```@docs
408+
DebugOnly
409+
DebugOnly.call_post_op_callback
410+
DebugOnly.post_op_callback
411+
DebugOnly.example_debug_post_op_callback
412+
```

docs/src/debugging.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Debugging
2+
3+
One of the most challenging tasks that users have is: debug a large simulation
4+
that is breaking, e.g., yielding `NaN`s somewhere. This is especially complex
5+
for large models with many terms and implicit time-stepping with all the bells
6+
and whistles that the CliMA ecosystem offers.
7+
8+
ClimaCore has a module, [`ClimaCore.DebugOnly`](@ref), which contains tools for
9+
debugging simulations for these complicated situations.
10+
11+
Because so much data (for example, the solution state, and many cached fields)
12+
is typically contained in ClimaCore data structures, we offer a hook to inspect
13+
this data after any operation that ClimaCore performs.
14+
15+
## Example
16+
17+
```@example
18+
import ClimaCore
19+
using ClimaCore: DataLayouts
20+
ClimaCore.DebugOnly.call_post_op_callback() = true
21+
function ClimaCore.DebugOnly.post_op_callback(result, args...; kwargs...)
22+
if any(isnan, parent(data))
23+
println("NaNs found!")
24+
end
25+
end
26+
27+
FT = Float64;
28+
data = DataLayouts.VIJFH{FT}(Array{FT}, zeros; Nv=5, Nij=2, Nh=2)
29+
@. data = NaN
30+
```
31+
32+
Note that, due to dispatch, `post_op_callback` will likely need a very general
33+
method signature, and using `post_op_callback
34+
(result::DataLayouts.VIJFH, args...; kwargs...)` above fails (on the CPU),
35+
because `post_op_callback` ends up getting called multiple times with different
36+
datalayouts.
37+
38+
!!! warn
39+
40+
While this debugging tool may be helpful, it's not bullet proof. NaNs can
41+
infiltrate user data any time internals are used. For example `parent
42+
(data) .= NaN` will not be caught by ClimaCore.DebugOnly, and errors can be
43+
observed later than expected.
44+
45+
!!! note
46+
47+
This method is called in many places, so this is a performance-critical code
48+
path and expensive operations performed in `post_op_callback` may
49+
significantly slow down your code.

ext/ClimaCoreCUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import CUDA
99
using CUDA
1010
using CUDA: threadIdx, blockIdx, blockDim
1111
import StaticArrays: SVector, SMatrix, SArray
12+
import ClimaCore.DebugOnly: call_post_op_callback, post_op_callback
1213
import ClimaCore.DataLayouts: mapreduce_cuda
1314
import ClimaCore.DataLayouts: ToCUDA
1415
import ClimaCore.DataLayouts: slab, column

ext/cuda/data_layouts_copyto.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ if VERSION ≥ v"1.11.0-beta"
2424
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
2525
# (including the GPU-variant related issue resolution efforts:
2626
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
27-
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
27+
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
2828
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
2929
us = DataLayouts.UniversalSize(dest)
3030
if Nv > 0 && Nh > 0
@@ -39,10 +39,11 @@ if VERSION ≥ v"1.11.0-beta"
3939
blocks_s = p.blocks,
4040
)
4141
end
42+
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
4243
return dest
4344
end
4445
else
45-
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
46+
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
4647
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
4748
us = DataLayouts.UniversalSize(dest)
4849
if Nv > 0 && Nh > 0
@@ -74,6 +75,7 @@ else
7475
)
7576
end
7677
end
78+
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
7779
return dest
7880
end
7981
end
@@ -85,7 +87,7 @@ end
8587
function Base.copyto!(
8688
dest::AbstractData,
8789
bc::Base.Broadcast.Broadcasted{Style},
88-
::ToCUDA,
90+
to::ToCUDA,
8991
) where {
9092
Style <:
9193
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
@@ -95,13 +97,14 @@ function Base.copyto!(
9597
)
9698
@inbounds bc0 = bc[]
9799
fill!(dest, bc0)
100+
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
98101
end
99102

100103
# For field-vector operations
101104
function DataLayouts.copyto_per_field!(
102105
array::AbstractArray,
103106
bc::Union{AbstractArray, Base.Broadcast.Broadcasted},
104-
::ToCUDA,
107+
to::ToCUDA,
105108
)
106109
bc′ = DataLayouts.to_non_extruded_broadcasted(bc)
107110
# All field variables are treated separately, so
@@ -119,6 +122,7 @@ function DataLayouts.copyto_per_field!(
119122
threads_s = p.threads,
120123
blocks_s = p.blocks,
121124
)
125+
call_post_op_callback() && post_op_callback(array, array, bc, to)
122126
return array
123127
end
124128
function copyto_per_field_kernel!(array, bc, N)
@@ -133,7 +137,7 @@ end
133137
function DataLayouts.copyto_per_field_scalar!(
134138
array::AbstractArray,
135139
bc::Base.Broadcast.Broadcasted{Style},
136-
::ToCUDA,
140+
to::ToCUDA,
137141
) where {
138142
Style <:
139143
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
@@ -154,12 +158,13 @@ function DataLayouts.copyto_per_field_scalar!(
154158
threads_s = p.threads,
155159
blocks_s = p.blocks,
156160
)
161+
call_post_op_callback() && post_op_callback(array, array, bc, to)
157162
return array
158163
end
159164
function DataLayouts.copyto_per_field_scalar!(
160165
array::AbstractArray,
161166
bc::Real,
162-
::ToCUDA,
167+
to::ToCUDA,
163168
)
164169
bc′ = DataLayouts.to_non_extruded_broadcasted(bc)
165170
# All field variables are treated separately, so
@@ -177,6 +182,7 @@ function DataLayouts.copyto_per_field_scalar!(
177182
threads_s = p.threads,
178183
blocks_s = p.blocks,
179184
)
185+
call_post_op_callback() && post_op_callback(array, array, bc, to)
180186
return array
181187
end
182188
function copyto_per_field_kernel_0D!(array, bc, N)

ext/cuda/data_layouts_fill.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function knl_fill_linear!(dest, val, us)
1414
return nothing
1515
end
1616

17-
function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
17+
function Base.fill!(dest::AbstractData, bc, to::ToCUDA)
1818
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
1919
us = DataLayouts.UniversalSize(dest)
2020
args = (dest, bc, us)
@@ -41,5 +41,6 @@ function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
4141
)
4242
end
4343
end
44+
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
4445
return dest
4546
end

ext/cuda/data_layouts_mapreduce.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ function mapreduce_cuda(
1515
)
1616
pdata = parent(data)
1717
S = eltype(data)
18-
return DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :]))
18+
data_out = DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :]))
19+
call_post_op_callback() &&
20+
post_op_callback(data_out, f, op, data; weighted_jacobian, opargs...)
21+
return data_out
1922
end
2023

2124
function mapreduce_cuda(
@@ -101,7 +104,11 @@ function mapreduce_cuda(
101104
Val(shmemsize),
102105
)
103106
end
104-
return DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :]))
107+
data_out = DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :]))
108+
109+
call_post_op_callback() &&
110+
post_op_callback(data_out, f, op, data; weighted_jacobian, opargs...)
111+
return data_out
105112
end
106113

107114
function mapreduce_cuda_kernel!(

ext/cuda/fields.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,46 +13,52 @@ end
1313

1414
function Base.sum(
1515
field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}},
16-
::ClimaComms.CUDADevice,
16+
dev::ClimaComms.CUDADevice,
1717
)
1818
context = ClimaComms.context(axes(field))
1919
localsum = mapreduce_cuda(identity, +, field, weighting = true)
2020
ClimaComms.allreduce!(context, parent(localsum), +)
21+
call_post_op_callback() && post_op_callback(localsum[], field, dev)
2122
return localsum[]
2223
end
2324

24-
function Base.sum(fn, field::Field, ::ClimaComms.CUDADevice)
25+
function Base.sum(fn, field::Field, dev::ClimaComms.CUDADevice)
2526
context = ClimaComms.context(axes(field))
2627
localsum = mapreduce_cuda(fn, +, field, weighting = true)
2728
ClimaComms.allreduce!(context, parent(localsum), +)
29+
call_post_op_callback() && post_op_callback(localsum[], fn, field, dev)
2830
return localsum[]
2931
end
3032

31-
function Base.maximum(fn, field::Field, ::ClimaComms.CUDADevice)
33+
function Base.maximum(fn, field::Field, dev::ClimaComms.CUDADevice)
3234
context = ClimaComms.context(axes(field))
3335
localmax = mapreduce_cuda(fn, max, field)
3436
ClimaComms.allreduce!(context, parent(localmax), max)
37+
call_post_op_callback() && post_op_callback(localmax[], fn, field, dev)
3538
return localmax[]
3639
end
3740

38-
function Base.maximum(field::Field, ::ClimaComms.CUDADevice)
41+
function Base.maximum(field::Field, dev::ClimaComms.CUDADevice)
3942
context = ClimaComms.context(axes(field))
4043
localmax = mapreduce_cuda(identity, max, field)
4144
ClimaComms.allreduce!(context, parent(localmax), max)
45+
call_post_op_callback() && post_op_callback(localmax[], fn, field, dev)
4246
return localmax[]
4347
end
4448

45-
function Base.minimum(fn, field::Field, ::ClimaComms.CUDADevice)
49+
function Base.minimum(fn, field::Field, dev::ClimaComms.CUDADevice)
4650
context = ClimaComms.context(axes(field))
4751
localmin = mapreduce_cuda(fn, min, field)
4852
ClimaComms.allreduce!(context, parent(localmin), min)
53+
call_post_op_callback() && post_op_callback(localmin[], fn, field, dev)
4954
return localmin[]
5055
end
5156

5257
function Base.minimum(field::Field, ::ClimaComms.CUDADevice)
5358
context = ClimaComms.context(axes(field))
5459
localmin = mapreduce_cuda(identity, min, field)
5560
ClimaComms.allreduce!(context, parent(localmin), min)
61+
call_post_op_callback() && post_op_callback(localmin[], fn, field, dev)
5662
return localmin[]
5763
end
5864

ext/cuda/limiters.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function compute_element_bounds!(
1919
limiter::QuasiMonotoneLimiter,
2020
ρq,
2121
ρ,
22-
::ClimaComms.CUDADevice,
22+
dev::ClimaComms.CUDADevice,
2323
)
2424
ρ_values = Fields.field_values(Operators.strip_space(ρ, axes(ρ)))
2525
ρq_values = Fields.field_values(Operators.strip_space(ρq, axes(ρq)))
@@ -33,6 +33,8 @@ function compute_element_bounds!(
3333
threads_s = nthreads,
3434
blocks_s = nblocks,
3535
)
36+
call_post_op_callback() &&
37+
post_op_callback(limiter.q_bounds, limiter, ρq, ρ, dev)
3638
return nothing
3739
end
3840

@@ -70,7 +72,7 @@ end
7072
function compute_neighbor_bounds_local!(
7173
limiter::QuasiMonotoneLimiter,
7274
ρ,
73-
::ClimaComms.CUDADevice,
75+
dev::ClimaComms.CUDADevice,
7476
)
7577
topology = Spaces.topology(axes(ρ))
7678
us = DataLayouts.UniversalSize(Fields.field_values(ρ))
@@ -88,6 +90,8 @@ function compute_neighbor_bounds_local!(
8890
threads_s = nthreads,
8991
blocks_s = nblocks,
9092
)
93+
call_post_op_callback() &&
94+
post_op_callback(limiter.q_bounds, limiter, ρ, dev)
9195
end
9296

9397
function compute_neighbor_bounds_local_kernel!(
@@ -123,7 +127,7 @@ function apply_limiter!(
123127
ρq::Fields.Field,
124128
ρ::Fields.Field,
125129
limiter::QuasiMonotoneLimiter,
126-
::ClimaComms.CUDADevice,
130+
dev::ClimaComms.CUDADevice,
127131
)
128132
ρq_data = Fields.field_values(ρq)
129133
us = DataLayouts.UniversalSize(ρq_data)
@@ -147,6 +151,7 @@ function apply_limiter!(
147151
threads_s = nthreads,
148152
blocks_s = nblocks,
149153
)
154+
call_post_op_callback() && post_op_callback(ρq, ρq, ρ, limiter, dev)
150155
return nothing
151156
end
152157

ext/cuda/matrix_fields_multiple_field_solve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import ClimaCore.Utilities.UnrolledFunctions: unrolled_map
1111
is_CuArray_type(::Type{T}) where {T <: CUDA.CuArray} = true
1212

1313
NVTX.@annotate function multiple_field_solve!(
14-
::ClimaComms.CUDADevice,
14+
dev::ClimaComms.CUDADevice,
1515
cache,
1616
x,
1717
A,
@@ -48,6 +48,7 @@ NVTX.@annotate function multiple_field_solve!(
4848
blocks_s = p.blocks,
4949
always_inline = true,
5050
)
51+
call_post_op_callback() && post_op_callback(x, dev, cache, x, A, b, x1)
5152
end
5253

5354
Base.@propagate_inbounds column_A(A::UniformScaling, i, j, h) = A

0 commit comments

Comments
 (0)