Skip to content

Commit ee4183a

Browse files
authored
Merge pull request #547 from JuliaSymbolics/s/fix-metadata-prop
prevent canonicalization where arguments have metadata
2 parents 6505233 + ba61640 commit ee4183a

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

src/types.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,19 @@ function hasmetadata(s::Symbolic, ctx)
576576
metadata(s) isa AbstractDict && haskey(metadata(s), ctx)
577577
end
578578

579+
function issafecanon(f, s)
580+
if isnothing(metadata(s)) || issym(s)
581+
return true
582+
else
583+
_issafecanon(f, s)
584+
end
585+
end
586+
_issafecanon(::typeof(*), s) = !istree(s) || !(operation(s) in (+,*,^))
587+
_issafecanon(::typeof(+), s) = !istree(s) || !(operation(s) in (+,*))
588+
_issafecanon(::typeof(^), s) = !istree(s) || !(operation(s) in (*, ^))
589+
590+
issafecanon(f, ss...) = all(x->issafecanon(f, x), ss)
591+
579592
function getmetadata(s::Symbolic, ctx)
580593
md = metadata(s)
581594
if md isa AbstractDict
@@ -1016,6 +1029,7 @@ sub_t(a) = promote_symtype(-, symtype(a))
10161029

10171030
import Base: (+), (-), (*), (//), (/), (\), (^)
10181031
function +(a::SN, b::SN)
1032+
!issafecanon(+, a,b) && return term(+, a, b) # Don't flatten if args have metadata
10191033
if isadd(a) && isadd(b)
10201034
return Add(add_t(a,b),
10211035
a.coeff + b.coeff,
@@ -1031,6 +1045,7 @@ function +(a::SN, b::SN)
10311045
end
10321046

10331047
function +(a::Number, b::SN)
1048+
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
10341049
iszero(a) && return b
10351050
if isadd(b)
10361051
Add(add_t(a,b), a + b.coeff, b.dict)
@@ -1044,11 +1059,13 @@ end
10441059
+(a::SN) = a
10451060

10461061
function -(a::SN)
1062+
!issafecanon(*, a) && return term(-, a)
10471063
isadd(a) ? Add(sub_t(a), -a.coeff, mapvalues((_,v) -> -v, a.dict)) :
10481064
Add(sub_t(a), makeadd(-1, 0, a)...)
10491065
end
10501066

10511067
function -(a::SN, b::SN)
1068+
(!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b)
10521069
isadd(a) && isadd(b) ? Add(sub_t(a,b),
10531070
a.coeff - b.coeff,
10541071
_merge(-, a.dict,
@@ -1067,6 +1084,7 @@ mul_t(a) = promote_symtype(*, symtype(a))
10671084

10681085
function *(a::SN, b::SN)
10691086
# Always make sure Div wraps Mul
1087+
!issafecanon(*, a, b) && return term(*, a, b)
10701088
if isdiv(a) && isdiv(b)
10711089
Div(a.num * b.num, a.den * b.den)
10721090
elseif isdiv(a)
@@ -1093,6 +1111,7 @@ function *(a::SN, b::SN)
10931111
end
10941112

10951113
function *(a::Number, b::SN)
1114+
!issafecanon(*, b) && return term(*, a, b)
10961115
if iszero(a)
10971116
a
10981117
elseif isone(a)
@@ -1132,6 +1151,7 @@ end
11321151
###
11331152

11341153
function ^(a::SN, b)
1154+
!issafecanon(^, a,b) && return Pow(a, b)
11351155
if b isa Number && iszero(b)
11361156
# fast path
11371157
1

0 commit comments

Comments
 (0)