diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl index ef9f3a31..4a128313 100644 --- a/src/host/mapreduce.jl +++ b/src/host/mapreduce.jl @@ -15,9 +15,10 @@ neutral_element(op, T) = Please pass it as an explicit argument to `GPUArrays.mapreducedim!`, or register it globally by defining `GPUArrays.neutral_element(::typeof($op), T)`.""") neutral_element(::typeof(Base.:(|)), T) = zero(T) +neutral_element(::typeof(Base.:(⊻)), T) = zero(T) +neutral_element(::typeof(Base.:(&)), T) = ~zero(T) neutral_element(::typeof(Base.:(+)), T) = zero(T) neutral_element(::typeof(Base.add_sum), T) = zero(T) -neutral_element(::typeof(Base.:(&)), T) = one(T) neutral_element(::typeof(Base.:(*)), T) = one(T) neutral_element(::typeof(Base.mul_prod), T) = one(T) neutral_element(::typeof(Base.min), T) = typemax(T) diff --git a/test/testsuite/reductions.jl b/test/testsuite/reductions.jl index d575edfe..a1bdc83e 100644 --- a/test/testsuite/reductions.jl +++ b/test/testsuite/reductions.jl @@ -66,6 +66,11 @@ end (0,)=>[1]] @test compare(A->reduce(+, A; dims=dims, init=zero(ET)), AT, rand(range, sz)) @test compare(A->reduce(*, A; dims=dims, init=one(ET)), AT, rand(range, sz)) + if ET <: Integer + @test compare(A->reduce(&, A; dims=dims, init=~zero(ET)), AT, rand(range, sz)) + @test compare(A->reduce(|, A; dims=dims, init=zero(ET)), AT, rand(range, sz)) + @test compare(A->reduce(⊻, A; dims=dims, init=zero(ET)), AT, rand(range, sz)) + end end end end