Skip to content

Commit d43195f

Browse files
committed
parallel partitioned shuffle
add ppshuffle, pprandperm to stdlib.Random (ppmisc.jl)
1 parent c239e99 commit d43195f

File tree

2 files changed

+168
-0
lines changed

2 files changed

+168
-0
lines changed

stdlib/Random/src/Random.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ include("RNGs.jl")
303303
include("generation.jl")
304304
include("normal.jl")
305305
include("misc.jl")
306+
include("ppmisc.jl")
306307
include("XoshiroSimd.jl")
307308

308309
## rand & rand! & seed! docstrings

stdlib/Random/src/ppmisc.jl

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
#
3+
# Parallel Partitioned Shuffle
4+
#
5+
using Base.Threads
6+
export ppshuffle!, ppshuffle, pprandperm!, pprandperm
7+
8+
## ppshuffle! & ppshuffle
9+
10+
"""
11+
_ppshuffle!(rng::TaskLocalRNG, B::AbstractArray{T}, A::Union{AbstractArray{T}, Base.OneTo{T}}, mask<:Union{UInt8, UInt16})
12+
13+
Parallel Partitioned Shuffle
14+
1. partition input randomly
15+
2. shuffle partitions concurrently
16+
17+
Arg `mask` determines number of partitions (mask + 1) to be used.
18+
"""
19+
function _ppshuffle!(r::TaskLocalRNG, B::AbstractArray{T}, A::Union{AbstractArray{T}, Base.OneTo{T}}, mask::Tu) where {T, Tu<:Union{UInt8, UInt16}}
20+
# determine number of partitions
21+
nparts = mask + 1
22+
@assert ispow2(nparts) "invalid mask $(mask)"
23+
@assert length(A) == length(B)
24+
25+
n = length(A)
26+
s = Random.SamplerType{Tu}()
27+
28+
# an array to count partition hits by threads
29+
# partitions map to rows (pid)
30+
# threads map to cols (tid)
31+
# we add an extra cache line for each thread
32+
nrows = nparts < 8 ? 8 : nparts + 8
33+
hits = zeros(Int, nrows, nthreads())
34+
35+
# save initial random state
36+
r0 = copy(r)
37+
# 1st pass
38+
# assign input to partitions uniformly at random
39+
# count cells hit by each thread in every partition
40+
@threads :static for i in 1:n
41+
local tid, pid = threadid(), rand(r, s) & mask + 1
42+
@inbounds hits[pid, tid] += 1
43+
end
44+
45+
# cumsum partition hits
46+
# to mark boundaries of space reserved by each thread in every partition
47+
# note that the 1st column will contain boundaries of entire partitions
48+
prev = 0
49+
for pid = 1:nparts, tid = 1:nthreads()
50+
@inbounds prev = hits[pid, tid] += prev
51+
end
52+
# mark the end of the last partition
53+
hits[nparts + 1, 1] = n
54+
55+
# recover random state
56+
copy!(r, r0)
57+
# 2nd pass
58+
# scatter input accross partitions uniformly at random
59+
# note that input distribution is identical as in the 1st pass
60+
# since we recovered the initial random state
61+
@threads :static for i in 1:n
62+
local tid, pid = threadid(), rand(r, s) & mask + 1
63+
@inbounds B[hits[pid, tid]] = A[i]
64+
@inbounds hits[pid, tid] -= 1
65+
end
66+
67+
# input is partitioned
68+
# shuffle partitions in parallel
69+
@threads :static for pid in 1:nparts
70+
@inbounds local chunk = view(B, hits[pid, 1] + 1:hits[pid + 1, 1])
71+
shuffle!(r, chunk)
72+
end
73+
B
74+
end
75+
76+
"""
77+
ppshuffle!([rng::TaskLocalRNG=default_rng(),] B::AbstractArray{T}, A::Union{AbstractArray{T}, Base.OneTo{T}}, mask<:Union{UInt8, UInt16})
78+
79+
A multi-threaded implementation of [`shuffle!`](@ref).
80+
Construct in `B` a permuted copy of `A`.
81+
Optional arg `rng` specifies a random number generator (see [`TaskLocalRNG`](@ref)).
82+
83+
# Examples
84+
```jldoctest
85+
julia> b = Vector{Int}(undef, 16);
86+
87+
julia> ppshuffle!(b, 1:16);
88+
89+
julia> isperm(b)
90+
true
91+
```
92+
"""
93+
function ppshuffle!(r::TaskLocalRNG, B::AbstractArray{T}, A::Union{AbstractArray{T}, Base.OneTo{T}}) where {T<:Integer}
94+
nparts = max(2, (length(A) * sizeof(T)) >> 21)
95+
nparts = nextpow(2, nparts)
96+
mask = nparts <= typemax(UInt8) + 1 ? UInt8(nparts - 1) : UInt16(nparts - 1)
97+
_ppshuffle!(r, B, A, mask)
98+
end
99+
ppshuffle!(B::AbstractArray{T}, A::Union{AbstractArray{T}, Base.OneTo{T}}) where {T<:Integer} = ppshuffle!(default_rng(), B, A)
100+
101+
102+
"""
103+
ppshuffle([rng=default_rng(),] A::AbstractArray)
104+
105+
A multi-threaded implementation of [`shuffle`](@ref).
106+
Expected to run noticeably faster for `A` large.
107+
108+
Return a randomly permuted copy of `A`. The optional `rng` argument specifies a random
109+
number generator (see [`TaskLocalRNG`](@ref)).
110+
To permute `A` in-place, see [`ppshuffle!`](@ref). To obtain randomly permuted
111+
indices, see [`pprandperm`](@ref).
112+
113+
# Examples
114+
```jldoctest
115+
julia> isperm(ppshuffle(Vector(1:16)))
116+
true
117+
```
118+
"""
119+
ppshuffle(r::TaskLocalRNG, A::Union{AbstractArray{T}, Base.OneTo{T}}) where {T<:Integer} = ppshuffle!(r, similar(A), A)
120+
ppshuffle(A::Union{AbstractArray{T}, Base.OneTo{T}}) where {T<:Integer} = ppshuffle(default_rng(), A)
121+
122+
123+
## pprandperm! & pprandperm
124+
125+
"""
126+
pprandperm([rng::TaskLocalRNG=default_rng(),] n::{T<:Integer})
127+
128+
A multi-threaded implementation of [`randperm`](@ref).
129+
Expected to run noticeably faster for `n` large.
130+
131+
Construct a random permutation of length `n`. The optional `rng`
132+
argument specifies a random number generator (see [`TaskLocalRNG`](@ref)).
133+
The element type of the result is the same as the type of `n`.
134+
135+
# Examples
136+
```jldoctest
137+
julia> isperm(pprandperm(1024))
138+
true
139+
```
140+
"""
141+
pprandperm(r::TaskLocalRNG, n::T) where {T<:Integer} = ppshuffle(r, Base.OneTo(n))
142+
pprandperm(n::T) where {T<:Integer} = ppshuffle(Base.OneTo(n))
143+
144+
"""
145+
pprandperm!([rng=default_rng(),] A::Array{<:Integer})
146+
147+
A multi-threaded implementation of [`randperm!`](@ref).
148+
Expected to run noticeably faster for `A` large.
149+
150+
Construct in `A` a random permutation of length `length(A)`.
151+
The optional `rng` argument specifies a random
152+
number generator (see [`TaskLocalRNG`](@ref)).
153+
To randomly permute an arbitrary vector, see
154+
[`ppshuffle`](@ref) or [`ppshuffle!`](@ref).
155+
156+
# Examples
157+
```jldoctest
158+
julia> A = Vector{Int}(undef, 1024);
159+
160+
julia> pprandperm!(A);
161+
162+
julia> isperm(A)
163+
true
164+
```
165+
"""
166+
pprandperm!(r::TaskLocalRNG, A::AbstractArray{T}) where {T<:Integer} = ppshuffle!(r, A, Base.OneTo(length(A)%eltype(A)))
167+
pprandperm!(A::AbstractArray{T}) where {T<:Integer} = pprandperm!(default_rng(), A)

0 commit comments

Comments
 (0)