From faf35f80f567ed82c94a1d7baa9d6812e088e96d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 23 Jul 2024 09:40:55 -0400 Subject: [PATCH 1/2] Make Zero be a subtype of Number --- src/rewrite.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rewrite.jl b/src/rewrite.jl index 811570f..8590ebe 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -28,7 +28,7 @@ macro rewrite(args...) return rewrite_and_return(args[1]; move_factors_into_sums = args[2].args[2]) end -struct Zero end +struct Zero <: Number end # This method is called in various `promote_operation_fallback` methods if one # of the arguments is `::Zero`. @@ -62,7 +62,9 @@ Base.:*(z::Zero, ::Any) = z Base.:*(::Any, z::Zero) = z Base.:*(z::Zero, ::Zero) = z Base.:+(::Zero, x::Any) = x +Base.:+(::Zero, x::Number) = x Base.:+(x::Any, ::Zero) = x +Base.:+(x::Number, ::Zero) = x Base.:+(z::Zero, ::Zero) = z Base.:-(::Zero, x::Any) = -x Base.:-(x::Any, ::Zero) = x From 9848cfcdfee7cbbc50fc932130ff4d379d8a9d33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 23 Jul 2024 10:11:14 -0400 Subject: [PATCH 2/2] Fixes --- src/rewrite.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/rewrite.jl b/src/rewrite.jl index 8590ebe..4b3b338 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -59,7 +59,9 @@ broadcast!!(::typeof(add_mul), ::Zero, x, y) = x * y # Needed in `@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)` Base.:*(z::Zero, ::Any) = z +Base.:*(z::Zero, ::Number) = z Base.:*(::Any, z::Zero) = z +Base.:*(::Number, z::Zero) = z Base.:*(z::Zero, ::Zero) = z Base.:+(::Zero, x::Any) = x Base.:+(::Zero, x::Number) = x @@ -67,7 +69,9 @@ Base.:+(x::Any, ::Zero) = x Base.:+(x::Number, ::Zero) = x Base.:+(z::Zero, ::Zero) = z Base.:-(::Zero, x::Any) = -x +Base.:-(::Zero, x::Number) = -x Base.:-(x::Any, ::Zero) = x +Base.:-(x::Number, ::Zero) = x Base.:-(z::Zero, ::Zero) = z Base.:-(z::Zero) = z Base.:+(z::Zero) = z @@ -81,6 +85,14 @@ function Base.:/(z::Zero, x::Any) end end +function Base.:/(z::Zero, x::Number) + if iszero(x) + throw(DivideError()) + else + return z + end +end + # These methods are used to provide an efficient implementation for the common # case like `x^2 * sum(f for i in 1:0)`, which lowers to # `_MA.operate!!(*, x^2, _MA.Zero())`. We don't need the method with reversed