Skip to content

Commit 226dd4f

Browse files
Merge pull request #671 from devmotion/dw/nanmath
Rewrite `^` with `NaNMath.pow` in nanmath-mode
2 parents 59e3aa6 + 862a4c5 commit 226dd4f

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicUtils"
22
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
33
authors = ["Shashi Gowda"]
4-
version = "3.7.2"
4+
version = "3.7.3"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/code.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,20 @@ function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
138138
end
139139
end
140140

141-
function function_to_expr(::typeof(^), O, st)
141+
function function_to_expr(op::typeof(^), O, st)
142142
args = arguments(O)
143143
if length(args) == 2 && args[2] isa Real && args[2] < 0
144144
ex = args[1]
145145
if args[2] == -1
146146
return toexpr(Term(inv, Any[ex]), st)
147147
else
148-
return toexpr(Term(^, Any[Term(inv, Any[ex]), -args[2]]), st)
148+
args = Any[Term(inv, Any[ex]), -args[2]]
149+
op = get(st.rewrites, :nanmath, false) ? op : NaNMath.pow
150+
return toexpr(Term(op, args), st)
149151
end
150152
end
151-
return nothing
153+
get(st.rewrites, :nanmath, false) === true || return nothing
154+
return toexpr(Term(NaNMath.pow, args), st)
152155
end
153156

154157
function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)

test/code.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ nanmath_st.rewrites[:nanmath] = true
2020
@test toexpr(a*b*c*d*e) == :($(*)($(*)($(*)($(*)(a, b), c), d), e))
2121
@test toexpr(a+b+c+d+e) == :($(+)($(+)($(+)($(+)(a, b), c), d), e))
2222
@test toexpr(a+b) == :($(+)(a, b))
23-
@test toexpr(a^b) == :($(^)(a, b))
24-
@test toexpr(a^2) == :($(^)(a, 2))
25-
@test toexpr(a^-2) == :($(/)(1, $(^)(a, 2)))
2623
@test toexpr(x(t)+y(t)) == :($(+)(x(t), y(t)))
2724
@test toexpr(x(t)+y(t)+x(t+1)) == :($(+)($(+)(x(t), y(t)), x($(+)(1, t))))
2825
s = LazyState()
@@ -87,8 +84,35 @@ nanmath_st.rewrites[:nanmath] = true
8784
end)
8885
@test toexpr(SetArray(true, a, [x(t), AtIndex(9, b), c])).head == :macrocall
8986

87+
for fname in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10, :log1p, :sqrt)
88+
f = getproperty(Base, fname)
89+
@test toexpr(f(a)) == :($f(a))
90+
@test toexpr(f(a), nanmath_st) == :($(GlobalRef(NaNMath, fname))(a))
9091

92+
nanmath_f = getproperty(NaNMath, fname)
93+
@test toexpr(nanmath_f(a)) == :($nanmath_f(a))
94+
@test toexpr(nanmath_f(a), nanmath_st) == :($nanmath_f(a))
95+
end
96+
97+
@test toexpr(a^b) == :($(^)(a, b))
98+
@test toexpr(a^b, nanmath_st) == :($(NaNMath.pow)(a, b))
9199
@test toexpr(NaNMath.pow(a, b)) == :($(NaNMath.pow)(a, b))
100+
@test toexpr(NaNMath.pow(a, b), nanmath_st) == :($(NaNMath.pow)(a, b))
101+
102+
@test toexpr(a^2) == :($(^)(a, 2))
103+
@test toexpr(a^2, nanmath_st) == :($(NaNMath.pow)(a, 2))
104+
@test toexpr(NaNMath.pow(a, 2)) == :($(NaNMath.pow)(a, 2))
105+
@test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(NaNMath.pow)(a, 2))
106+
107+
@test toexpr(a^-1) == :($(/)(1, a))
108+
@test toexpr(a^-1, nanmath_st) == :($(/)(1, a))
109+
@test toexpr(NaNMath.pow(a, -1)) == :($(NaNMath.pow)(a, -1))
110+
@test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(NaNMath.pow)(a, -1))
111+
112+
@test toexpr(a^-2) == :($(/)(1, $(^)(a, 2)))
113+
@test toexpr(a^-2, nanmath_st) == :($(/)(1, $(NaNMath.pow)(a, 2)))
114+
@test toexpr(NaNMath.pow(a, -2)) == :($(NaNMath.pow)(a, -2))
115+
@test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(NaNMath.pow)(a, -2))
92116

93117
f = GlobalRef(NaNMath, :sin)
94118
test_repr(toexpr(LiteralExpr(:(let x=1, y=2

0 commit comments

Comments
 (0)