-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
adds the nth
function for iterables
#56580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
b565783
ed7a984
1bfc080
b3fca0b
26913e6
53b58f2
97f8104
1bd0cdb
7422eca
c49db01
4e49673
0082a01
6a6cfa1
111a7e7
eb605f2
320a7d4
8e76592
8be01c0
9afd84f
11ebb19
1bcaf4a
cff3b83
0b05ca0
3e522c7
5269053
5d26214
0cbf637
6f76d64
1d14912
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ Methods for working with Iterators. | |
baremodule Iterators | ||
|
||
# small dance to make this work from Base or Intrinsics | ||
import Base: @__MODULE__, parentmodule | ||
const Base = parentmodule(@__MODULE__) | ||
using .Base: | ||
@inline, Pair, Pairs, AbstractDict, IndexLinear, IndexStyle, AbstractVector, Vector, | ||
|
@@ -16,7 +15,7 @@ using .Base: | |
(:), |, +, -, *, !==, !, ==, !=, <=, <, >, >=, =>, missing, | ||
any, _counttuple, eachindex, ntuple, zero, prod, reduce, in, firstindex, lastindex, | ||
tail, fieldtypes, min, max, minimum, zero, oneunit, promote, promote_shape, LazyString, | ||
afoldl | ||
afoldl, mod1 | ||
using .Core | ||
using Core: @doc | ||
|
||
|
@@ -32,7 +31,7 @@ import Base: | |
getindex, setindex!, get, iterate, | ||
popfirst!, isdone, peek, intersect | ||
|
||
export enumerate, zip, rest, countfrom, take, drop, takewhile, dropwhile, cycle, repeated, product, flatten, flatmap, partition | ||
export enumerate, zip, rest, countfrom, take, drop, takewhile, dropwhile, cycle, repeated, product, flatten, flatmap, partition, nth | ||
public accumulate, filter, map, peel, reverse, Stateful | ||
|
||
""" | ||
|
@@ -1602,4 +1601,81 @@ end | |
# be the same as the keys, so this is a valid optimization (see #51631) | ||
pairs(s::AbstractString) = IterableStatePairs(s) | ||
|
||
""" | ||
nth(itr, n::Integer) | ||
|
||
Get the `n`th element of an iterable collection. Throw a `BoundsError`[@ref] if not existing. | ||
Will advance any `Stateful`[@ref] iterator. | ||
|
||
ghyatzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
See also: [`first`](@ref), [`last`](@ref), [`nth`](@ref) | ||
ghyatzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Examples | ||
```jldoctest | ||
julia> Iterators.nth(2:2:10, 4) | ||
8 | ||
|
||
julia> Iterators.nth(reshape(1:30, (5,6)), 6) | ||
6 | ||
|
||
julia> stateful = Iterators.Stateful(1:10); nth(stateful, 7) | ||
7 | ||
|
||
julia> first(stateful) | ||
8 | ||
``` | ||
""" | ||
nth(itr, n::Integer) = _nth(itr, n) | ||
|
||
# Count | ||
nth(itr::Count, n::Integer) = n > 0 ? itr.start + itr.step * (n - 1) : throw(ArgumentError("n must be positive.")) | ||
# Repeated | ||
nth(itr::Repeated, ::Integer) = itr.x | ||
# Take(Repeated) | ||
nth(itr::Take{<:Repeated}, n::Integer) = | ||
n > itr.n ? throw(BoundsError(itr, n)) : nth(itr.xs, n) | ||
|
||
# infinite cycle | ||
function nth(itr::Cycle{I}, n::Integer) where {I} | ||
if IteratorSize(I) isa Union{HasShape, HasLength} | ||
_nth(itr.xs, mod1(n, length(itr.xs))) | ||
else | ||
_nth(itr, n) | ||
end | ||
end | ||
|
||
# finite cycle: in reality a Flatten{Take{Repeated{O}}} iterator | ||
function nth(itr::Flatten{Take{Repeated{O}}}, n::Integer) where {O} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a clear need for all these specialisations? They seem quite exotic... do they show up? Are they really almost always faster? Are they in fact all well-tested? When I looked at the benchmarks earlier, the total time was driven by one case, so didn't answer these questions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be more concrete, can you please post benchmarks for each of the methods in the present PR, comparing to the most basic implementation? Ideally at varying lengths etc. to see if small/large is different.
That is also what the comment should say, not repeat the signature. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you refer to the julia> @benchmark nth($repeateditr, 9235)
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
Range (min … max): 3.042 ns … 8.750 ns ┊ GC (min … max): 0.00% … 0.00%
Time (median): 3.167 ns ┊ GC (median): 0.00%
Time (mean ± σ): 3.172 ns ± 0.097 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▅ █ ▅ ▃ ▁
▃▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▇ █
3.04 ns Histogram: log(frequency) by time 3.29 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark _nth($repeateditr, 9235)
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
Range (min … max): 3.333 ns … 14.166 ns ┊ GC (min … max): 0.00% … 0.00%
Time (median): 3.458 ns ┊ GC (median): 0.00%
Time (mean ± σ): 3.468 ns ± 0.331 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▂ █ █ ▃ ▂▃ ▂ ▁ ▁ ▂
▃▁▁▁▁█▁▁▁▁▁█▁▁▁▁▁█▁▁▁▁▁█▁▁▁▁██▁▁▁▁█▁▁▁▁▁█▁▁▁▁▁█▁▁▁▁▁▇▁▁▁▁▇ █
3.33 ns Histogram: log(frequency) by time 3.75 ns <
Memory estimate: 0 bytes, allocs estimate: 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I refer to every single specialisation. Each one needs to be individually justified. Again, on long cases, and short cases, and weird types (cycle of a string? cycle of a tuple?). And ideally in some file others can easily try out on other machines. Those that survive also need to be individually tested. Again, with weird cases (cycle of an empty vector?) poking for weak points. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks this looks great, will read slowly. Haven't thought hard re errors, it does seem friendly to have something uniform & deliberate. But how much complexity it's worth I don't know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cycles of strings might need an extra specialization since the issue is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if the performance could be matched by specializing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Upon further reflection, I don't think it can work. as an example, if we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's been a while. nth_str(itr::Cycle{<:AbstractString}, n::Integer) =
(n < ncodeunits(itr.xs) / 2 || isempty(itr.xs)) ? _nth(itr, n) : _nth(itr.xs, mod1(n, length(itr.xs)))
function nth_str(itr::Flatten{Take{Repeated{<:AbstractString}}}, n::Integer)
cycles = itr.it.n
repeated = itr.it.xs.x
n < ncodeunits(repeated) / 2 || isempty(repeated) && return _nth(itr, n)
k = length(repeated)
n > k * cycles && first(())
_ntr(repeated, mod1(n, k))
end which only goes down the "optmized way" if we're at least a bit into the string, to make the expensive call to But turns out it's not really worth it in my opinion: Bench 1:2N#=
# Iterators.cycle("A"^N) # 1 code unit characters
=#
[ Info: Cycle{String}
[ Info: N = 10 - summed time of accessing every element from 1:2N.
[ Info: baseline / cycle optimization
518.2 ns / 303.2 ns 1.71x n/N = 1//1
[ Info: cycle opt / string cycle opt
303.2 ns / 301.0 ns 1.01x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 100 - summed time of accessing every element from 1:2N.
[ Info: baseline / cycle optimization
68500.0 ns / 21833.0 ns 3.14x n/N = 1//1
[ Info: cycle opt / string cycle opt
21833.0 ns / 21083.0 ns 1.04x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 1000 - summed time of accessing every element from 1:2N.
[ Info: baseline / cycle optimization
7025709.0 ns / 1918292.0 ns 3.66x n/N = 1//1
[ Info: cycle opt / string cycle opt
1918292.0 ns / 1975584.0 ns 0.97x n/N = 1//1
#=
# Iterators.cycle("μ"^N) # 2 code units chars
=#
[ Info: Cycle{String}
[ Info: N = 10 - summed time of accessing every element from 1:2N.
[ Info: baseline / cycle optimization
622.8 ns / 622.6 ns 1.00x n/N = 1//1
[ Info: cycle opt / string cycle opt
622.6 ns / 495.3 ns 1.26x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 100 - summed time of accessing every element from 1:2N.
[ Info: baseline / cycle optimization
75375.0 ns / 51541.0 ns 1.46x n/N = 1//1
[ Info: cycle opt / string cycle opt
51541.0 ns / 44083.0 ns 1.17x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 1000 - summed time of accessing every element from 1:2N.
[ Info: baseline / cycle optimization
7625583.0 ns / 4870583.0 ns 1.57x n/N = 1//1
[ Info: cycle opt / string cycle opt
4870583.0 ns / 4349917.0 ns 1.12x n/N = 1//1
#=
# Iterators.cycle(String(rand(UInt8, N))) # random code units
=#
[ Info: Cycle{String}
[ Info: N = 10 - summed time of accessing every element from 1:2N.
[ Info: baseline / cycle optimization
566.3 ns / 366.6 ns 1.54x n/N = 1//1
[ Info: cycle opt / string cycle opt
366.6 ns / 360.9 ns 1.02x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 100 - summed time of accessing every element from 1:2N.
[ Info: baseline / cycle optimization
70083.0 ns / 28708.0 ns 2.44x n/N = 1//1
[ Info: cycle opt / string cycle opt
28708.0 ns / 26666.0 ns 1.08x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 1000 - summed time of accessing every element from 1:2N.
[ Info: baseline / cycle optimization
7351625.0 ns / 2519292.0 ns 2.92x n/N = 1//1
[ Info: cycle opt / string cycle opt
2519292.0 ns / 2466583.0 ns 1.02x n/N = 1//1 Bench 1:N#=
# Iterators.cycle("A"^N) # 1 code unit characters
=#
[ Info: Cycle{String}
[ Info: N = 10 - summed time of accessing every element from 1:N.
[ Info: baseline / cycle optimization
124.9 ns / 149.1 ns 0.84x n/N = 1//1
[ Info: cycle opt / string cycle opt
149.1 ns / 138.5 ns 1.08x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 100 - summed time of accessing every element from 1:N.
[ Info: baseline / cycle optimization
16208.0 ns / 10875.0 ns 1.49x n/N = 1//1
[ Info: cycle opt / string cycle opt
10875.0 ns / 10125.0 ns 1.07x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 1000 - summed time of accessing every element from 1:N.
[ Info: baseline / cycle optimization
1738417.0 ns / 959291.0 ns 1.81x n/N = 1//1
[ Info: cycle opt / string cycle opt
959291.0 ns / 1004709.0 ns 0.95x n/N = 1//1
#=
# Iterators.cycle("μ"^N) # 2 code units chars
=#
[ Info: Cycle{String}
[ Info: N = 10 - summed time of accessing every element from 1:N.
[ Info: baseline / cycle optimization
160.2 ns / 307.8 ns 0.52x n/N = 1//1
[ Info: cycle opt / string cycle opt
307.8 ns / 178.5 ns 1.72x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 100 - summed time of accessing every element from 1:N.
[ Info: baseline / cycle optimization
18000.0 ns / 25708.0 ns 0.70x n/N = 1//1
[ Info: cycle opt / string cycle opt
25708.0 ns / 18333.0 ns 1.40x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 1000 - summed time of accessing every element from 1:N.
[ Info: baseline / cycle optimization
1890125.0 ns / 2435291.0 ns 0.78x n/N = 1//1
[ Info: cycle opt / string cycle opt
2435291.0 ns / 1896791.0 ns 1.28x n/N = 1//1
#=
# Iterators.cycle(String(rand(UInt8, N))) # random code units
=#
[ Info: Cycle{String}
[ Info: N = 10 - summed time of accessing every element from 1:N.
[ Info: baseline / cycle optimization
146.0 ns / 191.1 ns 0.76x n/N = 1//1
[ Info: cycle opt / string cycle opt
191.1 ns / 171.4 ns 1.12x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 100 - summed time of accessing every element from 1:N.
[ Info: baseline / cycle optimization
17000.0 ns / 12250.0 ns 1.39x n/N = 1//1
[ Info: cycle opt / string cycle opt
12250.0 ns / 10959.0 ns 1.12x n/N = 1//1
[ Info: Cycle{String}
[ Info: N = 1000 - summed time of accessing every element from 1:N.
[ Info: baseline / cycle optimization
1819666.0 ns / 1256500.0 ns 1.45x n/N = 1//1
[ Info: cycle opt / string cycle opt
1256500.0 ns / 1191125.0 ns 1.05x n/N = 1//1 Codeimport Base: Iterators.Cycle, IteratorSize, HasShape, HasLength, Iterators.Flatten, Iterators.Repeated,
Iterators.Take, Iterators.drop
using Printf
@usingany BenchmarkTools
Ns = [10, 100, 1000]
timemult(b1,b2, N, n) = @sprintf("%.1f ns / %.1f ns \t %.2fx \t n/N = %s",
b1.time*1e9,
b2.time*1e9,
b1.time/b2.time,
string(n//N)
)
function compare_infinite(itr, N, maxn = N)
@info "N = $N - baseline: simple version, N length of inner iterable, n element accessed."
_b1 = @btimed _nth($itr, 1)
b1 = @btimed nth($itr, 1)
@assert b1.value == _b1.value
println(timemult(_b1,b1, N, 1) )
_b3 = @btimed _nth($itr, 3)
b3 = @btimed nth($itr, 3)
@assert b3.value == _b3.value
println(timemult(_b3,b3, N, 3) )
_bmid = @btimed _nth($itr, $maxn ÷ 2)
bmid = @btimed nth($itr, $maxn ÷ 2)
@assert bmid.value == _bmid.value
println(timemult(_bmid, bmid , N, maxn÷2))
_bend = @btimed _nth($itr, $maxn-1)
bend = @btimed nth($itr, $maxn-1)
@assert bend.value == _bend.value
println(timemult(_bend, bend , N, maxn-1))
_bshortafter = @btimed _nth($itr, floor(Int, 2*$maxn))
bshortafter = @btimed nth($itr, floor(Int, 2*$maxn))
@assert bshortafter.value == _bshortafter.value
println(timemult(_bshortafter, bshortafter , N, 2maxn))
_blongafter = @btimed _nth($itr, floor(Int, 5*$maxn))
blongafter = @btimed nth($itr, floor(Int, 5*$maxn))
@assert blongafter.value == _blongafter.value
println(timemult(_blongafter, blongafter, N, 5maxn))
end
function compare_finite(itr, N)
@info "N = $N - baseline: simple version, N length of inner iterable, n element accessed."
_b1 = @btimed _nth($itr, 1)
b1 = @btimed nth($itr, 1)
@assert b1.value == _b1.value
println(timemult(_b1,b1, N, 1) )
_b3 = @btimed _nth($itr, 3)
b3 = @btimed nth($itr, 3)
@assert b3.value == _b3.value
println(timemult(_b3,b3, N, 3) )
_bmid = @btimed _nth($itr, $N ÷ 2)
bmid = @btimed nth($itr, $N ÷ 2)
@assert bmid.value == _bmid.value
println(timemult(_bmid, bmid , N, N÷2))
_bend = @btimed _nth($itr, $N-1)
bend = @btimed nth($itr, $N-1)
@assert bend.value == _bend.value
println(timemult(_bend, bend , N, N-1))
end
function compare_all_string(itr, N)
@info "N = $N - summed time of accessing every element from 1:N."
_b = @btimed for i in 1:($N)
_nth($itr, i)
end
b = @btimed for i in 1:($N)
nth($itr, i)
end
b_str = @btimed for i in 1:($N)
nth_str($itr, i)
end
@info "baseline / cycle optimization"
println(timemult(_b, b , N, N))
@info "cycle opt / string cycle opt"
println(timemult(b, b_str, N, N))
end
inf_cycle_iters = [
# N -> Iterators.cycle(collect(1:N)),
N -> Iterators.cycle("A"^N),
N -> Iterators.cycle("μ"^N),
N -> Iterators.cycle(String(rand(UInt8, N)))
# N -> Iterators.cycle(ntuple(N) do i; rand() > 0.5 ? "AAAAA" : 12345 end),
# N -> Iterators.cycle(ntuple(N) do i; i end)
]
for itr_ in inf_cycle_iters, N in Ns
itr = itr_(N)
println()
@info typeof(itr)
compare_all_string(itr, N)
end
for itr_ in inf_cycle_iters, N in Ns
itr = itr_(N)
println()
@info typeof(itr)
compare_infinite(itr, N, N)
end
inf_iters = [
Iterators.repeated(1),
Iterators.countfrom(10, 3),
]
for itr in inf_iters, N in Ns
println()
@info typeof(itr)
compare_infinite(itr, 1, N)
end
cycle_iters = [
N -> Iterators.cycle(collect(1:10), N ÷ 10),
N -> Iterators.cycle("A"^10, N ÷ 10),
# N -> Iterators.cycle(ntuple(N) do i; rand() > 0.5 ? "AAAAA" : 12345 end),
N -> Iterators.cycle(ntuple(10) do i; i end, N÷10)
]
for itr_ in cycle_iters, N in Ns
itr = itr_(N)
println()
@info typeof(itr)
compare_finite(itr, N)
end tldr: the cycle optimization version loses a bit only on small strings if we don't cycle through the whole string at least once (and a bit more). But I would argue that if you expect to only index the first half of a string over and over again, then an |
||
if IteratorSize(O) isa Union{HasShape, HasLength} | ||
cycles = itr.it.n | ||
repeated = itr.it.xs.x | ||
k = length(repeated) | ||
n > k*cycles ? throw(BoundsError(itr, n)) : _nth(repeated, mod1(n, k)) | ||
else | ||
_nth(itr, n) | ||
end | ||
end | ||
|
||
_nth(itr, n) = first(drop(itr, n-1)) | ||
_nth(itr::AbstractArray, n) = itr[begin + n-1] | ||
""" | ||
nth(n::Integer) | ||
|
||
Return a function that gets the `n`-th element from any iterator passed to it. | ||
Equivalent to `Base.Fix2(nth, n)` or `itr -> nth(itr, n)`. | ||
|
||
See also: [`nth`](@ref), [`Base.Fix2`](@ref) | ||
# Examples | ||
```jldoctest | ||
julia> fifth_element = Iterators.nth(5) | ||
(::Base.Fix2{typeof(Base.Iterators.nth), Int64}) (generic function with 2 methods) | ||
|
||
julia> fifth_element(reshape(1:30, (5,6))) | ||
5 | ||
|
||
julia> map(fifth_element, ("Willis", "Jovovich", "Oldman")) | ||
('i', 'v', 'a') | ||
``` | ||
""" | ||
nth(n::Integer) = Base.Fix2(nth, n) | ||
|
||
end |
Uh oh!
There was an error while loading. Please reload this page.