Skip to content

Commit 54e12e5

Browse files
authored
optimize reductions for OffsetRanges (#202)
* optimize reductions for OffsetUnitRanges * propagate sum to parent
1 parent 7a1e2b9 commit 54e12e5

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

src/OffsetArrays.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ end
323323
A
324324
end
325325

326+
Base.in(x, A::OffsetArray) = in(x, parent(A))
327+
326328
Base.strides(A::OffsetArray) = strides(parent(A))
327329
Base.elsize(::Type{OffsetArray{T,N,A}}) where {T,N,A} = Base.elsize(A)
328330
@inline Base.unsafe_convert(::Type{Ptr{T}}, A::OffsetArray{T}) where {T} = Base.unsafe_convert(Ptr{T}, parent(A))
@@ -334,6 +336,7 @@ Broadcast.broadcast_unalias(dest::OffsetArray, src::OffsetArray) = parent(dest)
334336
### Special handling for AbstractRange
335337

336338
const OffsetRange{T} = OffsetArray{T,1,<:AbstractRange{T}}
339+
const OffsetUnitRange{T} = OffsetArray{T,1,<:AbstractUnitRange{T}}
337340
const IIUR = IdentityUnitRange{S} where S<:AbstractUnitRange{T} where T<:Integer
338341

339342
Base.step(a::OffsetRange) = step(parent(a))
@@ -368,6 +371,19 @@ end
368371
# This is technically breaking, so it might be incorporated in the next major release
369372
# Base.getindex(a::OffsetRange, ::Colon) = OffsetArray(a.parent[:], a.offsets)
370373

374+
# mapreduce is faster with an IdOffsetRange than with an OffsetUnitRange
375+
# We therefore convert OffsetUnitRanges to IdOffsetRanges with the same values and axes
376+
function Base.mapreduce(f, op, As::OffsetUnitRange...; kw...)
377+
ofs = map(A -> first(axes(A,1)) - 1, As)
378+
AIds = map((A, of) -> IdOffsetRange(UnitRange(parent(A) .- of), of), As, ofs)
379+
mapreduce(f, op, AIds...; kw...)
380+
end
381+
382+
# Optimize certain reductions that treat an OffsetVector as a list
383+
for f in [:minimum, :maximum, :extrema, :sum]
384+
@eval Base.$f(r::OffsetRange) = $f(parent(r))
385+
end
386+
371387
function Base.show(io::IO, r::OffsetRange)
372388
show(io, r.parent)
373389
print(io, " with indices ", UnitRange(axes(r, 1)))

test/runtests.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,13 @@ Base.convert(::Type{Int}, a::WeirdInteger) = a
374374
@test_throws OverflowError OffsetArray{Float64, 1, typeof(ao)}(ao, (-2, )) # inner Constructor
375375
@test_throws OverflowError OffsetArray(ao, (-2, )) # convinient constructor accumulate offsets
376376

377+
@testset "OffsetRange" begin
378+
local r = 1:100
379+
local a = OffsetVector(r, 4)
380+
@test first(r) in a
381+
@test !(last(r) + 1 in a)
382+
end
383+
377384
# disallow OffsetVector(::Array{<:Any, N}, offsets) where N != 1
378385
@test_throws ArgumentError OffsetVector(zeros(2,2), (2, 2))
379386
@test_throws ArgumentError OffsetVector(zeros(2,2), 2, 2)
@@ -1512,6 +1519,44 @@ end
15121519

15131520
amin, amax = extrema(parent(A))
15141521
@test clamp.(A, (amax+amin)/2, amax) == OffsetArray(clamp.(parent(A), (amax+amin)/2, amax), axes(A))
1522+
1523+
@testset "mapreduce for OffsetRange" begin
1524+
for r in Any[5:100, IdOffsetRange(1:100, 4), IdOffsetRange(4:5), # AbstractUnitRanges
1525+
2:4:14, 1.5:1.0:10.5, # AbstractRanges
1526+
]
1527+
1528+
a = OffsetVector(r, 2);
1529+
@test mapreduce(identity, +, a) == mapreduce(identity, +, r)
1530+
@test mapreduce(x -> x^2, (x,y) -> x, a) == mapreduce(x -> x^2, (x,y) -> x, r)
1531+
1532+
b = mapreduce(identity, +, a, dims = 1)
1533+
br = mapreduce(identity, +, r, dims = 1)
1534+
@test no_offset_view(b) == no_offset_view(br)
1535+
@test axes(b, 1) == first(axes(a,1)):first(axes(a,1))
1536+
1537+
@test mapreduce(identity, +, a, init = 3) == mapreduce(identity, +, r, init = 3)
1538+
if VERSION >= v"1.2"
1539+
@test mapreduce((x,y) -> x*y, +, a, a) == mapreduce((x,y) -> x*y, +, r, r)
1540+
@test mapreduce((x,y) -> x*y, +, a, a, init = 10) == mapreduce((x,y) -> x*y, +, r, r, init = 10)
1541+
end
1542+
1543+
for f in [sum, minimum, maximum]
1544+
@test f(a) == f(r)
1545+
1546+
b = f(a, dims = 1);
1547+
br = f(r, dims = 1)
1548+
@test no_offset_view(b) == no_offset_view(br)
1549+
@test axes(b, 1) == first(axes(a,1)):first(axes(a,1))
1550+
1551+
b = f(a, dims = 2);
1552+
br = f(r, dims = 2)
1553+
@test no_offset_view(b) == no_offset_view(br)
1554+
@test axes(b, 1) == axes(a,1)
1555+
end
1556+
1557+
@test extrema(a) == extrema(r)
1558+
end
1559+
end
15151560
end
15161561

15171562
# v = OffsetArray([1,1e100,1,-1e100], (-3,))*1000

0 commit comments

Comments
 (0)