Skip to content

Commit 2896140

Browse files
Remove julia 1.11 conditionals
1 parent 431d459 commit 2896140

File tree

7 files changed

+58
-170
lines changed

7 files changed

+58
-170
lines changed

ext/cuda/data_layouts_copyto.jl

Lines changed: 21 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,26 @@ function knl_copyto_linear!(dest, src, us)
1616
return nothing
1717
end
1818

19-
if VERSION v"1.11.0-beta"
20-
# https://github.com/JuliaLang/julia/issues/56295
21-
# Julia 1.11's Base.Broadcast currently requires
22-
# multiple integer indexing, wheras Julia 1.10 did not.
23-
# This means that we cannot reserve linear indexing to
24-
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
25-
# (including the GPU-variant related issue resolution efforts:
26-
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
27-
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
28-
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
29-
us = DataLayouts.UniversalSize(dest)
30-
if Nv > 0 && Nh > 0
19+
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
20+
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
21+
us = DataLayouts.UniversalSize(dest)
22+
if Nv > 0 && Nh > 0
23+
if DataLayouts.has_uniform_datalayouts(bc) &&
24+
dest isa DataLayouts.EndsWithField
25+
bc′ = Base.Broadcast.instantiate(
26+
DataLayouts.to_non_extruded_broadcasted(bc),
27+
)
28+
args = (dest, bc′, us)
29+
threads = threads_via_occupancy(knl_copyto_linear!, args)
30+
n_max_threads = min(threads, get_N(us))
31+
p = linear_partition(prod(size(dest)), n_max_threads)
32+
auto_launch!(
33+
knl_copyto_linear!,
34+
args;
35+
threads_s = p.threads,
36+
blocks_s = p.blocks,
37+
)
38+
else
3139
args = (dest, bc, us)
3240
threads = threads_via_occupancy(knl_copyto!, args)
3341
n_max_threads = min(threads, get_N(us))
@@ -39,43 +47,8 @@ if VERSION ≥ v"1.11.0-beta"
3947
blocks_s = p.blocks,
4048
)
4149
end
42-
return dest
43-
end
44-
else
45-
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
46-
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
47-
us = DataLayouts.UniversalSize(dest)
48-
if Nv > 0 && Nh > 0
49-
if DataLayouts.has_uniform_datalayouts(bc) &&
50-
dest isa DataLayouts.EndsWithField
51-
bc′ = Base.Broadcast.instantiate(
52-
DataLayouts.to_non_extruded_broadcasted(bc),
53-
)
54-
args = (dest, bc′, us)
55-
threads = threads_via_occupancy(knl_copyto_linear!, args)
56-
n_max_threads = min(threads, get_N(us))
57-
p = linear_partition(prod(size(dest)), n_max_threads)
58-
auto_launch!(
59-
knl_copyto_linear!,
60-
args;
61-
threads_s = p.threads,
62-
blocks_s = p.blocks,
63-
)
64-
else
65-
args = (dest, bc, us)
66-
threads = threads_via_occupancy(knl_copyto!, args)
67-
n_max_threads = min(threads, get_N(us))
68-
p = partition(dest, n_max_threads)
69-
auto_launch!(
70-
knl_copyto!,
71-
args;
72-
threads_s = p.threads,
73-
blocks_s = p.blocks,
74-
)
75-
end
76-
end
77-
return dest
7850
end
51+
return dest
7952
end
8053

8154
# broadcasting scalar assignment

ext/cuda/data_layouts_fill.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
1919
us = DataLayouts.UniversalSize(dest)
2020
args = (dest, bc, us)
2121
if Nv > 0 && Nh > 0
22-
if !(VERSION v"1.11.0-beta") && dest isa DataLayouts.EndsWithField
22+
if dest isa DataLayouts.EndsWithField
2323
threads = threads_via_occupancy(knl_fill_linear!, args)
2424
n_max_threads = min(threads, get_N(us))
2525
p = linear_partition(prod(size(dest)), n_max_threads)

ext/cuda/data_layouts_fused_copyto.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,7 @@ function launch_fused_copyto!(fmb::FusedMultiBroadcast)
120120
destinations = map(p -> p.first, fmb.pairs)
121121
bcs = map(p -> p.second, fmb.pairs)
122122
if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) &&
123-
all(d -> d isa DataLayouts.EndsWithField, destinations) &&
124-
!(VERSION v"1.11.0-beta")
123+
all(d -> d isa DataLayouts.EndsWithField, destinations)
125124
pairs′ = map(fmb.pairs) do p
126125
bc′ = DataLayouts.to_non_extruded_broadcasted(p.second)
127126
Pair(p.first, Base.Broadcast.instantiate(bc′))

src/DataLayouts/DataLayouts.jl

Lines changed: 17 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,88 +2051,23 @@ end
20512051
const EndsWithField{S} =
20522052
Union{IJHF{S}, IHF{S}, IJF{S}, IF{S}, VF{S}, VIJHF{S}, VIHF{S}}
20532053

2054-
if VERSION v"1.11.0-beta"
2055-
### --------------- Support for multi-dimensional indexing
2056-
# TODO: can we remove this? It's not needed for Julia 1.10,
2057-
# but seems needed in Julia 1.11.
2058-
@inline Base.getindex(
2059-
data::Union{
2060-
IJF,
2061-
IJFH,
2062-
IJHF,
2063-
IFH,
2064-
IHF,
2065-
VIJFH,
2066-
VIJHF,
2067-
VIFH,
2068-
VIHF,
2069-
VF,
2070-
IF,
2071-
},
2072-
I::Vararg{Int, N},
2073-
) where {N} = Base.getindex(
2074-
data,
2075-
CartesianIndex(to_universal_index(singleton(data), I)),
2076-
)
2077-
2078-
@inline Base.setindex!(
2079-
data::Union{
2080-
IJF,
2081-
IJFH,
2082-
IJHF,
2083-
IFH,
2084-
IHF,
2085-
VIJFH,
2086-
VIJHF,
2087-
VIFH,
2088-
VIHF,
2089-
VF,
2090-
IF,
2091-
},
2092-
val,
2093-
I::Vararg{Int, N},
2094-
) where {N} = Base.setindex!(
2095-
data,
2096-
val,
2097-
CartesianIndex(to_universal_index(singleton(data), I)),
2098-
)
2099-
2100-
# Certain datalayouts support special indexing.
2101-
# Like VF datalayouts with `getindex(::VF, v::Integer)`
2102-
#! format: off
2103-
@inline to_universal_index(::VFSingleton, I::NTuple{1, T}) where {T} = (T(1), T(1), T(1), I[1], T(1))
2104-
@inline to_universal_index(::IFSingleton, I::NTuple{1, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
2105-
@inline to_universal_index(::IFSingleton, I::NTuple{2, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
2106-
@inline to_universal_index(::IFSingleton, I::NTuple{3, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
2107-
@inline to_universal_index(::IFSingleton, I::NTuple{4, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
2108-
@inline to_universal_index(::IFSingleton, I::NTuple{5, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
2109-
@inline to_universal_index(::IJFSingleton, I::NTuple{2, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
2110-
@inline to_universal_index(::IJFSingleton, I::NTuple{3, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
2111-
@inline to_universal_index(::IJFSingleton, I::NTuple{4, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
2112-
@inline to_universal_index(::IJFSingleton, I::NTuple{5, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
2113-
@inline to_universal_index(::AbstractDataSingleton, I::NTuple{5}) = I
2114-
#! format: on
2115-
### ---------------
2116-
else
2117-
# Only support datalayouts that end with fields, since those
2118-
# are the only layouts where we can efficiently compute the
2119-
# strides.
2120-
@propagate_inbounds function Base.getindex(
2121-
data::EndsWithField{S},
2122-
I::Integer,
2123-
) where {S}
2124-
s_array = farray_size(data)
2125-
@inbounds get_struct_linear(parent(data), S, I, s_array)
2126-
end
2127-
@propagate_inbounds function Base.setindex!(
2128-
data::EndsWithField{S},
2129-
val,
2130-
I::Integer,
2131-
) where {S}
2132-
s_array = farray_size(data)
2133-
@inbounds set_struct_linear!(parent(data), convert(S, val), I, s_array)
2134-
end
2135-
2054+
# Only support datalayouts that end with fields, since those
2055+
# are the only layouts where we can efficiently compute the
2056+
# strides.
2057+
@propagate_inbounds function Base.getindex(
2058+
data::EndsWithField{S},
2059+
I::Integer,
2060+
) where {S}
2061+
s_array = farray_size(data)
2062+
@inbounds get_struct_linear(parent(data), S, I, s_array)
2063+
end
2064+
@propagate_inbounds function Base.setindex!(
2065+
data::EndsWithField{S},
2066+
val,
2067+
I::Integer,
2068+
) where {S}
2069+
s_array = farray_size(data)
2070+
@inbounds set_struct_linear!(parent(data), convert(S, val), I, s_array)
21362071
end
21372072

21382073
"""

src/DataLayouts/copyto.jl

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,24 @@
11
#####
22
##### Dispatching and edge cases
33
#####
4-
if VERSION v"1.11.0-beta"
5-
# https://github.com/JuliaLang/julia/issues/56295
6-
# Julia 1.11's Base.Broadcast currently requires
7-
# multiple integer indexing, wheras Julia 1.10 did not.
8-
# This means that we cannot reserve linear indexing to
9-
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
10-
# (including the GPU-variant related issue resolution efforts:
11-
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
12-
function Base.copyto!(
13-
dest::AbstractData{S},
14-
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
15-
) where {S}
16-
return Base.copyto!(dest, bc, device_dispatch(parent(dest)))
17-
end
18-
else
19-
function Base.copyto!(
20-
dest::AbstractData{S},
21-
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
22-
) where {S}
23-
dev = device_dispatch(parent(dest))
24-
if dev isa ToCPU &&
25-
has_uniform_datalayouts(bc) &&
26-
dest isa EndsWithField &&
27-
!(dest isa DataF)
28-
# Specialize on linear indexing when possible:
29-
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
30-
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
31-
dest[I] = convert(S, bc′[I])
32-
end
33-
else
34-
Base.copyto!(dest, bc, device_dispatch(parent(dest)))
4+
function Base.copyto!(
5+
dest::AbstractData{S},
6+
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
7+
) where {S}
8+
dev = device_dispatch(parent(dest))
9+
if dev isa ToCPU &&
10+
has_uniform_datalayouts(bc) &&
11+
dest isa EndsWithField &&
12+
!(dest isa DataF)
13+
# Specialize on linear indexing when possible:
14+
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
15+
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
16+
dest[I] = convert(S, bc′[I])
3517
end
36-
return dest
18+
else
19+
Base.copyto!(dest, bc, device_dispatch(parent(dest)))
3720
end
21+
return dest
3822
end
3923

4024
# Specialize on non-Broadcasted objects

src/DataLayouts/fill.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
function Base.fill!(dest::AbstractData, val)
22
dev = device_dispatch(parent(dest))
3-
if !(VERSION v"1.11.0-beta") &&
4-
dest isa EndsWithField &&
5-
dev isa ClimaComms.AbstractCPUDevice
3+
if dest isa EndsWithField && dev isa ClimaComms.AbstractCPUDevice
64
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
75
dest[I] = val
86
end

src/DataLayouts/fused_copyto.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ function Base.copyto!(
5252
dev = device_dispatch(parent(dest1))
5353
if dev isa ClimaComms.AbstractCPUDevice &&
5454
all(bc -> has_uniform_datalayouts(bc), bcs) &&
55-
all(d -> d isa EndsWithField, destinations) &&
56-
!(VERSION v"1.11.0-beta")
55+
all(d -> d isa EndsWithField, destinations)
5756
pairs′ = map(fmb_inst.pairs) do p
5857
bc′ = to_non_extruded_broadcasted(p.second)
5958
Pair(p.first, bc′)

0 commit comments

Comments
 (0)