-
Notifications
You must be signed in to change notification settings - Fork 13
fixed scalarIndexing performance issue #1215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thanks @riteshbhirud! Does this reduce allocations, though? I believe this line still allocates: Also, confusingly, the |
src/shared_utilities/utils.jl
Outdated
if isnothing(mask) | ||
num_nans = sum(@. ifelse(isnan(state), 1, 0)) | ||
else | ||
num_nans = sum(@. ifelse(isnan(state), mask, 0)) | ||
end | ||
num_nans = Int(num_nans) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can further optimize this and reduce some of the allocations. For the no mask case, we can go from 2 kernel launches to just one if we use mapreduce
or count
. I believe @kmdeck is correct, and line 597 will still allocate an intermediate. For the case with a mask, we can also get away with only one kernel launch and no allocations with mapreduce
.
That could look something like
if isnothing(mask) | |
num_nans = sum(@. ifelse(isnan(state), 1, 0)) | |
else | |
num_nans = sum(@. ifelse(isnan(state), mask, 0)) | |
end | |
num_nans = Int(num_nans) | |
if isnothing(mask) | |
num_nans = count(isnan, parent(state) | |
else | |
num_nans = mapreduce((s, m) -> m != 0 && isnan(s), Base.add_sum, parent(state), parent(mask)) | |
end |
That would also avoid the type conversion.
I'm not sure if masks are guaranteed to be boolean valued, so I'm not sure if line 599 will always be valid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
The mask is not boolean, because we create it using this:
ClimaLand.jl/src/shared_utilities/Domains.jl
Line 1047 in 7ac4cd4
apply_threshold(field, value) = |
I think we did this because you cant create a field of bools by broadcasting over a field of floats. But we can do something like:
apply_threshold(field, value) =
field > value ? 0 : 1
and then later down in the landsea_mask function, change thishttps://github.com/CliMA/ClimaLand.jl/blob/7ac4cd452663559f20df22d2a60e8f1aaa92d90f/src/shared_utilities/Domains.jl#L1088
to
binary_mask = ClimaCore.Fields.Field(Bool, surface_space)
fill!(ClimaCore.Fields.field_values(binary_mask), 0)
@. binary_mask = apply_threshold(mask, threshold)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One other thought is that we call this function for each state variable in Y, but with one exception, if one is NaN, the entire state is NaN at that same location. so maybe we can also speed it up by calling it for just one field.
I'm not sure how to do that in a nice automated way though, without making an assumption about what is in Y
apologies for the allocations, I have now replaced the multi-step array creation and summation with single-kernel count/mapreduce operations to eliminate allocations. Can you please take a look once you have chance if this is what we expected? |
|
@riteshbhirud I am going to try running a simulation where we use the NaN check callback and see if it runs and if SYPD increases :) will report back! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, this is great!
Looking at just the callback by itself, this change makes the first call slightly slower, but then the subsequent calls are about 5x faster. @allocations
also reports no allocations now!
I'll run a benchmark comparing this to main.
Could you please run the .dev/climaformat.jl
to format these changes? It looks like the formatting test is failing.
@imreddyTeja I ran the .dev/climaformat.jl. Can you please run the CI if all good to merge? Thanks! |
fixes #1209
Fixes performance regression in the NaN counting callback.
Replaced inefficient scalar operations with proper ClimaCore field operations:
allowscalar
block entirely@. ifelse(isnan(state), ...)
instead ofparent(state)
Results: