|
1 | 1 | module NDIteration
|
2 | 2 |
|
| 3 | +import Base.MultiplicativeInverses: SignedMultiplicativeInverse |
| 4 | + |
| 5 | +# CartesianIndex uses Int instead of Int32 |
| 6 | + |
| 7 | +@eval EmptySMI() = $(Expr(:new, SignedMultiplicativeInverse{Int32}, Int32(0), typemax(Int32), 0%Int8, 0%UInt8)) |
| 8 | +SMI(i) = i == 0 ? EmptySMI() : SignedMultiplicativeInverse{Int32}(i) |
| 9 | + |
| 10 | +struct FastCartesianIndices{N} <: AbstractArray{CartesianIndex{N}, N} |
| 11 | + inverses::NTuple{N, SignedMultiplicativeInverse{Int32}} |
| 12 | +end |
| 13 | + |
| 14 | +function FastCartesianIndices(indices::NTuple{N}) where N |
| 15 | + inverses = map(i->SMI(Int32(i)), indices) |
| 16 | + FastCartesianIndices(inverses) |
| 17 | +end |
| 18 | + |
| 19 | +function Base.size(FCI::FastCartesianIndices{N}) where N |
| 20 | + ntuple(Val(N)) do I |
| 21 | + FCI.inverses[I].divisor |
| 22 | + end |
| 23 | +end |
| 24 | + |
| 25 | +@inline function Base.getindex(::FastCartesianIndices{0}) |
| 26 | + return CartesianIndex() |
| 27 | +end |
| 28 | + |
| 29 | +@inline function Base.getindex(iter::FastCartesianIndices{N}, I::Vararg{Int, N}) where N |
| 30 | + @boundscheck checkbounds(iter, I...) |
| 31 | + index = map(iter.inverses, I) do inv, i |
| 32 | + @inbounds getindex(Base.OneTo(inv.divisor), i) |
| 33 | + end |
| 34 | + CartesianIndex(index) |
| 35 | +end |
| 36 | + |
| 37 | +_ind2sub_recuse(::Tuple{}, ind) = (ind+1,) |
| 38 | +function _ind2sub_recurse(indslast::NTuple{1}, ind) |
| 39 | + @inline |
| 40 | + (_lookup(ind, indslast[1]),) |
| 41 | +end |
| 42 | + |
| 43 | +function _ind2sub_recurse(inds, ind) |
| 44 | + @inline |
| 45 | + inv = inds[1] |
| 46 | + indnext, f, l = _div(ind, inv) |
| 47 | + (ind-l*indnext+f, _ind2sub_recurse(Base.tail(inds), indnext)...) |
| 48 | +end |
| 49 | + |
| 50 | +_lookup(ind, inv::SignedMultiplicativeInverse) = ind+1 |
| 51 | +function _div(ind, inv::SignedMultiplicativeInverse) |
| 52 | + inv.divisor == 0 && throw(DivideError()) |
| 53 | + div(ind%Int32, inv), 1, inv.divisor |
| 54 | +end |
| 55 | + |
| 56 | +function Base._ind2sub(inv::FastCartesianIndices, ind) |
| 57 | + @inline |
| 58 | + _ind2sub_recurse(inv.inverses, ind-1) |
| 59 | +end |
| 60 | + |
3 | 61 | export _Size, StaticSize, DynamicSize, get
|
4 | 62 | export NDRange, blocks, workitems, expand
|
5 | 63 | export DynamicCheck, NoDynamicCheck
|
@@ -50,18 +108,30 @@ struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems}
|
50 | 108 | blocks::DynamicBlock
|
51 | 109 | workitems::DynamicWorkitems
|
52 | 110 |
|
53 |
| - function NDRange{N, B, W}() where {N, B, W} |
54 |
| - new{N, B, W, Nothing, Nothing}(nothing, nothing) |
55 |
| - end |
56 |
| - |
57 |
| - function NDRange{N, B, W}(blocks, workitems) where {N, B, W} |
| 111 | + function NDRange{N, B, W}(blocks::Union{Nothing, FastCartesianIndices{N}}, workitems::Union{Nothing, FastCartesianIndices{N}}) where {N, B, W} |
58 | 112 | new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems)
|
59 | 113 | end
|
60 | 114 | end
|
61 | 115 |
|
62 |
| -@inline workitems(range::NDRange{N, B, W}) where {N, B, W <: DynamicSize} = range.workitems::CartesianIndices{N} |
| 116 | +function NDRange{N, B, W}() where {N, B, W} |
| 117 | + NDRange{N, B, W}(nothing, nothing) |
| 118 | +end |
| 119 | + |
| 120 | +function NDRange{N, B, W}(blocks::CartesianIndices, workitems::CartesianIndices) where {N, B, W} |
| 121 | + return NDRange{N, B, W}(FastCartesianIndices(size(blocks)), FastCartesianIndices(size(workitems))) |
| 122 | +end |
| 123 | + |
| 124 | +function NDRange{N, B, W}(blocks::Nothing, workitems::CartesianIndices) where {N, B, W} |
| 125 | + return NDRange{N, B, W}(blocks, FastCartesianIndices(size(workitems))) |
| 126 | +end |
| 127 | + |
| 128 | +function NDRange{N, B, W}(blocks::CartesianIndices, workitems::Nothing) where {N, B, W} |
| 129 | + return NDRange{N, B, W}(FastCartesianIndices(size(blocks)), workitems) |
| 130 | +end |
| 131 | + |
| 132 | +@inline workitems(range::NDRange{N, B, W}) where {N, B, W <: DynamicSize} = range.workitems::FastCartesianIndices{N} |
63 | 133 | @inline workitems(range::NDRange{N, B, W}) where {N, B, W <: StaticSize} = CartesianIndices(get(W))::CartesianIndices{N}
|
64 |
| -@inline blocks(range::NDRange{N, B}) where {N, B <: DynamicSize} = range.blocks::CartesianIndices{N} |
| 134 | +@inline blocks(range::NDRange{N, B}) where {N, B <: DynamicSize} = range.blocks::FastCartesianIndices{N} |
65 | 135 | @inline blocks(range::NDRange{N, B}) where {N, B <: StaticSize} = CartesianIndices(get(B))::CartesianIndices{N}
|
66 | 136 |
|
67 | 137 | import Base.iterate
|
|
82 | 152 |
|
83 | 153 | Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::Integer, idx::Integer) where {N}
|
84 | 154 | # This causes two sdiv operations, one for each Linear to CartesianIndex
|
85 |
| - # expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx]) |
| 155 | + return expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx]) |
86 | 156 |
|
87 | 157 | # The formulation below saves one sdiv
|
88 | 158 | # but leads to a different index order...
|
89 | 159 | # previous: julia> expand(ndrange, 1, 32*32)
|
90 | 160 | # CartesianIndex(32, 32)
|
91 | 161 | # now: julia> expand(ndrange, 1, 32*32)
|
92 | 162 | # CartesianIndex(1024, 1)
|
93 |
| - B = blocks(ndrange) |
94 |
| - W = workitems(ndrange) |
95 |
| - Ind = ntuple(Val(N)) do I |
96 |
| - Base.@_inline_meta |
97 |
| - b = B.indices[I] |
98 |
| - w = W.indices[I] |
99 |
| - length(b) * length(w) |
100 |
| - end |
101 |
| - CartesianIndices(Ind)[(groupidx-1)* prod(size(W)) + idx] |
| 163 | + # B = blocks(ndrange)::CartesianIndices |
| 164 | + # W = workitems(ndrange)::CartesianIndices |
| 165 | + # Ind = ntuple(Val(N)) do I |
| 166 | + # Base.@_inline_meta |
| 167 | + # b = B.indices[I] |
| 168 | + # w = W.indices[I] |
| 169 | + # length(b) * length(w) |
| 170 | + # end |
| 171 | + # CartesianIndices(Ind)[(groupidx-1)* prod(size(W)) + idx] |
102 | 172 | end
|
103 | 173 |
|
104 | 174 | Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::CartesianIndex{N}, idx::Integer) where {N}
|
|
0 commit comments