1
1
import MultiBroadcastFusion as MBF
2
2
import MultiBroadcastFusion: fused_direct
3
+ import .. RecursiveApply
3
4
4
5
# Make a MultiBroadcastFusion type, `FusedMultiBroadcast`, and macro, `@fused`:
5
6
# via https://github.com/CliMA/MultiBroadcastFusion.jl
@@ -11,6 +12,25 @@ MBF.@make_fused fused_direct FusedMultiBroadcast fused_direct
11
12
12
13
abstract type DataStyle <: Base.BroadcastStyle end
13
14
15
+ """
16
+ parent_array_type
17
+
18
+ Returns a UnionAll array type given the inputs.
19
+ For example: `Array`, `CuArray` etc.
20
+
21
+ # Note
22
+
23
+ The returned type must be a UnionAll array type
24
+ because we need to be able to promote broadcast
25
+ expressions with fields containing different number
26
+ of variables. The number of fields returns depends
27
+ on the function being broadcasted over, and we do
28
+ not have this number here.
29
+
30
+ # TODO: make this note more precise
31
+ """
32
+ function parent_array_type end
33
+
14
34
abstract type Data0DStyle <: DataStyle end
15
35
struct DataFStyle{A} <: Data0DStyle end
16
36
DataStyle (:: Type{DataF{S, A}} ) where {S, A} = DataFStyle {parent_array_type(A)} ()
@@ -291,45 +311,59 @@ function Base.similar(
291
311
bc:: BroadcastedUnionDataF{<:Any, A} ,
292
312
:: Type{Eltype} ,
293
313
) where {A, Eltype}
294
- PA = parent_array_type (A)
295
- array = similar (PA, (typesize (eltype (A), Eltype)))
296
- return DataF {Eltype} (array)
314
+ Nf = typesize (eltype (A), Eltype)
315
+ _size = ()
316
+ as = ArraySize {field_dim(DataF), Nf, _size} ()
317
+ fa = similar (rebuild_field_array_type (A, as), _size)
318
+ return DataF {Eltype} (fa)
297
319
end
298
320
299
321
function Base. similar (
300
322
bc:: BroadcastedUnionIJFH{<:Any, Nij, Nh, A} ,
301
323
:: Type{Eltype} ,
302
324
) where {Nij, Nh, A, Eltype}
303
- PA = parent_array_type (A)
304
- array = similar (PA, (Nij, Nij, typesize (eltype (A), Eltype), Nh))
305
- return IJFH {Eltype, Nij, Nh} (array)
325
+ Nf = typesize (eltype (A), Eltype)
326
+ _size = (Nij, Nij, Nh)
327
+ as = ArraySize {field_dim(IJFH), Nf, _size} ()
328
+ fa = similar (rebuild_field_array_type (A, as), _size)
329
+ return IJFH {Eltype, Nij, Nh} (fa)
306
330
end
307
331
308
332
function Base. similar (
309
333
bc:: BroadcastedUnionIFH{<:Any, Ni, Nh, A} ,
310
334
:: Type{Eltype} ,
311
335
) where {Ni, Nh, A, Eltype}
312
- PA = parent_array_type (A)
313
- array = similar (PA, (Ni, typesize (eltype (A), Eltype), Nh))
314
- return IFH {Eltype, Ni, Nh} (array)
336
+ Nf = typesize (eltype (A), Eltype)
337
+ _size = (Ni, Nh)
338
+ as = ArraySize {field_dim(IFH), Nf, _size} ()
339
+ fa = similar (rebuild_field_array_type (A, as), _size)
340
+ return IFH {Eltype, Ni, Nh} (fa)
315
341
end
316
342
317
343
function Base. similar (
318
344
:: BroadcastedUnionIJF{<:Any, Nij, A} ,
319
345
:: Type{Eltype} ,
320
346
) where {Nij, A, Eltype}
321
347
Nf = typesize (eltype (A), Eltype)
322
- array = MArray {Tuple{Nij, Nij, Nf}, eltype(A), 3, Nij * Nij * Nf} (undef)
323
- return IJF {Eltype, Nij} (array)
348
+ # array = MArray{Tuple{Nij, Nij, Nf}, eltype(A), 3, Nij * Nij * Nf}(undef)
349
+ MAT = MArray{Tuple{Nij, Nij}, eltype (A), 2 , Nij * Nij}
350
+ _size = (Nij, Nij)
351
+ as = ArraySize {field_dim(IJF), Nf, ()} ()
352
+ fa = similar (rebuild_field_array_type (A, as, MAT), _size)
353
+ return IJF {Eltype, Nij} (fa)
324
354
end
325
355
326
356
function Base. similar (
327
357
:: BroadcastedUnionIF{<:Any, Ni, A} ,
328
358
:: Type{Eltype} ,
329
359
) where {Ni, A, Eltype}
330
360
Nf = typesize (eltype (A), Eltype)
331
- array = MArray {Tuple{Ni, Nf}, eltype(A), 2, Ni * Nf} (undef)
332
- return IF {Eltype, Ni} (array)
361
+ # array = MArray{Tuple{Ni, Nf}, eltype(A), 2, Ni * Nf}(undef)
362
+ MAT = MArray{Tuple{Ni}, eltype (A), 2 , Ni}
363
+ _size = (Ni, )
364
+ as = ArraySize {field_dim(IF), Nf, ()} () # size is unused
365
+ fa = similar (rebuild_field_array_type (A, as, MAT), _size)
366
+ return IF {Eltype, Ni} (fa)
333
367
end
334
368
335
369
Base. similar (
@@ -342,12 +376,10 @@ function Base.similar(
342
376
:: Type{Eltype} ,
343
377
:: Val{newNv} ,
344
378
) where {Nv, A, Eltype, newNv}
345
- PA = parent_array_type (A)
346
- # @show PA
347
379
Nf = typesize (eltype (A), Eltype)
348
- # @show (newNv, Nf )
349
- # array = similar(PA, (newNv, Nf) )
350
- fa = FieldArray {field_dim(VF)} ( ntuple (i -> similar (PA, newNv ), Nf) )
380
+ _size = (newNv, )
381
+ as = ArraySize {field_dim(VF), Nf, _size} ( )
382
+ fa = similar (rebuild_field_array_type (A, as ), _size )
351
383
return VF {Eltype, newNv, typeof(fa)} (fa)
352
384
end
353
385
@@ -361,9 +393,11 @@ function Base.similar(
361
393
:: Type{Eltype} ,
362
394
:: Val{newNv} ,
363
395
) where {Nv, Ni, Nh, A, Eltype, newNv}
364
- PA = parent_array_type (A)
365
- array = similar (PA, (newNv, Ni, typesize (eltype (A), Eltype), Nh))
366
- return VIFH {Eltype, newNv, Ni, Nh} (array)
396
+ Nf = typesize (eltype (A), Eltype)
397
+ _size = (newNv, Ni, Nh)
398
+ as = ArraySize {field_dim(VIFH), Nf, _size} ()
399
+ fa = similar (rebuild_field_array_type (A, as), _size)
400
+ return VIFH {Eltype, newNv, Ni, Nh} (fa)
367
401
end
368
402
369
403
Base. similar (
@@ -378,16 +412,10 @@ function Base.similar(
378
412
) where {Nv, Nij, Nh, A, Eltype, newNv}
379
413
T = eltype (A)
380
414
Nf = typesize (eltype (A), Eltype)
381
- # fat = rebuild_type(A, Val(field_dim(VIJFH)), Val(Nf), Val(4))
382
415
_size = (newNv, Nij, Nij, Nh)
383
416
as = ArraySize {field_dim(VIJFH), Nf, _size} ()
384
- # fat = if A isa AbstractArray
385
- # field_array_type(A, as)
386
- # else
387
- # end
388
- array = similar (rebuild_field_array_type (A, as), _size)
389
- vd = VIJFH {Eltype, newNv, Nij, Nh} (array)
390
- return vd
417
+ fa = similar (rebuild_field_array_type (A, as), _size)
418
+ return VIJFH {Eltype, newNv, Nij, Nh} (fa)
391
419
end
392
420
393
421
# ============= FusedMultiBroadcast
0 commit comments