|
| 1 | +#= |
| 2 | +julia --project=.buildkite |
| 3 | +using Revise; include(joinpath("benchmarks", "scripts", "linear_vs_cartesian_indexing.jl")) |
| 4 | +
|
| 5 | +# Info: |
| 6 | +Linear indexing, when possible, has performance advantages |
| 7 | +over using Cartesian indexing. Julia Base's Broadcast only |
| 8 | +supports Cartesian indexing as it provides more general support |
| 9 | +for "extruded"-style broadcasting, where shapes of input/output |
| 10 | +arrays can change. |
| 11 | +
|
| 12 | +This script (re-)defines some broadcast machinery and tests |
| 13 | +the performance of vector vs array operations in a broadcast |
| 14 | +setting where linear indexing is allowed. |
| 15 | +
|
| 16 | +# References: |
| 17 | + - https://github.com/CliMA/ClimaCore.jl/issues/1889 |
| 18 | + - https://github.com/JuliaLang/julia/issues/28126 |
| 19 | + - https://github.com/JuliaLang/julia/issues/32051 |
| 20 | +
|
| 21 | +# Benchmark results: |
| 22 | +
|
| 23 | +Local Apple M1 Mac (CPU): |
| 24 | +``` |
| 25 | +at_dot_call!($X_array, $Y_array): |
| 26 | + 146 milliseconds, 558 microseconds |
| 27 | +at_dot_call!($X_vector, $Y_vector): |
| 28 | + 65 milliseconds, 531 microseconds |
| 29 | +custom_kernel_bc!($X_vector, $Y_vector, $(Val(length(X_vector.x1))); printtb = false): |
| 30 | + 66 milliseconds, 735 microseconds |
| 31 | +custom_kernel_bc!($X_array, $Y_array, $(Val(length(X_vector.x1))); printtb = false, use_pw = false): |
| 32 | + 145 milliseconds, 957 microseconds |
| 33 | +custom_kernel_bc!($X_array, $Y_array, $(Val(length(X_vector.x1))); printtb = false, use_pw = true): |
| 34 | + 66 milliseconds, 320 microseconds |
| 35 | +``` |
| 36 | +
|
| 37 | +Clima A100 |
| 38 | +``` |
| 39 | +at_dot_call!($X_vector, $Y_vector): |
| 40 | + 2 milliseconds, 848 microseconds |
| 41 | +custom_kernel_bc!($X_vector, $Y_vector, $(Val(length(X_vector.x1))); printtb = false): |
| 42 | + 2 milliseconds, 537 microseconds |
| 43 | +custom_kernel_bc!($X_array, $Y_array, $(Val(length(X_vector.x1))); printtb = false, use_pw = false): |
| 44 | + 8 milliseconds, 804 microseconds |
| 45 | +custom_kernel_bc!($X_array, $Y_array, $(Val(length(X_vector.x1))); printtb = false, use_pw = true): |
| 46 | + 2 milliseconds, 545 microseconds |
| 47 | +``` |
| 48 | +=# |
| 49 | + |
| 50 | +#! format: off |
| 51 | +import CUDA |
| 52 | +using BenchmarkTools, Dates |
| 53 | +using LazyBroadcast: @lazy |
| 54 | +ArrayType = CUDA.CuArray; |
| 55 | +# ArrayType = identity; |
| 56 | + |
| 57 | +# ============================================================ Non-extruded broadcast (start) |
| 58 | +import Base.Broadcast: BroadcastStyle |
| 59 | +struct PointWiseBC{ |
| 60 | + Style <: Union{Nothing, BroadcastStyle}, |
| 61 | + Axes, |
| 62 | + F, |
| 63 | + Args <: Tuple, |
| 64 | +} <: Base.AbstractBroadcasted |
| 65 | + style::Style |
| 66 | + f::F |
| 67 | + args::Args |
| 68 | + axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `PointWiseBC`) |
| 69 | + |
| 70 | + PointWiseBC(style::Union{Nothing, BroadcastStyle}, f::Tuple, args::Tuple) = |
| 71 | + error() # disambiguation: tuple is not callable |
| 72 | + function PointWiseBC( |
| 73 | + style::Union{Nothing, BroadcastStyle}, |
| 74 | + f::F, |
| 75 | + args::Tuple, |
| 76 | + axes = nothing, |
| 77 | + ) where {F} |
| 78 | + # using Core.Typeof rather than F preserves inferrability when f is a type |
| 79 | + return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}( |
| 80 | + style, |
| 81 | + f, |
| 82 | + args, |
| 83 | + axes, |
| 84 | + ) |
| 85 | + end |
| 86 | + function PointWiseBC(f::F, args::Tuple, axes = nothing) where {F} |
| 87 | + PointWiseBC(combine_styles(args...)::BroadcastStyle, f, args, axes) |
| 88 | + end |
| 89 | + function PointWiseBC{Style}(f::F, args, axes = nothing) where {Style, F} |
| 90 | + return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}( |
| 91 | + Style()::Style, |
| 92 | + f, |
| 93 | + args, |
| 94 | + axes, |
| 95 | + ) |
| 96 | + end |
| 97 | + function PointWiseBC{Style, Axes, F, Args}( |
| 98 | + f, |
| 99 | + args, |
| 100 | + axes, |
| 101 | + ) where {Style, Axes, F, Args} |
| 102 | + return new{Style, Axes, F, Args}(Style()::Style, f, args, axes) |
| 103 | + end |
| 104 | +end |
| 105 | + |
| 106 | +import Adapt |
| 107 | +import CUDA |
| 108 | +function Adapt.adapt_structure( |
| 109 | + to::CUDA.KernelAdaptor, |
| 110 | + bc::PointWiseBC{Style}, |
| 111 | +) where {Style} |
| 112 | + PointWiseBC{Style}( |
| 113 | + Adapt.adapt(to, bc.f), |
| 114 | + Adapt.adapt(to, bc.args), |
| 115 | + Adapt.adapt(to, bc.axes), |
| 116 | + ) |
| 117 | +end |
| 118 | + |
| 119 | +@inline to_pointwise_bc(bc::Base.Broadcast.Broadcasted) = |
| 120 | + PointWiseBC(bc.style, bc.f, bc.args, bc.axes) |
| 121 | +@inline to_pointwise_bc(x) = x |
| 122 | +PointWiseBC(bc::Base.Broadcast.Broadcasted) = to_pointwise_bc(bc) |
| 123 | + |
| 124 | +@inline to_pointwise_bc_args(args::Tuple, inds...) = ( |
| 125 | + to_pointwise_bc(args[1], inds...), |
| 126 | + to_pointwise_bc_args(Base.tail(args), inds...)..., |
| 127 | +) |
| 128 | +@inline to_pointwise_bc_args(args::Tuple{Any}, inds...) = |
| 129 | + (to_pointwise_bc(args[1], inds...),) |
| 130 | +@inline to_pointwise_bc_args(args::Tuple{}, inds...) = () |
| 131 | + |
| 132 | +@inline function to_pointwise_bc(bc::Base.Broadcast.Broadcasted, symb, axes) |
| 133 | + Base.Broadcast.Broadcasted( |
| 134 | + bc.f, |
| 135 | + to_pointwise_bc_args(bc.args, symb, axes), |
| 136 | + axes, |
| 137 | + ) |
| 138 | +end |
| 139 | +@inline to_pointwise_bc(x, symb, axes) = x |
| 140 | + |
| 141 | +@inline function Base.getindex( |
| 142 | + bc::PointWiseBC, |
| 143 | + I::Union{Integer, CartesianIndex}, |
| 144 | +) |
| 145 | + @boundscheck Base.checkbounds(bc, I) # is this really the only issue? |
| 146 | + @inbounds _broadcast_getindex(bc, I) |
| 147 | +end |
| 148 | +Base.@propagate_inbounds _broadcast_getindex( |
| 149 | + A::Union{Ref, AbstractArray{<:Any, 0}, Number}, |
| 150 | + I::Integer, |
| 151 | +) = A[] # Scalar-likes can just ignore all indices |
| 152 | +Base.@propagate_inbounds _broadcast_getindex( |
| 153 | + ::Ref{Type{T}}, |
| 154 | + I::Integer, |
| 155 | +) where {T} = T |
| 156 | +# Tuples are statically known to be singleton or vector-like |
| 157 | +Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I::Integer) = A[1] |
| 158 | +Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I::Integer) = A[I[1]] |
| 159 | +# Everything else falls back to dynamically dropping broadcasted indices based upon its axes |
| 160 | +# Base.@propagate_inbounds _broadcast_getindex(A, I) = A[newindex(A, I)] |
| 161 | +Base.@propagate_inbounds _broadcast_getindex(A, I::Integer) = A[I] |
| 162 | +Base.@propagate_inbounds function _broadcast_getindex( |
| 163 | + bc::PointWiseBC{<:Any, <:Any, <:Any, <:Any}, |
| 164 | + I::Integer, |
| 165 | +) |
| 166 | + args = _getindex(bc.args, I) |
| 167 | + return _broadcast_getindex_evalf(bc.f, args...) |
| 168 | +end |
| 169 | +@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any, N}) where {Tf, N} = |
| 170 | + f(args...) # not propagate_inbounds |
| 171 | +Base.@propagate_inbounds _getindex(args::Tuple, I) = |
| 172 | + (_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...) |
| 173 | +Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) = |
| 174 | + (_broadcast_getindex(args[1], I),) |
| 175 | +Base.@propagate_inbounds _getindex(args::Tuple{}, I) = () |
| 176 | + |
| 177 | +@inline Base.axes(bc::PointWiseBC) = _axes(bc, bc.axes) |
| 178 | +_axes(::PointWiseBC, axes::Tuple) = axes |
| 179 | +@inline _axes(bc::PointWiseBC, ::Nothing) = |
| 180 | + Base.Broadcast.combine_axes(bc.args...) |
| 181 | +_axes(bc::PointWiseBC{<:Base.Broadcast.AbstractArrayStyle{0}}, ::Nothing) = () |
| 182 | +@inline Base.axes(bc::PointWiseBC{<:Any, <:NTuple{N}}, d::Integer) where {N} = |
| 183 | + d <= N ? axes(bc)[d] : OneTo(1) |
| 184 | +Base.IndexStyle(::Type{<:PointWiseBC{<:Any, <:Tuple{Any}}}) = IndexLinear() |
| 185 | +# ============================================================ Non-extruded broadcast (end) |
| 186 | + |
| 187 | +if ArrayType === identity |
| 188 | + macro pretty_belapsed(expr) |
| 189 | + return quote |
| 190 | + println($(string(expr)), ":") |
| 191 | + print(" ") |
| 192 | + print_time_and_units(BenchmarkTools.@belapsed(esc($expr))) |
| 193 | + end |
| 194 | + end |
| 195 | + macro pretty_elapsed(expr) |
| 196 | + return quote |
| 197 | + println($(string(expr)), ":") |
| 198 | + print(" ") |
| 199 | + print_time_and_units(BenchmarkTools.@elapsed(esc($expr))) |
| 200 | + end |
| 201 | + end |
| 202 | +else |
| 203 | + macro pretty_belapsed(expr) |
| 204 | + return quote |
| 205 | + println($(string(expr)), ":") |
| 206 | + print(" ") |
| 207 | + print_time_and_units( |
| 208 | + BenchmarkTools.@belapsed(CUDA.@sync((esc($expr)))) |
| 209 | + ) |
| 210 | + end |
| 211 | + end |
| 212 | + macro pretty_elapsed(expr) |
| 213 | + return quote |
| 214 | + println($(string(expr)), ":") |
| 215 | + print(" ") |
| 216 | + print_time_and_units( |
| 217 | + BenchmarkTools.@elapsed(CUDA.@sync((esc($expr)))) |
| 218 | + ) |
| 219 | + end |
| 220 | + end |
| 221 | +end |
| 222 | +print_time_and_units(x) = println(time_and_units_str(x)) |
| 223 | +time_and_units_str(x::Real) = |
| 224 | + trunc_time(string(compound_period(x, Dates.Second))) |
| 225 | +function compound_period(x::Real, ::Type{T}) where {T <: Dates.Period} |
| 226 | + nf = Dates.value(convert(Dates.Nanosecond, T(1))) |
| 227 | + ns = Dates.Nanosecond(ceil(x * nf)) |
| 228 | + return Dates.canonicalize(Dates.CompoundPeriod(ns)) |
| 229 | +end |
| 230 | +trunc_time(s::String) = count(',', s) > 1 ? join(split(s, ",")[1:2], ",") : s |
| 231 | +myadd(x1, x2, x3) = zero(x1) |
| 232 | +function at_dot_call!(X, Y) |
| 233 | + (; x1, x2, x3) = X |
| 234 | + (; y1) = Y |
| 235 | + for i in 1:100 # reduce variance / impact of launch latency |
| 236 | + @. y1 = myadd(x1, x2, x3) # 3 reads, 1 write |
| 237 | + # @. y1 = 0 # 3 reads, 1 write |
| 238 | + end |
| 239 | + return nothing |
| 240 | +end; |
| 241 | + |
| 242 | +function custom_kernel!(X, Y, ::Val{N}) where {N} |
| 243 | + (; x1, x2, x3) = X |
| 244 | + (; y1) = Y |
| 245 | + kernel = CUDA.@cuda always_inline = true launch = false custom_kernel_knl!( |
| 246 | + y1, |
| 247 | + x1, |
| 248 | + x2, |
| 249 | + x3, |
| 250 | + Val(N), |
| 251 | + ) |
| 252 | + config = CUDA.launch_configuration(kernel.fun) |
| 253 | + threads = min(N, config.threads) |
| 254 | + blocks = cld(N, threads) |
| 255 | + for i in 1:100 # reduce variance / impact of launch latency |
| 256 | + kernel(y1, x1, x2, x3, Val(N); threads, blocks) |
| 257 | + end |
| 258 | + return nothing |
| 259 | +end; |
| 260 | +function custom_kernel_knl!(y1, x1, x2, x3, ::Val{N}) where {N} |
| 261 | + @inbounds begin |
| 262 | + I = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x |
| 263 | + if I ≤ N |
| 264 | + y1[I] = myadd(x1[I], x2[I], x3[I]) |
| 265 | + end |
| 266 | + end |
| 267 | + return nothing |
| 268 | +end; |
| 269 | + |
| 270 | +function custom_kernel_bc!(X, Y, ::Val{N}; printtb=true, use_pw=true) where {N} |
| 271 | + (; x1, x2, x3) = X |
| 272 | + (; y1) = Y |
| 273 | + bc_base = @lazy @. y1 = myadd(x1, x2, x3) |
| 274 | + bc = use_pw ? to_pointwise_bc(bc_base) : bc_base |
| 275 | + if y1 isa Array |
| 276 | + if bc isa Base.Broadcast.Broadcasted |
| 277 | + for i in 1:100 # reduce variance / impact of launch latency |
| 278 | + @inbounds @simd for j in eachindex(bc) |
| 279 | + y1[j] = bc[j] |
| 280 | + end |
| 281 | + end |
| 282 | + else |
| 283 | + for i in 1:100 # reduce variance / impact of launch latency |
| 284 | + @inbounds @simd for j in 1:N |
| 285 | + y1[j] = bc[j] |
| 286 | + end |
| 287 | + end |
| 288 | + end |
| 289 | + else |
| 290 | + kernel = |
| 291 | + CUDA.@cuda always_inline = true launch = false custom_kernel_knl_bc!( |
| 292 | + y1, |
| 293 | + bc, |
| 294 | + Val(N), |
| 295 | + ) |
| 296 | + config = CUDA.launch_configuration(kernel.fun) |
| 297 | + threads = min(N, config.threads) |
| 298 | + blocks = cld(N, threads) |
| 299 | + printtb && @show blocks, threads |
| 300 | + for i in 1:100 # reduce variance / impact of launch latency |
| 301 | + kernel(y1, bc, Val(N); threads, blocks) |
| 302 | + end |
| 303 | + end |
| 304 | + return nothing |
| 305 | +end; |
| 306 | +@inline get_cart_lin_index(bc, n, I) = I |
| 307 | +@inline get_cart_lin_index(bc::Base.Broadcast.Broadcasted, n, I) = |
| 308 | + CartesianIndices(map(x -> Base.OneTo(x), n))[I] |
| 309 | +function custom_kernel_knl_bc!(y1, bc, ::Val{N}) where {N} |
| 310 | + @inbounds begin |
| 311 | + I = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x |
| 312 | + n = size(y1) |
| 313 | + if 1 ≤ I ≤ N |
| 314 | + ind = get_cart_lin_index(bc, n, I) |
| 315 | + y1[ind] = bc[ind] |
| 316 | + end |
| 317 | + end |
| 318 | + return nothing |
| 319 | +end; |
| 320 | + |
| 321 | +FT = Float32; |
| 322 | +arr(T) = T(zeros(63,4,4,1,5400)) |
| 323 | +X_array = (;x1 = arr(ArrayType),x2 = arr(ArrayType),x3 = arr(ArrayType)); |
| 324 | +Y_array = (;y1 = arr(ArrayType),); |
| 325 | +to_vec(ξ) = (;zip(propertynames(ξ), map(θ -> vec(θ), values(ξ)))...); |
| 326 | +X_vector = to_vec(X_array); |
| 327 | +Y_vector = to_vec(Y_array); |
| 328 | +at_dot_call!(X_array, Y_array) |
| 329 | +at_dot_call!(X_vector, Y_vector) |
| 330 | +# custom_kernel!(X_vector, Y_vector, Val(length(X_vector.x1))) |
| 331 | +custom_kernel_bc!(X_vector, Y_vector, Val(length(X_vector.x1))) |
| 332 | +custom_kernel_bc!(X_array, Y_array, Val(length(X_vector.x1)); use_pw=false) |
| 333 | +custom_kernel_bc!(X_array, Y_array, Val(length(X_vector.x1)); use_pw=true) |
| 334 | + |
| 335 | +@pretty_belapsed at_dot_call!($X_array, $Y_array) # slow |
| 336 | +@pretty_belapsed at_dot_call!($X_vector, $Y_vector) # fast |
| 337 | +# @pretty_belapsed custom_kernel!($X_vector, $Y_vector, $(Val(length(X_vector.x1)))) |
| 338 | +@pretty_belapsed custom_kernel_bc!($X_vector, $Y_vector, $(Val(length(X_vector.x1)));printtb=false) |
| 339 | +@pretty_belapsed custom_kernel_bc!($X_array, $Y_array, $(Val(length(X_vector.x1)));printtb=false, use_pw=false) |
| 340 | +@pretty_belapsed custom_kernel_bc!($X_array, $Y_array, $(Val(length(X_vector.x1)));printtb=false, use_pw=true) |
| 341 | + |
| 342 | +#! format: on |
0 commit comments