Skip to content

WIP: a more principled take on dimensional reduction inits #55318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b9bdd63
a more principled take on dimensional reduction inits
mbauman Jul 31, 2024
2b000b1
add support for broadcasteds
mbauman Jul 31, 2024
07c1da1
add tests for complex reductions
mbauman Jul 31, 2024
2179df3
add tests for bitwise operators on integers
mbauman Jul 31, 2024
344ed43
fixup! add tests for bitwise operators on integers
mbauman Jul 31, 2024
8eb7d91
restore performance for linear fast dim=1+
mbauman Jul 31, 2024
11e98ed
fixup! add support for broadcasteds
mbauman Jul 31, 2024
c32c6ed
add more tests
mcabbott Jul 31, 2024
58d2b27
perf: add loop switching
mbauman Jul 31, 2024
2d595bb
perf: slice up the slow part
mbauman Jul 31, 2024
b7adc89
use builtin incremental widening slowly when needed
mbauman Aug 1, 2024
e59a4d8
fixup! use builtin incremental widening slowly when needed
mbauman Aug 2, 2024
6af4850
add tests for 44906
mbauman Aug 2, 2024
d31cba6
also revamp findmin/findmax
mbauman Aug 4, 2024
6025380
minor cleanups, slightly better naming and hoist some merges
mbauman Aug 6, 2024
3e9e9d0
fixup findmin/max/etc and add support for custom allocators
mbauman Aug 27, 2024
6bfd6c1
more tests
mbauman Aug 29, 2024
1635026
better comment and slight cleanups
mbauman Aug 29, 2024
64a27ea
Merge branch 'master' into mapreducedim-init-adienes-testing
adienes Dec 16, 2024
ce4c0cd
Merge remote-tracking branch 'origin' into mapreducedim-init-adienes-…
adienes Dec 21, 2024
48b4ad0
couple tuneups to dimensional mapreduce refactor
adienes Dec 21, 2024
aacdcee
Merge branch 'master' into mb/mapreducedim-init
mbauman May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ function _collect(::Type{T}, itr, isz::SizeUnknown) where T
end

# make a collection similar to `c` and appropriate for collecting `itr`
_similar_for(c, ::Type{T}, itr, isz, shp) where {T} = similar(c, T)
_similar_for(c, ::Type{T}, itr, isz, shp) where {T} = similar(c, T, shp)

_similar_shape(itr, ::SizeUnknown) = nothing
_similar_shape(itr, ::HasLength) = length(itr)::Integer
Expand Down Expand Up @@ -731,6 +731,12 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A)

collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr))

struct _Allocator{T}
f::T
end
similar(f::_Allocator, ::Type{T}, dims) where {T} = f.f(T, dims)
collect_allocator(f::F, itr) where {F} = _collect(_Allocator(f), itr, IteratorEltype(itr), IteratorSize(itr))

_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
copyto!(_similar_for(cont, eltype(itr), itr, isz, _similar_shape(itr, isz)), itr)

Expand Down
3 changes: 3 additions & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ Base.similar(::Broadcasted{ArrayConflict}, ::Type{ElType}, dims) where ElType =
similar(Array{ElType, length(dims)}, dims)
Base.similar(::Broadcasted{ArrayConflict}, ::Type{Bool}, dims) =
similar(BitArray, dims)
# As well as the default behavior
Base.similar(::Broadcasted, ::Type{ElType}, dims) where ElType =
similar(Array{ElType, length(dims)}, dims)

@inline Base.axes(bc::Broadcasted) = _axes(bc, bc.axes)
_axes(::Broadcasted, axes::Tuple) = axes
Expand Down
4 changes: 4 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,10 @@ end

# BEGIN 1.12 deprecations

function reducedim_initarray end
const _dep_message_reducedim_initarray = ", these internals have been removed. To customize the array returned by dimensional reductions, implement mapreduce_similar instead"
deprecate(Base, :reducedim_initarray)

@deprecate isbindingresolved(m::Module, var::Symbol) true false

"""
Expand Down
14 changes: 5 additions & 9 deletions base/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,19 +401,15 @@ minimum_fast(a; kw...) = Base.reduce(min_fast, a; kw...)
maximum_fast(f, a; kw...) = Base.mapreduce(f, max_fast, a; kw...)
minimum_fast(f, a; kw...) = Base.mapreduce(f, min_fast, a; kw...)

Base.reducedim_init(f, ::typeof(max_fast), A::AbstractArray, region) =
Base.reducedim_init(f, max, A::AbstractArray, region)
Base.reducedim_init(f, ::typeof(min_fast), A::AbstractArray, region) =
Base.reducedim_init(f, min, A::AbstractArray, region)

maximum!_fast(r::AbstractArray, A::AbstractArray; kw...) =
maximum!_fast(identity, r, A; kw...)
minimum!_fast(r::AbstractArray, A::AbstractArray; kw...) =
minimum!_fast(identity, r, A; kw...)

maximum!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) =
Base.mapreducedim!(f, max_fast, Base.initarray!(r, f, max, init, A), A)
minimum!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) =
Base.mapreducedim!(f, min_fast, Base.initarray!(r, f, min, init, A), A)
maximum!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) = init ?
Base.mapreduce!(f, max_fast, r, A) : Base.mapreducedim!(f, max_fast, r, A)

minimum!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) = init ?
Base.mapreduce!(f, min_fast, r, A) : Base.mapreducedim!(f, min_fast, r, A)

end
8 changes: 5 additions & 3 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,7 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...)
@noinline function mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted,
ifirst::Integer, ilast::Integer, blksize::Int)
if ifirst == ilast
@inbounds a1 = A[ifirst]
return mapreduce_first(f, op, a1)
throw(AssertionError("mapreduce_impl must not be called with only one element"))
elseif ilast - ifirst < blksize
# sequential portion
@inbounds a1 = A[ifirst]
Expand Down Expand Up @@ -356,6 +355,7 @@ reduce_empty(::typeof(mul_prod), ::Type{T}) where {T<:BitUnsignedSmall} = one(UI

reduce_empty(op::BottomRF, ::Type{T}) where {T} = reduce_empty(op.rf, T)
reduce_empty(op::MappingRF, ::Type{T}) where {T} = mapreduce_empty(op.f, op.rf, T)
reduce_empty(op::MappingRF{<:Any,<:BottomRF}, ::Type{T}) where {T} = mapreduce_empty(op.f, op.rf.rf, T)
reduce_empty(op::FilteringRF, ::Type{T}) where {T} = reduce_empty(op.rf, T)
reduce_empty(op::FlipArgs, ::Type{T}) where {T} = reduce_empty(op.f, T)

Expand Down Expand Up @@ -1050,7 +1050,9 @@ argmin(itr) = findmin(itr)[2]

## count

_bool(f) = x->f(x)::Bool
_assert_bool(x) = x::Bool
_bool(f) = _assert_bool ∘ f
mapreduce_empty(::ComposedFunction{typeof(_assert_bool), <:Any}, ::typeof(add_sum), itr) = false+false

"""
count([f=identity,] itr; init=0)::Integer
Expand Down
Loading
Loading