Skip to content

Commit a7f8825

Browse files
petvanaKristofferC
authored andcommitted
Fix sum() and prod() for tuples (#41510)
This PR aims to fix #39182 and #39183 by using the universal implementation of `prod` and `sum` from https://github.com/JuliaLang/julia/blob/97f817a379b0c3c5f9bb803427fe88a018ebfe18/base/reduce.jl#L588 However, the file `abstractarray.jl` is included way sooner, and it is crucial to have already a simplified version of `prod` function. We can specify a simplified version or `prod` only for a system-wide `Int` type that is sufficient to compile `Base`. ``` julia prod(x::Tuple{}) = 1 # This is consistent with the regular prod because there is no need for size promotion # if all elements in the tuple are of system size. prod(x::Tuple{Int, Vararg{Int}}) = *(x...) ``` Although the implementations are different, they lead to the same binary code for tuples containing ~~`UInt` and~~ `Int`. ``` julia julia> a = (1,2,3) (1, 2, 3) # Simplified version for tuples containing Int only julia> prod_simplified(x::Tuple{Int, Vararg{Int}}) = *(x...) julia> @code_native prod_simplified(a) .text ; ┌ @ REPL[1]:1 within `prod_simplified' ; │┌ @ operators.jl:560 within `*' @ int.jl:88 movq 8(%rdi), %rax imulq (%rdi), %rax imulq 16(%rdi), %rax ; │└ retq nop ; └ ``` ``` julia # Regular prod without the simplification julia> @code_native prod(a) .text ; ┌ @ reduce.jl:588 within `prod` ; │┌ @ reduce.jl:588 within `#prod#247` ; ││┌ @ reduce.jl:289 within `mapreduce` ; │││┌ @ reduce.jl:289 within `#mapreduce#240` ; ││││┌ @ reduce.jl:162 within `mapfoldl` ; │││││┌ @ reduce.jl:162 within `#mapfoldl#236` ; ││││││┌ @ reduce.jl:44 within `mapfoldl_impl` ; │││││││┌ @ reduce.jl:48 within `foldl_impl` ; ││││││││┌ @ tuple.jl:276 within `_foldl_impl` ; │││││││││┌ @ operators.jl:613 within `afoldl` ; ││││││││││┌ @ reduce.jl:81 within `BottomRF` ; │││││││││││┌ @ reduce.jl:38 within `mul_prod` ; ││││││││││││┌ @ int.jl:88 within `*` movq 8(%rdi), %rax imulq (%rdi), %rax ; │││││││││└└└└ ; │││││││││┌ @ operators.jl:614 within `afoldl` ; ││││││││││┌ @ reduce.jl:81 within `BottomRF` ; │││││││││││┌ @ reduce.jl:38 within `mul_prod` ; ││││││││││││┌ @ int.jl:88 within `*` imulq 16(%rdi), %rax ; │└└└└└└└└└└└└ retq nop ; └ ``` (cherry picked from commit bada80c)
1 parent 7cdfa10 commit a7f8825

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

base/tuple.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -467,17 +467,12 @@ reverse(t::Tuple) = revargs(t...)
467467

468468
## specialized reduction ##
469469

470-
# TODO: these definitions cannot yet be combined, since +(x...)
471-
# where x might be any tuple matches too many methods.
472-
# TODO: this is inconsistent with the regular sum in cases where the arguments
473-
# require size promotion to system size.
474-
sum(x::Tuple{Any, Vararg{Any}}) = +(x...)
475-
476-
# NOTE: should remove, but often used on array sizes
477-
# TODO: this is inconsistent with the regular prod in cases where the arguments
478-
# require size promotion to system size.
479470
prod(x::Tuple{}) = 1
480-
prod(x::Tuple{Any, Vararg{Any}}) = *(x...)
471+
# This is consistent with the regular prod because there is no need for size promotion
472+
# if all elements in the tuple are of system size.
473+
# It is defined here separately in order to support bootstrap, because it's needed earlier
474+
# than the general prod definition is available.
475+
prod(x::Tuple{Int, Vararg{Int}}) = *(x...)
481476

482477
all(x::Tuple{}) = true
483478
all(x::Tuple{Bool}) = x[1]

test/tuple.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,24 @@ end
347347
@test prod(()) === 1
348348
@test prod((1,2,3)) === 6
349349

350+
# issue 39182
351+
@test sum((0xe1, 0x1f)) === sum([0xe1, 0x1f])
352+
@test sum((Int8(3),)) === Int(3)
353+
@test sum((UInt8(3),)) === UInt(3)
354+
@test sum((3,)) === Int(3)
355+
@test sum((3.0,)) === 3.0
356+
@test sum(("a",)) == sum(["a"])
357+
@test sum((0xe1, 0x1f), init=0x0) == sum([0xe1, 0x1f], init=0x0)
358+
359+
# issue 39183
360+
@test prod((Int8(100), Int8(100))) === 10000
361+
@test prod((Int8(3),)) === Int(3)
362+
@test prod((UInt8(3),)) === UInt(3)
363+
@test prod((3,)) === Int(3)
364+
@test prod((3.0,)) === 3.0
365+
@test prod(("a",)) == prod(["a"])
366+
@test prod((0xe1, 0x1f), init=0x1) == prod([0xe1, 0x1f], init=0x1)
367+
350368
@testset "all" begin
351369
@test all(()) === true
352370
@test all((false,)) === false

0 commit comments

Comments
 (0)