Skip to content

Commit 194aa85

Browse files
refactor: don't use NaNMath.pow in codegen rewriters if integral exponent
1 parent da3bd6d commit 194aa85

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

src/code.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,14 @@ end
140140

141141
function function_to_expr(op::typeof(^), O, st)
142142
args = arguments(O)
143-
if length(args) == 2 && args[2] isa Real && args[2] < 0
144-
ex = args[1]
145-
if args[2] == -1
146-
return toexpr(Term(inv, Any[ex]), st)
147-
else
148-
args = Any[Term(inv, Any[ex]), -args[2]]
149-
op = get(st.rewrites, :nanmath, false) ? op : NaNMath.pow
150-
return toexpr(Term(op, args), st)
151-
end
143+
if args[2] isa Real && args[2] < 0
144+
args[1] = Term(inv, Any[args[1]])
145+
args[2] = -args[2]
146+
end
147+
if get(st.rewrites, :nanmath, false) === true && !(args[2] isa Integer)
148+
op = NaNMath.pow
152149
end
153-
get(st.rewrites, :nanmath, false) === true || return nothing
154-
return toexpr(Term(NaNMath.pow, args), st)
150+
return toexpr(Term(op, args), st)
155151
end
156152

157153
function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)

test/code.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,19 @@ nanmath_st.rewrites[:nanmath] = true
100100
@test toexpr(NaNMath.pow(a, b), nanmath_st) == :($(NaNMath.pow)(a, b))
101101

102102
@test toexpr(a^2) == :($(^)(a, 2))
103-
@test toexpr(a^2, nanmath_st) == :($(NaNMath.pow)(a, 2))
103+
@test toexpr(a^2, nanmath_st) == :($(^)(a, 2))
104104
@test toexpr(NaNMath.pow(a, 2)) == :($(^)(a, 2))
105-
@test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(NaNMath.pow)(a, 2))
105+
@test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(^)(a, 2))
106106

107107
@test toexpr(a^-1) == :($(/)(1, a))
108108
@test toexpr(a^-1, nanmath_st) == :($(/)(1, a))
109109
@test toexpr(NaNMath.pow(a, -1)) == :($(inv)(a))
110110
@test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(inv)(a))
111111

112112
@test toexpr(a^-2) == :($(/)(1, $(^)(a, 2)))
113-
@test toexpr(a^-2, nanmath_st) == :($(/)(1, $(NaNMath.pow)(a, 2)))
114-
@test toexpr(NaNMath.pow(a, -2)) == :($(NaNMath.pow)($(inv)(a), 2))
115-
@test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(NaNMath.pow)($(inv)(a), 2))
113+
@test toexpr(a^-2, nanmath_st) == :($(/)(1, $(^)(a, 2)))
114+
@test toexpr(NaNMath.pow(a, -2)) == :($(^)($(inv)(a), 2))
115+
@test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(^)($(inv)(a), 2))
116116

117117
f = GlobalRef(NaNMath, :sin)
118118
test_repr(toexpr(LiteralExpr(:(let x=1, y=2

0 commit comments

Comments
 (0)