Skip to content

Commit f34e774

Browse files
authored
Merge pull request #192 from cdsousa/master
Fix CSE for scalar and array cases
2 parents 5a9fabd + ea4ce47 commit f34e774

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

src/simplify.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
## from http://docs.sympy.org/0.7.2/modules/simplify/simplify.html
33

44
## simple methods (x, args) -> y (y coercion happens via PyCall)
5-
simplify_sympy_meths = (:collect, :rcollect, :separate,
5+
simplify_sympy_meths = (:collect, :rcollect, :separate,
66
:separatevars,
77
:radsimp, :ratsimp, :trigsimp, :besselsimp,
88
:powsimp, :combsimp, :hypersimp,
@@ -11,7 +11,7 @@ simplify_sympy_meths = (:collect, :rcollect, :separate,
1111
:posify, :powdenest, :sqrtdenest,
1212
:logcombine, :hyperexpand)
1313

14-
14+
1515
expand_sympy_meths = (:expand_trig,
1616
:expand_power_base, :expand_power_exp,
1717
:expand_log,
@@ -32,13 +32,9 @@ Example: (from man page)
3232
cse(((w + x + y + z)*(w + y + z))/(w + x)^3), ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1^3]) # tuple of replacements and reduced expressions.
3333
3434
""" ->
35-
cse{T<:SymbolicObject}(ex::T, args...; kwargs...) = sympy_meth(:cse, ex, args...; kwargs...)
36-
cse{T<:SymbolicObject}(ex::Vector{T}, args...; kwargs...) = sympy_meth(:cse, ex, args...; kwargs...)
37-
38-
function cse{T<:SymbolicObject, N}(ex::AbstractArray{T, N}, args...; kwargs...)
39-
a,b = cse(ex[:], args...; kwargs...)
40-
bb = convert(Array{Sym}, reshape(b, size(ex)))
41-
a, bb
35+
function cse{T<:SymbolicObject}(ex::Union{T, AbstractArray{T}}, args...; kwargs...)
36+
a, b = sympy_meth(:cse, ex, args...; kwargs...)
37+
a, oftype(ex, b[1])
4238
end
4339
export(cse)
4440

test/tests.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ end
178178
@vars a b; eqs = (a*x+2y-3, 2b*x + 3y - 4)
179179
as = linsolve(eqs, x, y)
180180
@test length(elements(as)) == 1
181-
181+
182182
## limits
183183
@test limit(x -> sin(x)/x, 0) == 1
184184
@test limit(sin(x)/x, x, 0) |> float == 1
@@ -475,6 +475,12 @@ end
475475
@test round(y, 5) == 0
476476
@test round(y, 16) != 0
477477

478+
## test cse output
479+
@test cse(x) == (Any[], x)
480+
@test cse([x]) == (Any[], [x])
481+
@test cse([x, x]) == (Any[], [x, x])
482+
@test cse([x x; x x]) == (Any[], [x x; x x])
483+
478484
## sympy"..."(...)
479485
@vars x
480486
@test sympy"sin"(1) == sin(Sym(1))

0 commit comments

Comments
 (0)