Skip to content

Commit 24ef689

Browse files
wip
1 parent 236262b commit 24ef689

31 files changed

+592
-330
lines changed

ext/ClimaCoreCUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import ClimaCore.RecursiveApply:
1818
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax
1919
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
2020
import ClimaCore.DataLayouts: universal_size, UniversalSize
21+
import ClimaCore.DataLayouts: ArraySize
2122

2223
include(joinpath("cuda", "cuda_utils.jl"))
2324
include(joinpath("cuda", "data_layouts.jl"))

src/DataLayouts/DataLayouts.jl

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,18 @@ function Base.show(io::IO, data::AbstractData)
129129
(rows, cols) = displaysize(io)
130130
println(io, summary(data))
131131
print(io, " "^indent_width)
132+
# @show similar(parent_array_type(data))
133+
# fa = map(x -> vec(x), field_arrays(data))
132134
print(
133135
IOContext(
134136
io,
135137
:compact => true,
136138
:limit => true,
137139
:displaysize => (rows, cols - indent_width),
138140
),
139-
map(x -> vec(x), field_arrays(data)),
141+
# collect(field_array(data)),
142+
parent(data),
143+
# map(x -> vec(x), field_arrays(data)),
140144
)
141145
return io
142146
end
@@ -619,10 +623,7 @@ function IJF{S, Nij}(::Type{MArray}, ::Type{T}) where {S, Nij, T}
619623
array = FieldArray{field_dim(IJF)}(ntuple(f->MArray{Tuple{Nij, Nij}, T, 2, Nij * Nij}(undef), Nf))
620624
IJF{S, Nij}(array)
621625
end
622-
function SArray(ijf::IJF{S, Nij, FieldArray{FD, N, T}}) where {S, Nij, FD, N, T <: MArray}
623-
IJF{S, Nij}(SArray(field_array(ijf)))
624-
end
625-
function SArray(ijf::IJF{S, Nij, <:MArray}) where {S, Nij}
626+
function SArray(ijf::IJF{S, Nij, <:FieldArray}) where {S, Nij}
626627
IJF{S, Nij}(SArray(field_array(ijf)))
627628
end
628629

@@ -681,15 +682,15 @@ end
681682
function IF{S, Ni}(::Type{MArray}, ::Type{T}) where {S, Ni, T}
682683
Nf = typesize(T, S)
683684
# array = MArray{Tuple{Ni, Nf}, T, 2, Ni * Nf}(undef)
684-
array = FieldArray{field_dim(IF)}(ntuple(f->MArray{Tuple{Ni}, T, 1, Ni}(undef), Nf))
685-
IF{S, Ni}(array)
685+
fa = FieldArray{field_dim(IF)}(ntuple(f->MArray{Tuple{Ni}, T, 1, Ni}(undef), Nf))
686+
IF{S, Ni}(fa)
686687
end
687-
function SArray(data::IF{S, Ni, <:FieldArray{<:Any, <:Any, T}}) where {S, Ni, T <: MArray}
688-
IF{S, Ni}(SArray(field_array(data)))
689-
end
690-
function SArray(data::IF{S, Ni, <:MArray}) where {S, Ni}
688+
function SArray(data::IF{S, Ni, <:FieldArray}) where {S, Ni}
691689
IF{S, Ni}(SArray(field_array(data)))
692690
end
691+
# function SArray(data::IF{S, Ni, <:MArray}) where {S, Ni}
692+
# IF{S, Ni}(SArray(field_array(data)))
693+
# end
693694

694695
@inline function column(data::IF{S, Ni}, i) where {S, Ni}
695696
@boundscheck (1 <= i <= Ni) || throw(BoundsError(data, (i,)))
@@ -816,14 +817,16 @@ Base.length(data::VIJFH) = get_Nv(data) * get_Nh(data)
816817
@boundscheck (1 <= v <= Nv && 1 <= h <= Nh) ||
817818
throw(BoundsError(data, (v, h)))
818819
Nf = ncomponents(data)
819-
dataview = @inbounds view(
820-
array,
821-
v,
822-
Base.Slice(Base.OneTo(Nij)),
823-
Base.Slice(Base.OneTo(Nij)),
824-
Base.Slice(Base.OneTo(Nf)),
825-
h,
826-
)
820+
sub_arrays = @inbounds ntuple(Nf) do f
821+
view(
822+
array.arrays[f],
823+
v,
824+
Base.Slice(Base.OneTo(Nij)),
825+
Base.Slice(Base.OneTo(Nij)),
826+
h,
827+
)
828+
end
829+
dataview = FieldArray{field_dim(IJF)}(sub_arrays)
827830
IJF{S, Nij}(dataview)
828831
end
829832

@@ -1113,11 +1116,15 @@ type parameters.
11131116
@inline field_dim(::Type{<:VIJFH}) = 4
11141117
@inline field_dim(::Type{<:VIFH}) = 3
11151118

1116-
@inline to_data_specific_field_array(::IJFH, I::CartesianIndex{5}) = CartesianIndex(I.I[1], I.I[2], I.I[5])
1117-
@inline to_data_specific_field_array(::IFH, I::CartesianIndex{5}) = CartesianIndex(I.I[1], I.I[5])
1118-
@inline to_data_specific_field_array(::VIJFH, I::CartesianIndex{5}) = CartesianIndex(I.I[4], I.I[1], I.I[2], I.I[5])
1119-
@inline to_data_specific_field_array(::VIFH, I::CartesianIndex{5}) = CartesianIndex(I.I[4], I.I[1], I.I[5])
1120-
@inline to_data_specific_field_array(::DataSlab1D, I::CartesianIndex{5}) = CartesianIndex(I.I[1], I.I[1], I.I[5])
1119+
@inline to_data_specific_field_array(data::AbstractData, I::CartesianIndex) =
1120+
CartesianIndex(_to_data_specific_field_array(data, I.I))
1121+
@inline _to_data_specific_field_array(::VF, I::Tuple) = (I[4],)
1122+
@inline _to_data_specific_field_array(::IF, I::Tuple) = (I[1],)
1123+
@inline _to_data_specific_field_array(::IJF, I::Tuple) = (I[1], I[2])
1124+
@inline _to_data_specific_field_array(::IJFH, I::Tuple) = (I[1], I[2], I[5])
1125+
@inline _to_data_specific_field_array(::IFH, I::Tuple) = (I[1], I[5])
1126+
@inline _to_data_specific_field_array(::VIJFH, I::Tuple) = (I[4], I[1], I[2], I[5])
1127+
@inline _to_data_specific_field_array(::VIFH, I::Tuple) = (I[4], I[1], I[5])
11211128

11221129
@inline to_data_specific(data::AbstractData, I::CartesianIndex) =
11231130
CartesianIndex(_to_data_specific(data, I.I))
@@ -1349,7 +1356,7 @@ field_array(data::AbstractData{S}) where {S} = parent(data)
13491356
parent(data),
13501357
eltype(data),
13511358
Val(field_dim(data)),
1352-
to_data_specific(data, I),
1359+
to_data_specific_field_array(data, I),
13531360
)
13541361
end
13551362

@@ -1363,7 +1370,7 @@ end
13631370
parent(data),
13641371
convert(eltype(data), val),
13651372
Val(field_dim(data)),
1366-
to_data_specific(data, I),
1373+
to_data_specific_field_array(data, I),
13671374
)
13681375
end
13691376

src/DataLayouts/broadcast.jl

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import MultiBroadcastFusion as MBF
22
import MultiBroadcastFusion: fused_direct
3+
import ..RecursiveApply
34

45
# Make a MultiBroadcastFusion type, `FusedMultiBroadcast`, and macro, `@fused`:
56
# via https://github.com/CliMA/MultiBroadcastFusion.jl
@@ -11,6 +12,25 @@ MBF.@make_fused fused_direct FusedMultiBroadcast fused_direct
1112

1213
abstract type DataStyle <: Base.BroadcastStyle end
1314

15+
"""
16+
parent_array_type
17+
18+
Returns a UnionAll array type given the inputs.
19+
For example: `Array`, `CuArray` etc.
20+
21+
# Note
22+
23+
The returned type must be a UnionAll array type
24+
because we need to be able to promote broadcast
25+
expressions with fields containing different number
26+
of variables. The number of fields returns depends
27+
on the function being broadcasted over, and we do
28+
not have this number here.
29+
30+
# TODO: make this note more precise
31+
"""
32+
function parent_array_type end
33+
1434
abstract type Data0DStyle <: DataStyle end
1535
struct DataFStyle{A} <: Data0DStyle end
1636
DataStyle(::Type{DataF{S, A}}) where {S, A} = DataFStyle{parent_array_type(A)}()
@@ -291,45 +311,59 @@ function Base.similar(
291311
bc::BroadcastedUnionDataF{<:Any, A},
292312
::Type{Eltype},
293313
) where {A, Eltype}
294-
PA = parent_array_type(A)
295-
array = similar(PA, (typesize(eltype(A), Eltype)))
296-
return DataF{Eltype}(array)
314+
Nf = typesize(eltype(A), Eltype)
315+
_size = ()
316+
as = ArraySize{field_dim(DataF), Nf, _size}()
317+
fa = similar(rebuild_field_array_type(A, as), _size)
318+
return DataF{Eltype}(fa)
297319
end
298320

299321
function Base.similar(
300322
bc::BroadcastedUnionIJFH{<:Any, Nij, Nh, A},
301323
::Type{Eltype},
302324
) where {Nij, Nh, A, Eltype}
303-
PA = parent_array_type(A)
304-
array = similar(PA, (Nij, Nij, typesize(eltype(A), Eltype), Nh))
305-
return IJFH{Eltype, Nij, Nh}(array)
325+
Nf = typesize(eltype(A), Eltype)
326+
_size = (Nij, Nij, Nh)
327+
as = ArraySize{field_dim(IJFH), Nf, _size}()
328+
fa = similar(rebuild_field_array_type(A, as), _size)
329+
return IJFH{Eltype, Nij, Nh}(fa)
306330
end
307331

308332
function Base.similar(
309333
bc::BroadcastedUnionIFH{<:Any, Ni, Nh, A},
310334
::Type{Eltype},
311335
) where {Ni, Nh, A, Eltype}
312-
PA = parent_array_type(A)
313-
array = similar(PA, (Ni, typesize(eltype(A), Eltype), Nh))
314-
return IFH{Eltype, Ni, Nh}(array)
336+
Nf = typesize(eltype(A), Eltype)
337+
_size = (Ni, Nh)
338+
as = ArraySize{field_dim(IFH), Nf, _size}()
339+
fa = similar(rebuild_field_array_type(A, as), _size)
340+
return IFH{Eltype, Ni, Nh}(fa)
315341
end
316342

317343
function Base.similar(
318344
::BroadcastedUnionIJF{<:Any, Nij, A},
319345
::Type{Eltype},
320346
) where {Nij, A, Eltype}
321347
Nf = typesize(eltype(A), Eltype)
322-
array = MArray{Tuple{Nij, Nij, Nf}, eltype(A), 3, Nij * Nij * Nf}(undef)
323-
return IJF{Eltype, Nij}(array)
348+
# array = MArray{Tuple{Nij, Nij, Nf}, eltype(A), 3, Nij * Nij * Nf}(undef)
349+
MAT = MArray{Tuple{Nij, Nij}, eltype(A), 2, Nij * Nij}
350+
_size = (Nij, Nij)
351+
as = ArraySize{field_dim(IJF), Nf, ()}()
352+
fa = similar(rebuild_field_array_type(A, as, MAT), _size)
353+
return IJF{Eltype, Nij}(fa)
324354
end
325355

326356
function Base.similar(
327357
::BroadcastedUnionIF{<:Any, Ni, A},
328358
::Type{Eltype},
329359
) where {Ni, A, Eltype}
330360
Nf = typesize(eltype(A), Eltype)
331-
array = MArray{Tuple{Ni, Nf}, eltype(A), 2, Ni * Nf}(undef)
332-
return IF{Eltype, Ni}(array)
361+
# array = MArray{Tuple{Ni, Nf}, eltype(A), 2, Ni * Nf}(undef)
362+
MAT = MArray{Tuple{Ni}, eltype(A), 2, Ni}
363+
_size = (Ni, )
364+
as = ArraySize{field_dim(IF), Nf, ()}() # size is unused
365+
fa = similar(rebuild_field_array_type(A, as, MAT), _size)
366+
return IF{Eltype, Ni}(fa)
333367
end
334368

335369
Base.similar(
@@ -342,12 +376,10 @@ function Base.similar(
342376
::Type{Eltype},
343377
::Val{newNv},
344378
) where {Nv, A, Eltype, newNv}
345-
PA = parent_array_type(A)
346-
# @show PA
347379
Nf = typesize(eltype(A), Eltype)
348-
# @show (newNv, Nf)
349-
# array = similar(PA, (newNv, Nf))
350-
fa = FieldArray{field_dim(VF)}(ntuple(i -> similar(PA, newNv), Nf))
380+
_size = (newNv, )
381+
as = ArraySize{field_dim(VF), Nf, _size}()
382+
fa = similar(rebuild_field_array_type(A, as), _size)
351383
return VF{Eltype, newNv, typeof(fa)}(fa)
352384
end
353385

@@ -361,9 +393,11 @@ function Base.similar(
361393
::Type{Eltype},
362394
::Val{newNv},
363395
) where {Nv, Ni, Nh, A, Eltype, newNv}
364-
PA = parent_array_type(A)
365-
array = similar(PA, (newNv, Ni, typesize(eltype(A), Eltype), Nh))
366-
return VIFH{Eltype, newNv, Ni, Nh}(array)
396+
Nf = typesize(eltype(A), Eltype)
397+
_size = (newNv, Ni, Nh)
398+
as = ArraySize{field_dim(VIFH), Nf, _size}()
399+
fa = similar(rebuild_field_array_type(A, as), _size)
400+
return VIFH{Eltype, newNv, Ni, Nh}(fa)
367401
end
368402

369403
Base.similar(
@@ -378,16 +412,10 @@ function Base.similar(
378412
) where {Nv, Nij, Nh, A, Eltype, newNv}
379413
T = eltype(A)
380414
Nf = typesize(eltype(A), Eltype)
381-
# fat = rebuild_type(A, Val(field_dim(VIJFH)), Val(Nf), Val(4))
382415
_size = (newNv, Nij, Nij, Nh)
383416
as = ArraySize{field_dim(VIJFH), Nf, _size}()
384-
# fat = if A isa AbstractArray
385-
# field_array_type(A, as)
386-
# else
387-
# end
388-
array = similar(rebuild_field_array_type(A, as), _size)
389-
vd = VIJFH{Eltype, newNv, Nij, Nh}(array)
390-
return vd
417+
fa = similar(rebuild_field_array_type(A, as), _size)
418+
return VIJFH{Eltype, newNv, Nij, Nh}(fa)
391419
end
392420

393421
# ============= FusedMultiBroadcast

src/DataLayouts/copyto.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
##### Dispatching and edge cases
33
#####
44

5-
Base.copyto!(
5+
function Base.copyto!(
66
dest::AbstractData,
77
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
8-
) = Base.copyto!(dest, bc, device_dispatch(dest))
8+
)
9+
ncomponents(dest) > 0 || return dest
10+
Base.copyto!(dest, bc, device_dispatch(dest))
11+
end
912

1013
# Specialize on non-Broadcasted objects
1114
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}

0 commit comments

Comments
 (0)