Skip to content

Commit 8e73313

Browse files
committed
Define get_base
1 parent f044cec commit 8e73313

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

src/polyform.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ function polyform_factors(d, pvar2sym, sym2term)
268268
if ispow(x) && x.impl.exp isa Integer && x.impl.exp > 0
269269
# here we do want to recurse one level, that's why it's wrong to just
270270
# use Fs = Union{typeof(+), typeof(*)} here.
271-
_Pow(PolyForm(x.impl.base, pvar2sym, sym2term), x.impl.exp)
271+
_Pow(PolyForm(get_base(x), pvar2sym, sym2term), x.impl.exp)
272272
else
273273
PolyForm(x, pvar2sym, sym2term)
274274
end
@@ -416,8 +416,8 @@ But it will simplify `(x - 5)^2*(x - 3) / (x - 5)` to `(x - 5)*(x - 3)`.
416416
Has optimized processes for `Mul` and `Pow` terms.
417417
"""
418418
function quick_cancel(d)
419-
if ispow(d) && isdiv(d.impl.base)
420-
return quick_cancel((get_num(d.impl.base)^d.impl.exp) / (get_den(d.impl.base)^d.impl.exp))
419+
if ispow(d) && isdiv(get_base(d))
420+
return quick_cancel((get_num(get_base(d))^d.impl.exp) / (get_den(get_base(d))^d.impl.exp))
421421
elseif ismul(d) && any(isdiv, arguments(d))
422422
return prod(arguments(d))
423423
elseif isdiv(d)
@@ -502,17 +502,17 @@ end
502502
# mul, pow case
503503
function quick_mulpow(x, y)
504504
y.impl.exp isa Number || return (x, y)
505-
if haskey(get_dict(x), y.impl.base)
505+
if haskey(get_dict(x), get_base(y))
506506
d = copy(get_dict(x))
507-
if get_dict(x)[y.impl.base] > y.impl.exp
508-
d[y.impl.base] -= y.impl.exp
507+
if get_dict(x)[get_base(y)] > y.impl.exp
508+
d[get_base(y)] -= y.impl.exp
509509
den = 1
510-
elseif get_dict(x)[y.impl.base] == y.impl.exp
511-
delete!(d, y.impl.base)
510+
elseif get_dict(x)[get_base(y)] == y.impl.exp
511+
delete!(d, get_base(y))
512512
den = 1
513513
else
514-
den = _Pow(symtype(y), y.impl.base, y.impl.exp-d[y.impl.base])
515-
delete!(d, y.impl.base)
514+
den = _Pow(symtype(y), get_base(y), y.impl.exp-d[get_base(y)])
515+
delete!(d, get_base(y))
516516
end
517517
return _Mul(symtype(x), get_coeff(x), d), den
518518
else

src/types.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ function get_den(x::BasicSymbolic)
8484
x.impl.den
8585
end
8686

87+
function get_base(x::BasicSymbolic)
88+
x.impl.base
89+
end
90+
8791
# Same but different error messages
8892
@noinline error_on_type() = error("Internal error: unreachable reached!")
8993
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
@@ -313,7 +317,7 @@ function _isequal(a, b, E)
313317
elseif E === DIV
314318
isequal(get_num(a), get_num(b)) && isequal(get_den(a), get_den(b))
315319
elseif E === POW
316-
isequal(a.impl.exp, b.impl.exp) && isequal(a.impl.base, b.impl.base)
320+
isequal(a.impl.exp, b.impl.exp) && isequal(get_base(a), get_base(b))
317321
elseif E === TERM
318322
a1 = arguments(a)
319323
a2 = arguments(b)
@@ -359,7 +363,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
359363
elseif E === DIV
360364
return hash(get_num(s), hash(get_den(s), salt DIV_SALT))
361365
elseif E === POW
362-
hash(s.impl.exp, hash(s.impl.base, salt POW_SALT))
366+
hash(s.impl.exp, hash(get_base(s), salt POW_SALT))
363367
elseif E === TERM
364368
!iszero(salt) && return hash(hash(s, zero(UInt)), salt)
365369
h = s.hash[]
@@ -562,7 +566,7 @@ function toterm(t::BasicSymbolic{T}) where {T}
562566
elseif E === DIV
563567
_Term(T, /, [get_num(t), get_den(t)])
564568
elseif E === POW
565-
_Term(T, ^, [t.impl.base, t.impl.exp])
569+
_Term(T, ^, [get_base(t), t.impl.exp])
566570
else
567571
error_on_type()
568572
end
@@ -605,7 +609,7 @@ end
605609
function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}())
606610
for x in xs
607611
if ispow(x) && x.impl.exp isa Number
608-
d[x.impl.base] = x.impl.exp + get(d, x.impl.base, 0)
612+
d[get_base(x)] = x.impl.exp + get(d, get_base(x), 0)
609613
elseif x isa Number
610614
coeff *= x
611615
elseif ismul(x)
@@ -629,7 +633,7 @@ function makepow(a, b)
629633
base = a
630634
exp = b
631635
if ispow(a)
632-
base = a.impl.base
636+
base = get_base(a)
633637
exp = a.impl.exp * b
634638
end
635639
base, exp
@@ -1311,7 +1315,7 @@ function *(a::SN, b::SN)
13111315
if b.impl.exp isa Number
13121316
_Mul(mul_t(a, b),
13131317
get_coeff(a),
1314-
_merge(+, get_dict(a), Base.ImmutableDict(b.impl.base => b.impl.exp),
1318+
_merge(+, get_dict(a), Base.ImmutableDict(get_base(b) => b.impl.exp),
13151319
filter = _iszero))
13161320
else
13171321
_Mul(mul_t(a, b), get_coeff(a),

0 commit comments

Comments
 (0)