From 5b7cad567bfe26803744071a035640ed321d884b Mon Sep 17 00:00:00 2001 From: "H.A." Date: Tue, 1 Jul 2025 14:19:41 +0200 Subject: [PATCH 1/2] Defining zero(T) as the neutral element for xor. --- src/host/mapreduce.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl index ef9f3a31..8692c68a 100644 --- a/src/host/mapreduce.jl +++ b/src/host/mapreduce.jl @@ -15,6 +15,7 @@ neutral_element(op, T) = Please pass it as an explicit argument to `GPUArrays.mapreducedim!`, or register it globally by defining `GPUArrays.neutral_element(::typeof($op), T)`.""") neutral_element(::typeof(Base.:(|)), T) = zero(T) +neutral_element(::typeof(Base.:(⊻)), T) = zero(T) neutral_element(::typeof(Base.:(+)), T) = zero(T) neutral_element(::typeof(Base.add_sum), T) = zero(T) neutral_element(::typeof(Base.:(&)), T) = one(T) From e37cd24dc99826612e68592eb8fc5a09839d18a7 Mon Sep 17 00:00:00 2001 From: "ha.git" Date: Tue, 1 Jul 2025 15:33:57 +0200 Subject: [PATCH 2/2] Modify neutral_element for & to ~zero(T). Included logical operators in the test suite. --- src/host/mapreduce.jl | 2 +- test/testsuite/reductions.jl | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl index 8692c68a..4a128313 100644 --- a/src/host/mapreduce.jl +++ b/src/host/mapreduce.jl @@ -16,9 +16,9 @@ neutral_element(op, T) = or register it globally by defining `GPUArrays.neutral_element(::typeof($op), T)`.""") neutral_element(::typeof(Base.:(|)), T) = zero(T) neutral_element(::typeof(Base.:(⊻)), T) = zero(T) +neutral_element(::typeof(Base.:(&)), T) = ~zero(T) neutral_element(::typeof(Base.:(+)), T) = zero(T) neutral_element(::typeof(Base.add_sum), T) = zero(T) -neutral_element(::typeof(Base.:(&)), T) = one(T) neutral_element(::typeof(Base.:(*)), T) = one(T) neutral_element(::typeof(Base.mul_prod), T) = one(T) neutral_element(::typeof(Base.min), T) = typemax(T) diff --git a/test/testsuite/reductions.jl b/test/testsuite/reductions.jl index d575edfe..a1bdc83e 100644 --- a/test/testsuite/reductions.jl +++ b/test/testsuite/reductions.jl @@ -66,6 +66,11 @@ end (0,)=>[1]] @test compare(A->reduce(+, A; dims=dims, init=zero(ET)), AT, rand(range, sz)) @test compare(A->reduce(*, A; dims=dims, init=one(ET)), AT, rand(range, sz)) + if ET <: Integer + @test compare(A->reduce(&, A; dims=dims, init=~zero(ET)), AT, rand(range, sz)) + @test compare(A->reduce(|, A; dims=dims, init=zero(ET)), AT, rand(range, sz)) + @test compare(A->reduce(⊻, A; dims=dims, init=zero(ET)), AT, rand(range, sz)) + end end end end