From da3bd6dc99e39b3b10e0abff248f3457a86cbbc9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Jan 2025 14:59:35 +0530 Subject: [PATCH 1/3] test: update codegen tests --- test/code.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/code.jl b/test/code.jl index 0e25437a1..3122bed15 100644 --- a/test/code.jl +++ b/test/code.jl @@ -101,18 +101,18 @@ nanmath_st.rewrites[:nanmath] = true @test toexpr(a^2) == :($(^)(a, 2)) @test toexpr(a^2, nanmath_st) == :($(NaNMath.pow)(a, 2)) - @test toexpr(NaNMath.pow(a, 2)) == :($(NaNMath.pow)(a, 2)) + @test toexpr(NaNMath.pow(a, 2)) == :($(^)(a, 2)) @test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(NaNMath.pow)(a, 2)) @test toexpr(a^-1) == :($(/)(1, a)) @test toexpr(a^-1, nanmath_st) == :($(/)(1, a)) - @test toexpr(NaNMath.pow(a, -1)) == :($(NaNMath.pow)(a, -1)) - @test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(NaNMath.pow)(a, -1)) + @test toexpr(NaNMath.pow(a, -1)) == :($(inv)(a)) + @test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(inv)(a)) @test toexpr(a^-2) == :($(/)(1, $(^)(a, 2))) @test toexpr(a^-2, nanmath_st) == :($(/)(1, $(NaNMath.pow)(a, 2))) - @test toexpr(NaNMath.pow(a, -2)) == :($(NaNMath.pow)(a, -2)) - @test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(NaNMath.pow)(a, -2)) + @test toexpr(NaNMath.pow(a, -2)) == :($(NaNMath.pow)($(inv)(a), 2)) + @test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(NaNMath.pow)($(inv)(a), 2)) f = GlobalRef(NaNMath, :sin) test_repr(toexpr(LiteralExpr(:(let x=1, y=2 From 194aa85208547c9b4a4dad124f5c08bc6bc60a25 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Jan 2025 18:14:59 +0530 Subject: [PATCH 2/3] refactor: don't use `NaNMath.pow` in codegen rewriters if integral exponent --- src/code.jl | 18 +++++++----------- test/code.jl | 10 +++++----- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/code.jl b/src/code.jl index 8b1953b20..754c0750d 100644 --- a/src/code.jl +++ b/src/code.jl @@ -140,18 +140,14 @@ end function function_to_expr(op::typeof(^), O, st) args = arguments(O) - if length(args) == 2 && args[2] isa Real && args[2] < 0 - ex = args[1] - if args[2] == -1 - return toexpr(Term(inv, Any[ex]), st) - else - args = Any[Term(inv, Any[ex]), -args[2]] - op = get(st.rewrites, :nanmath, false) ? op : NaNMath.pow - return toexpr(Term(op, args), st) - end + if args[2] isa Real && args[2] < 0 + args[1] = Term(inv, Any[args[1]]) + args[2] = -args[2] + end + if get(st.rewrites, :nanmath, false) === true && !(args[2] isa Integer) + op = NaNMath.pow end - get(st.rewrites, :nanmath, false) === true || return nothing - return toexpr(Term(NaNMath.pow, args), st) + return toexpr(Term(op, args), st) end function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st) diff --git a/test/code.jl b/test/code.jl index 3122bed15..918ef9dcc 100644 --- a/test/code.jl +++ b/test/code.jl @@ -100,9 +100,9 @@ nanmath_st.rewrites[:nanmath] = true @test toexpr(NaNMath.pow(a, b), nanmath_st) == :($(NaNMath.pow)(a, b)) @test toexpr(a^2) == :($(^)(a, 2)) - @test toexpr(a^2, nanmath_st) == :($(NaNMath.pow)(a, 2)) + @test toexpr(a^2, nanmath_st) == :($(^)(a, 2)) @test toexpr(NaNMath.pow(a, 2)) == :($(^)(a, 2)) - @test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(NaNMath.pow)(a, 2)) + @test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(^)(a, 2)) @test toexpr(a^-1) == :($(/)(1, a)) @test toexpr(a^-1, nanmath_st) == :($(/)(1, a)) @@ -110,9 +110,9 @@ nanmath_st.rewrites[:nanmath] = true @test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(inv)(a)) @test toexpr(a^-2) == :($(/)(1, $(^)(a, 2))) - @test toexpr(a^-2, nanmath_st) == :($(/)(1, $(NaNMath.pow)(a, 2))) - @test toexpr(NaNMath.pow(a, -2)) == :($(NaNMath.pow)($(inv)(a), 2)) - @test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(NaNMath.pow)($(inv)(a), 2)) + @test toexpr(a^-2, nanmath_st) == :($(/)(1, $(^)(a, 2))) + @test toexpr(NaNMath.pow(a, -2)) == :($(^)($(inv)(a), 2)) + @test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(^)($(inv)(a), 2)) f = GlobalRef(NaNMath, :sin) test_repr(toexpr(LiteralExpr(:(let x=1, y=2 From f5611aa619b1a6d47caab6abd58bb2b65daef18a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Jan 2025 13:45:21 +0530 Subject: [PATCH 3/3] fix: fix stack overflow in `function_to_expr` --- src/code.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/code.jl b/src/code.jl index 754c0750d..774c26931 100644 --- a/src/code.jl +++ b/src/code.jl @@ -144,10 +144,14 @@ function function_to_expr(op::typeof(^), O, st) args[1] = Term(inv, Any[args[1]]) args[2] = -args[2] end + if isequal(args[2], 1) + return toexpr(args[1], st) + end if get(st.rewrites, :nanmath, false) === true && !(args[2] isa Integer) op = NaNMath.pow + return toexpr(Term(op, args), st) end - return toexpr(Term(op, args), st) + return nothing end function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)