Skip to content

Commit 13fe60c

Browse files
Remove julia 1.11 conditionals
1 parent adda6bf commit 13fe60c

File tree

7 files changed

+60
-175
lines changed

7 files changed

+60
-175
lines changed

ext/cuda/data_layouts_copyto.jl

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,26 @@ function knl_copyto_linear!(dest, src, us)
1616
return nothing
1717
end
1818

19-
if VERSION v"1.11.0-beta"
20-
# https://github.com/JuliaLang/julia/issues/56295
21-
# Julia 1.11's Base.Broadcast currently requires
22-
# multiple integer indexing, wheras Julia 1.10 did not.
23-
# This means that we cannot reserve linear indexing to
24-
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
25-
# (including the GPU-variant related issue resolution efforts:
26-
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
27-
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
28-
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
29-
us = DataLayouts.UniversalSize(dest)
30-
if Nv > 0 && Nh > 0
19+
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
20+
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
21+
us = DataLayouts.UniversalSize(dest)
22+
if Nv > 0 && Nh > 0
23+
if DataLayouts.has_uniform_datalayouts(bc) &&
24+
dest isa DataLayouts.EndsWithField
25+
bc′ = Base.Broadcast.instantiate(
26+
DataLayouts.to_non_extruded_broadcasted(bc),
27+
)
28+
args = (dest, bc′, us)
29+
threads = threads_via_occupancy(knl_copyto_linear!, args)
30+
n_max_threads = min(threads, get_N(us))
31+
p = linear_partition(prod(size(dest)), n_max_threads)
32+
auto_launch!(
33+
knl_copyto_linear!,
34+
args;
35+
threads_s = p.threads,
36+
blocks_s = p.blocks,
37+
)
38+
else
3139
args = (dest, bc, us)
3240
threads = threads_via_occupancy(knl_copyto!, args)
3341
n_max_threads = min(threads, get_N(us))
@@ -39,45 +47,9 @@ if VERSION ≥ v"1.11.0-beta"
3947
blocks_s = p.blocks,
4048
)
4149
end
42-
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
43-
return dest
44-
end
45-
else
46-
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
47-
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
48-
us = DataLayouts.UniversalSize(dest)
49-
if Nv > 0 && Nh > 0
50-
if DataLayouts.has_uniform_datalayouts(bc) &&
51-
dest isa DataLayouts.EndsWithField
52-
bc′ = Base.Broadcast.instantiate(
53-
DataLayouts.to_non_extruded_broadcasted(bc),
54-
)
55-
args = (dest, bc′, us)
56-
threads = threads_via_occupancy(knl_copyto_linear!, args)
57-
n_max_threads = min(threads, get_N(us))
58-
p = linear_partition(prod(size(dest)), n_max_threads)
59-
auto_launch!(
60-
knl_copyto_linear!,
61-
args;
62-
threads_s = p.threads,
63-
blocks_s = p.blocks,
64-
)
65-
else
66-
args = (dest, bc, us)
67-
threads = threads_via_occupancy(knl_copyto!, args)
68-
n_max_threads = min(threads, get_N(us))
69-
p = partition(dest, n_max_threads)
70-
auto_launch!(
71-
knl_copyto!,
72-
args;
73-
threads_s = p.threads,
74-
blocks_s = p.blocks,
75-
)
76-
end
77-
end
78-
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
79-
return dest
8050
end
51+
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
52+
return dest
8153
end
8254

8355
# broadcasting scalar assignment

ext/cuda/data_layouts_fill.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function Base.fill!(dest::AbstractData, bc, to::ToCUDA)
1919
us = DataLayouts.UniversalSize(dest)
2020
args = (dest, bc, us)
2121
if Nv > 0 && Nh > 0
22-
if !(VERSION v"1.11.0-beta") && dest isa DataLayouts.EndsWithField
22+
if dest isa DataLayouts.EndsWithField
2323
threads = threads_via_occupancy(knl_fill_linear!, args)
2424
n_max_threads = min(threads, get_N(us))
2525
p = linear_partition(prod(size(dest)), n_max_threads)

ext/cuda/data_layouts_fused_copyto.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,7 @@ function launch_fused_copyto!(fmb::FusedMultiBroadcast)
120120
destinations = map(p -> p.first, fmb.pairs)
121121
bcs = map(p -> p.second, fmb.pairs)
122122
if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) &&
123-
all(d -> d isa DataLayouts.EndsWithField, destinations) &&
124-
!(VERSION v"1.11.0-beta")
123+
all(d -> d isa DataLayouts.EndsWithField, destinations)
125124
pairs′ = map(fmb.pairs) do p
126125
bc′ = DataLayouts.to_non_extruded_broadcasted(p.second)
127126
Pair(p.first, Base.Broadcast.instantiate(bc′))

src/DataLayouts/DataLayouts.jl

Lines changed: 17 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,88 +2052,23 @@ end
20522052
const EndsWithField{S} =
20532053
Union{IJHF{S}, IHF{S}, IJF{S}, IF{S}, VF{S}, VIJHF{S}, VIHF{S}}
20542054

2055-
if VERSION v"1.11.0-beta"
2056-
### --------------- Support for multi-dimensional indexing
2057-
# TODO: can we remove this? It's not needed for Julia 1.10,
2058-
# but seems needed in Julia 1.11.
2059-
@inline Base.getindex(
2060-
data::Union{
2061-
IJF,
2062-
IJFH,
2063-
IJHF,
2064-
IFH,
2065-
IHF,
2066-
VIJFH,
2067-
VIJHF,
2068-
VIFH,
2069-
VIHF,
2070-
VF,
2071-
IF,
2072-
},
2073-
I::Vararg{Int, N},
2074-
) where {N} = Base.getindex(
2075-
data,
2076-
CartesianIndex(to_universal_index(singleton(data), I)),
2077-
)
2078-
2079-
@inline Base.setindex!(
2080-
data::Union{
2081-
IJF,
2082-
IJFH,
2083-
IJHF,
2084-
IFH,
2085-
IHF,
2086-
VIJFH,
2087-
VIJHF,
2088-
VIFH,
2089-
VIHF,
2090-
VF,
2091-
IF,
2092-
},
2093-
val,
2094-
I::Vararg{Int, N},
2095-
) where {N} = Base.setindex!(
2096-
data,
2097-
val,
2098-
CartesianIndex(to_universal_index(singleton(data), I)),
2099-
)
2100-
2101-
# Certain datalayouts support special indexing.
2102-
# Like VF datalayouts with `getindex(::VF, v::Integer)`
2103-
#! format: off
2104-
@inline to_universal_index(::VFSingleton, I::NTuple{1, T}) where {T} = (T(1), T(1), T(1), I[1], T(1))
2105-
@inline to_universal_index(::IFSingleton, I::NTuple{1, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
2106-
@inline to_universal_index(::IFSingleton, I::NTuple{2, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
2107-
@inline to_universal_index(::IFSingleton, I::NTuple{3, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
2108-
@inline to_universal_index(::IFSingleton, I::NTuple{4, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
2109-
@inline to_universal_index(::IFSingleton, I::NTuple{5, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
2110-
@inline to_universal_index(::IJFSingleton, I::NTuple{2, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
2111-
@inline to_universal_index(::IJFSingleton, I::NTuple{3, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
2112-
@inline to_universal_index(::IJFSingleton, I::NTuple{4, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
2113-
@inline to_universal_index(::IJFSingleton, I::NTuple{5, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
2114-
@inline to_universal_index(::AbstractDataSingleton, I::NTuple{5}) = I
2115-
#! format: on
2116-
### ---------------
2117-
else
2118-
# Only support datalayouts that end with fields, since those
2119-
# are the only layouts where we can efficiently compute the
2120-
# strides.
2121-
@propagate_inbounds function Base.getindex(
2122-
data::EndsWithField{S},
2123-
I::Integer,
2124-
) where {S}
2125-
s_array = farray_size(data)
2126-
@inbounds get_struct_linear(parent(data), S, I, s_array)
2127-
end
2128-
@propagate_inbounds function Base.setindex!(
2129-
data::EndsWithField{S},
2130-
val,
2131-
I::Integer,
2132-
) where {S}
2133-
s_array = farray_size(data)
2134-
@inbounds set_struct_linear!(parent(data), convert(S, val), I, s_array)
2135-
end
2136-
2055+
# Only support datalayouts that end with fields, since those
2056+
# are the only layouts where we can efficiently compute the
2057+
# strides.
2058+
@propagate_inbounds function Base.getindex(
2059+
data::EndsWithField{S},
2060+
I::Integer,
2061+
) where {S}
2062+
s_array = farray_size(data)
2063+
@inbounds get_struct_linear(parent(data), S, I, s_array)
2064+
end
2065+
@propagate_inbounds function Base.setindex!(
2066+
data::EndsWithField{S},
2067+
val,
2068+
I::Integer,
2069+
) where {S}
2070+
s_array = farray_size(data)
2071+
@inbounds set_struct_linear!(parent(data), convert(S, val), I, s_array)
21372072
end
21382073

21392074
"""

src/DataLayouts/copyto.jl

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,25 @@
11
#####
22
##### Dispatching and edge cases
33
#####
4-
if VERSION v"1.11.0-beta"
5-
# https://github.com/JuliaLang/julia/issues/56295
6-
# Julia 1.11's Base.Broadcast currently requires
7-
# multiple integer indexing, wheras Julia 1.10 did not.
8-
# This means that we cannot reserve linear indexing to
9-
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
10-
# (including the GPU-variant related issue resolution efforts:
11-
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
12-
function Base.copyto!(
13-
dest::AbstractData{S},
14-
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
15-
) where {S}
16-
Base.copyto!(dest, bc, device_dispatch(parent(dest)))
17-
call_post_op_callback() && post_op_callback(dest, dest, bc)
18-
dest
19-
end
20-
else
21-
function Base.copyto!(
22-
dest::AbstractData{S},
23-
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
24-
) where {S}
25-
dev = device_dispatch(parent(dest))
26-
if dev isa ToCPU &&
27-
has_uniform_datalayouts(bc) &&
28-
dest isa EndsWithField &&
29-
!(dest isa DataF)
30-
# Specialize on linear indexing when possible:
31-
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
32-
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
33-
dest[I] = convert(S, bc′[I])
34-
end
35-
else
36-
Base.copyto!(dest, bc, device_dispatch(parent(dest)))
4+
function Base.copyto!(
5+
dest::AbstractData{S},
6+
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
7+
) where {S}
8+
dev = device_dispatch(parent(dest))
9+
if dev isa ToCPU &&
10+
has_uniform_datalayouts(bc) &&
11+
dest isa EndsWithField &&
12+
!(dest isa DataF)
13+
# Specialize on linear indexing when possible:
14+
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
15+
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
16+
dest[I] = convert(S, bc′[I])
3717
end
38-
call_post_op_callback() && post_op_callback(dest, dest, bc)
39-
return dest
18+
else
19+
Base.copyto!(dest, bc, device_dispatch(parent(dest)))
4020
end
21+
call_post_op_callback() && post_op_callback(dest, dest, bc)
22+
return dest
4123
end
4224

4325
# Specialize on non-Broadcasted objects

src/DataLayouts/fill.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
function Base.fill!(dest::AbstractData, val)
22
dev = device_dispatch(parent(dest))
3-
if !(VERSION v"1.11.0-beta") &&
4-
dest isa EndsWithField &&
5-
dev isa ClimaComms.AbstractCPUDevice
3+
if dest isa EndsWithField && dev isa ClimaComms.AbstractCPUDevice
64
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
75
dest[I] = val
86
end

src/DataLayouts/fused_copyto.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ function Base.copyto!(
5252
dev = device_dispatch(parent(dest1))
5353
if dev isa ClimaComms.AbstractCPUDevice &&
5454
all(bc -> has_uniform_datalayouts(bc), bcs) &&
55-
all(d -> d isa EndsWithField, destinations) &&
56-
!(VERSION v"1.11.0-beta")
55+
all(d -> d isa EndsWithField, destinations)
5756
pairs′ = map(fmb_inst.pairs) do p
5857
bc′ = to_non_extruded_broadcasted(p.second)
5958
Pair(p.first, bc′)

0 commit comments

Comments
 (0)