445
445
# broadcast
446
446
import Base. Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
447
447
448
- struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end
448
+ struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end
449
+ # If `S` also track input's dimensionality, we'd better also update it.
450
+ StructArrayStyle {S,M} (:: Val{N} ) where {M,S<: AbstractArrayStyle{M} ,N} =
451
+ StructArrayStyle {typeof(S(Val(N))),N} ()
452
+ StructArrayStyle {S,M} (:: Val{N} ) where {M,S,N} = StructArrayStyle {S,N} ()
449
453
450
454
@inline combine_style_types (:: Type{A} , args... ) where A<: AbstractArray =
451
455
combine_style_types (BroadcastStyle (A), args... )
@@ -455,9 +459,9 @@ combine_style_types(s::BroadcastStyle) = s
455
459
456
460
Base. @pure cst (:: Type{SA} ) where SA = combine_style_types (array_types (SA). parameters... )
457
461
458
- BroadcastStyle (:: Type{SA} ) where SA<: StructArray = StructArrayStyle {typeof(cst(SA))} ()
462
+ BroadcastStyle (:: Type{SA} ) where SA<: StructArray = StructArrayStyle {typeof(cst(SA)),ndims(SA) } ()
459
463
460
- Base. similar (bc:: Broadcasted{StructArrayStyle{S}} , :: Type{ElType} ) where {S<: DefaultArrayStyle ,N,ElType} =
464
+ Base. similar (bc:: Broadcasted{<: StructArrayStyle{S}} , :: Type{ElType} ) where {S<: DefaultArrayStyle ,N,ElType} =
461
465
isstructtype (ElType) ? similar (StructArray{ElType}, axes (bc)) : similar (Array{ElType}, axes (bc))
462
466
463
467
# for aliasing analysis during broadcast
0 commit comments