|
91 | 91 | Base.@pure num_types(::Type{<:TypeSortedCollection{<:Any, N}}) where {N} = N
|
92 | 92 | num_types(x::TypeSortedCollection) = num_types(typeof(x))
|
93 | 93 |
|
| 94 | +const TSCOrAbstractVector{N} = Union{<:TypeSortedCollection{<:Any, N}, AbstractVector} |
| 95 | + |
94 | 96 | Base.isempty(x::TypeSortedCollection) = all(isempty, x.data)
|
95 | 97 | Base.empty!(x::TypeSortedCollection) = foreach(empty!, x.data)
|
96 | 98 | Base.length(x::TypeSortedCollection) = mapreduce(length, +, 0, x.data)
|
97 | 99 | Base.indices(x::TypeSortedCollection) = x.indices # semantics are a little different from Array, but OK
|
98 | 100 |
|
99 | 101 | # Trick from StaticArrays:
|
100 |
| -@inline first_tsc(a1::TypeSortedCollection, as::Union{<:TypeSortedCollection, AbstractVector}...) = a1 |
101 |
| -@inline first_tsc(a1, as::Union{<:TypeSortedCollection, AbstractVector}...) = first_tsc(as...) |
| 102 | +@inline first_tsc(a1::TypeSortedCollection, as...) = a1 |
| 103 | +@inline first_tsc(a1, as...) = first_tsc(as...) |
| 104 | + |
| 105 | +Base.@pure first_tsc_type(a1::Type{<:TypeSortedCollection}, as::Type...) = a1 |
| 106 | +Base.@pure first_tsc_type(a1::Type, as::Type...) = first_tsc_type(as...) |
102 | 107 |
|
103 | 108 | # inspired by Base.ith_all
|
104 | 109 | @inline _getindex_all(::Val, j, vecindex) = ()
|
105 | 110 | Base.@propagate_inbounds _getindex_all(vali::Val{i}, j, vecindex, a1, as...) where {i} = (_getindex(vali, j, vecindex, a1), _getindex_all(vali, j, vecindex, as...)...)
|
| 111 | +@inline _getindex(::Val, j, vecindex, a) = a # for anything that's not an AbstractVector or TypeSortedCollection, don't index (for use in broadcast!) |
106 | 112 | @inline _getindex(::Val, j, vecindex, a::AbstractVector) = a[vecindex]
|
107 | 113 | @inline _getindex(::Val{i}, j, vecindex, a::TypeSortedCollection) where {i} = a.data[i][j]
|
108 | 114 | @inline _setindex!(::Val, j, vecindex, a::AbstractVector, val) = a[vecindex] = val
|
109 | 115 | @inline _setindex!(::Val{i}, j, vecindex, a::TypeSortedCollection, val) where {i} = a.data[i][j] = val
|
110 | 116 |
|
111 | 117 | @inline lengths_match(a1) = true
|
112 |
| -@inline lengths_match(a1, a2, as...) = length(a1) == length(a2) && lengths_match(a2, as...) |
| 118 | +@inline lengths_match(a1::TSCOrAbstractVector, a2::TSCOrAbstractVector, as...) = length(a1) == length(a2) && lengths_match(a2, as...) |
| 119 | +@inline lengths_match(a1::TSCOrAbstractVector, a2, as...) = lengths_match(a1, as...) # case: a2 is not indexable: skip it |
113 | 120 | @noinline lengths_match_fail() = throw(DimensionMismatch("Lengths of input collections do not match."))
|
114 | 121 |
|
115 |
| -@inline indices_match(::Val, indices::Vector{Int}, ::AbstractVector) = true |
| 122 | +@inline indices_match(::Val, indices::Vector{Int}, ::Any) = true |
116 | 123 | @inline function indices_match(::Val{i}, indices::Vector{Int}, tsc::TypeSortedCollection) where {i}
|
117 | 124 | tsc_indices = tsc.indices[i]
|
118 | 125 | length(indices) == length(tsc_indices) || return false
|
|
124 | 131 | @inline indices_match(vali::Val, indices::Vector{Int}, a1, as...) = indices_match(vali, indices, a1) && indices_match(vali, indices, as...)
|
125 | 132 | @noinline indices_match_fail() = throw(ArgumentError("Indices of TypeSortedCollections do not match."))
|
126 | 133 |
|
127 |
| -@generated function Base.map!(f, dest::Union{TypeSortedCollection{<:Any, N}, AbstractArray}, args::Union{TypeSortedCollection{<:Any, N}, AbstractArray}...) where {N} |
| 134 | +@generated function Base.map!(f, dest::TSCOrAbstractVector{N}, args::TSCOrAbstractVector{N}...) where {N} |
128 | 135 | expr = Expr(:block)
|
129 | 136 | push!(expr.args, :(Base.@_inline_meta))
|
130 | 137 | push!(expr.args, :(leading_tsc = first_tsc(dest, args...)))
|
|
134 | 141 | push!(expr.args, quote
|
135 | 142 | let inds = leading_tsc.indices[$i]
|
136 | 143 | @boundscheck indices_match($vali, inds, dest, args...) || indices_match_fail()
|
137 |
| - for j in linearindices(inds) |
| 144 | + @inbounds for j in linearindices(inds) |
138 | 145 | vecindex = inds[j]
|
139 | 146 | _setindex!($vali, j, vecindex, dest, f(_getindex_all($vali, j, vecindex, args...)...))
|
140 | 147 | end
|
|
147 | 154 | end
|
148 | 155 | end
|
149 | 156 |
|
150 |
| -@generated function Base.foreach(f, As::Union{<:TypeSortedCollection{<:Any, N}, AbstractVector}...) where {N} |
| 157 | +@generated function Base.foreach(f, As::TSCOrAbstractVector{N}...) where {N} |
151 | 158 | expr = Expr(:block)
|
152 | 159 | push!(expr.args, :(Base.@_inline_meta))
|
153 | 160 | push!(expr.args, :(leading_tsc = first_tsc(As...)))
|
|
187 | 194 | end
|
188 | 195 | end
|
189 | 196 |
|
| 197 | +## broadcast! |
| 198 | +Base.Broadcast._containertype(::Type{<:TypeSortedCollection}) = TypeSortedCollection |
| 199 | +Base.Broadcast.promote_containertype(::Type{TypeSortedCollection}, _) = TypeSortedCollection |
| 200 | +Base.Broadcast.promote_containertype(_, ::Type{TypeSortedCollection}) = TypeSortedCollection |
| 201 | +Base.Broadcast.promote_containertype(::Type{TypeSortedCollection}, ::Type{Array}) = TypeSortedCollection # handle ambiguities with `Array` |
| 202 | +Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TypeSortedCollection}) = TypeSortedCollection # handle ambiguities with `Array` |
| 203 | + |
| 204 | +@generated function Base.Broadcast.broadcast_c!(f, ::Type, ::Type{TypeSortedCollection}, dest::AbstractVector, A, Bs...) |
| 205 | + T = first_tsc_type(A, Bs...) |
| 206 | + N = num_types(T) |
| 207 | + expr = Expr(:block) |
| 208 | + push!(expr.args, :(Base.@_inline_meta)) # TODO: good idea? |
| 209 | + push!(expr.args, :(leading_tsc = first_tsc(A, Bs...))) |
| 210 | + push!(expr.args, :(@boundscheck lengths_match(dest, A, Bs...) || lengths_match_fail())) |
| 211 | + for i = 1 : N |
| 212 | + vali = Val(i) |
| 213 | + push!(expr.args, quote |
| 214 | + let inds = leading_tsc.indices[$i] |
| 215 | + @boundscheck indices_match($vali, inds, A, Bs...) || indices_match_fail() |
| 216 | + @inbounds for j in linearindices(inds) |
| 217 | + vecindex = inds[j] |
| 218 | + _setindex!($vali, j, vecindex, dest, f(_getindex_all($vali, j, vecindex, A, Bs...)...)) |
| 219 | + end |
| 220 | + end |
| 221 | + end) |
| 222 | + end |
| 223 | + quote |
| 224 | + $expr |
| 225 | + dest |
| 226 | + end |
| 227 | +end |
| 228 | + |
190 | 229 | end # module
|
0 commit comments