|
| 1 | +# This file is a part of Julia. License is MIT: https://julialang.org/license |
| 2 | +# |
| 3 | +# Parallel Partitioned Shuffle |
| 4 | +# |
| 5 | +using Base.Threads |
| 6 | +export ppshuffle!, ppshuffle, pprandperm!, pprandperm |
| 7 | + |
| 8 | +## ppshuffle! & ppshuffle |
| 9 | + |
| 10 | +""" |
| 11 | + _ppshuffle!(rng::TaskLocalRNG, B::AbstractArray{T}, A::Union{AbstractArray{T}, Base.OneTo{T}}, mask<:Union{UInt8, UInt16}) |
| 12 | +
|
| 13 | +Parallel Partitioned Shuffle |
| 14 | +1. partition input randomly |
| 15 | +2. shuffle partitions concurrently |
| 16 | +
|
| 17 | +Arg `mask` determines number of partitions (mask + 1) to be used. |
| 18 | +""" |
| 19 | +function _ppshuffle!(r::TaskLocalRNG, B::AbstractArray{T}, A::Union{AbstractArray{T}, Base.OneTo{T}}, mask::Tu) where {T, Tu<:Union{UInt8, UInt16}} |
| 20 | + # determine number of partitions |
| 21 | + nparts = mask + 1 |
| 22 | + @assert ispow2(nparts) "invalid mask $(mask)" |
| 23 | + @assert length(A) == length(B) |
| 24 | + |
| 25 | + n = length(A) |
| 26 | + s = Random.SamplerType{Tu}() |
| 27 | + |
| 28 | + # an array to count partition hits by threads |
| 29 | + # partitions map to rows (pid) |
| 30 | + # threads map to cols (tid) |
| 31 | + # we add an extra cache line for each thread |
| 32 | + nrows = nparts < 8 ? 8 : nparts + 8 |
| 33 | + hits = zeros(Int, nrows, nthreads()) |
| 34 | + |
| 35 | + # save initial random state |
| 36 | + r0 = copy(r) |
| 37 | + # 1st pass |
| 38 | + # assign input to partitions uniformly at random |
| 39 | + # count cells hit by each thread in every partition |
| 40 | + @threads :static for i in 1:n |
| 41 | + local tid, pid = threadid(), rand(r, s) & mask + 1 |
| 42 | + @inbounds hits[pid, tid] += 1 |
| 43 | + end |
| 44 | + |
| 45 | + # cumsum partition hits |
| 46 | + # to mark boundaries of space reserved by each thread in every partition |
| 47 | + # note that the 1st column will contain boundaries of entire partitions |
| 48 | + prev = 0 |
| 49 | + for pid = 1:nparts, tid = 1:nthreads() |
| 50 | + @inbounds prev = hits[pid, tid] += prev |
| 51 | + end |
| 52 | + # mark the end of the last partition |
| 53 | + hits[nparts + 1, 1] = n |
| 54 | + |
| 55 | + # recover random state |
| 56 | + copy!(r, r0) |
| 57 | + # 2nd pass |
| 58 | + # scatter input accross partitions uniformly at random |
| 59 | + # note that input distribution is identical as in the 1st pass |
| 60 | + # since we recovered the initial random state |
| 61 | + @threads :static for i in 1:n |
| 62 | + local tid, pid = threadid(), rand(r, s) & mask + 1 |
| 63 | + @inbounds B[hits[pid, tid]] = A[i] |
| 64 | + @inbounds hits[pid, tid] -= 1 |
| 65 | + end |
| 66 | + |
| 67 | + # input is partitioned |
| 68 | + # shuffle partitions in parallel |
| 69 | + @threads :static for pid in 1:nparts |
| 70 | + @inbounds local chunk = view(B, hits[pid, 1] + 1:hits[pid + 1, 1]) |
| 71 | + shuffle!(r, chunk) |
| 72 | + end |
| 73 | + B |
| 74 | +end |
| 75 | + |
| 76 | +""" |
| 77 | + ppshuffle!([rng::TaskLocalRNG=default_rng(),] B::AbstractArray{T}, A::Union{AbstractArray{T}, Base.OneTo{T}}, mask<:Union{UInt8, UInt16}) |
| 78 | +
|
| 79 | +A multi-threaded implementation of [`shuffle!`](@ref). |
| 80 | +Construct in `B` a permuted copy of `A`. |
| 81 | +Optional arg `rng` specifies a random number generator (see [`TaskLocalRNG`](@ref)). |
| 82 | +
|
| 83 | +# Examples |
| 84 | +```jldoctest |
| 85 | +julia> b = Vector{Int}(undef, 16); |
| 86 | +
|
| 87 | +julia> ppshuffle!(b, 1:16); |
| 88 | +
|
| 89 | +julia> isperm(b) |
| 90 | +true |
| 91 | +``` |
| 92 | +""" |
| 93 | +function ppshuffle!(r::TaskLocalRNG, B::AbstractArray{T}, A::Union{AbstractArray{T}, Base.OneTo{T}}) where {T<:Integer} |
| 94 | + nparts = max(2, (length(A) * sizeof(T)) >> 21) |
| 95 | + nparts = nextpow(2, nparts) |
| 96 | + mask = nparts <= typemax(UInt8) + 1 ? UInt8(nparts - 1) : UInt16(nparts - 1) |
| 97 | + _ppshuffle!(r, B, A, mask) |
| 98 | +end |
| 99 | +ppshuffle!(B::AbstractArray{T}, A::Union{AbstractArray{T}, Base.OneTo{T}}) where {T<:Integer} = ppshuffle!(default_rng(), B, A) |
| 100 | + |
| 101 | + |
| 102 | +""" |
| 103 | + ppshuffle([rng=default_rng(),] A::AbstractArray) |
| 104 | +
|
| 105 | +A multi-threaded implementation of [`shuffle`](@ref). |
| 106 | +Expected to run noticeably faster for `A` large. |
| 107 | +
|
| 108 | +Return a randomly permuted copy of `A`. The optional `rng` argument specifies a random |
| 109 | +number generator (see [`TaskLocalRNG`](@ref)). |
| 110 | +To permute `A` in-place, see [`ppshuffle!`](@ref). To obtain randomly permuted |
| 111 | +indices, see [`pprandperm`](@ref). |
| 112 | +
|
| 113 | +# Examples |
| 114 | +```jldoctest |
| 115 | +julia> isperm(ppshuffle(Vector(1:16))) |
| 116 | +true |
| 117 | +``` |
| 118 | +""" |
| 119 | +ppshuffle(r::TaskLocalRNG, A::Union{AbstractArray{T}, Base.OneTo{T}}) where {T<:Integer} = ppshuffle!(r, similar(A), A) |
| 120 | +ppshuffle(A::Union{AbstractArray{T}, Base.OneTo{T}}) where {T<:Integer} = ppshuffle(default_rng(), A) |
| 121 | + |
| 122 | + |
| 123 | +## pprandperm! & pprandperm |
| 124 | + |
| 125 | +""" |
| 126 | + pprandperm([rng::TaskLocalRNG=default_rng(),] n::{T<:Integer}) |
| 127 | +
|
| 128 | +A multi-threaded implementation of [`randperm`](@ref). |
| 129 | +Expected to run noticeably faster for `n` large. |
| 130 | +
|
| 131 | +Construct a random permutation of length `n`. The optional `rng` |
| 132 | +argument specifies a random number generator (see [`TaskLocalRNG`](@ref)). |
| 133 | +The element type of the result is the same as the type of `n`. |
| 134 | +
|
| 135 | +# Examples |
| 136 | +```jldoctest |
| 137 | +julia> isperm(pprandperm(1024)) |
| 138 | +true |
| 139 | +``` |
| 140 | +""" |
| 141 | +pprandperm(r::TaskLocalRNG, n::T) where {T<:Integer} = ppshuffle(r, Base.OneTo(n)) |
| 142 | +pprandperm(n::T) where {T<:Integer} = ppshuffle(Base.OneTo(n)) |
| 143 | + |
| 144 | +""" |
| 145 | + pprandperm!([rng=default_rng(),] A::Array{<:Integer}) |
| 146 | +
|
| 147 | +A multi-threaded implementation of [`randperm!`](@ref). |
| 148 | +Expected to run noticeably faster for `A` large. |
| 149 | +
|
| 150 | +Construct in `A` a random permutation of length `length(A)`. |
| 151 | +The optional `rng` argument specifies a random |
| 152 | +number generator (see [`TaskLocalRNG`](@ref)). |
| 153 | +To randomly permute an arbitrary vector, see |
| 154 | +[`ppshuffle`](@ref) or [`ppshuffle!`](@ref). |
| 155 | +
|
| 156 | +# Examples |
| 157 | +```jldoctest |
| 158 | +julia> A = Vector{Int}(undef, 1024); |
| 159 | +
|
| 160 | +julia> pprandperm!(A); |
| 161 | +
|
| 162 | +julia> isperm(A) |
| 163 | +true |
| 164 | +``` |
| 165 | +""" |
| 166 | +pprandperm!(r::TaskLocalRNG, A::AbstractArray{T}) where {T<:Integer} = ppshuffle!(r, A, Base.OneTo(length(A)%eltype(A))) |
| 167 | +pprandperm!(A::AbstractArray{T}) where {T<:Integer} = pprandperm!(default_rng(), A) |
0 commit comments