Skip to content

Commit 53ca3b8

Browse files
committed
Define get_dict function
1 parent 206d21c commit 53ca3b8

File tree

4 files changed

+29
-25
lines changed

4 files changed

+29
-25
lines changed

src/inspect.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ function AbstractTrees.nodevalue(x::BasicSymbolic)
1212
string(x.impl.val)
1313
elseif isadd(x)
1414
string(exprtype(x),
15-
(scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in x.impl.dict)))
15+
(scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in get_dict(x))))
1616
elseif ismul(x)
1717
string(exprtype(x),
18-
(scalar = get_coeff(x), powers = Tuple(k => v for (k, v) in x.impl.dict)))
18+
(scalar = get_coeff(x), powers = Tuple(k => v for (k, v) in get_dict(x))))
1919
elseif isdiv(x) || ispow(x)
2020
string(exprtype(x))
2121
else

src/polyform.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -502,12 +502,12 @@ end
502502
# mul, pow case
503503
function quick_mulpow(x, y)
504504
y.impl.exp isa Number || return (x, y)
505-
if haskey(x.impl.dict, y.impl.base)
506-
d = copy(x.impl.dict)
507-
if x.impl.dict[y.impl.base] > y.impl.exp
505+
if haskey(get_dict(x), y.impl.base)
506+
d = copy(get_dict(x))
507+
if get_dict(x)[y.impl.base] > y.impl.exp
508508
d[y.impl.base] -= y.impl.exp
509509
den = 1
510-
elseif x.impl.dict[y.impl.base] == y.impl.exp
510+
elseif get_dict(x)[y.impl.base] == y.impl.exp
511511
delete!(d, y.impl.base)
512512
den = 1
513513
else

src/types.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ function get_coeff(x::BasicSymbolic)
7272
x.impl.coeff
7373
end
7474

75+
function get_dict(x::BasicSymbolic)
76+
x.impl.dict
77+
end
78+
7579
# Same but different error messages
7680
@noinline error_on_type() = error("Internal error: unreachable reached!")
7781
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
@@ -297,7 +301,7 @@ function _isequal(a, b, E)
297301
if E === SYM
298302
nameof(a) === nameof(b)
299303
elseif E === ADD || E === MUL
300-
coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(a.impl.dict, b.impl.dict)
304+
coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(get_dict(a), get_dict(b))
301305
elseif E === DIV
302306
isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den)
303307
elseif E === POW
@@ -341,7 +345,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
341345
h = s.hash[]
342346
!iszero(h) && return h
343347
hashoffset = isadd(s) ? ADD_SALT : SUB_SALT
344-
h′ = hash(hashoffset, hash(get_coeff(s), hash(s.impl.dict, salt)))
348+
h′ = hash(hashoffset, hash(get_coeff(s), hash(get_dict(s), salt)))
345349
s.hash[] = h′
346350
return h′
347351
elseif E === DIV
@@ -461,7 +465,7 @@ function maybe_intcoeff(x)
461465
if ismul(x)
462466
coeff = get_coeff(x)
463467
if coeff isa Rational && isone(denominator(coeff))
464-
_Mul(symtype(x), coeff.num, x.impl.dict; metadata = x.metadata)
468+
_Mul(symtype(x), coeff.num, get_dict(x); metadata = x.metadata)
465469
else
466470
x
467471
end
@@ -542,7 +546,7 @@ function toterm(t::BasicSymbolic{T}) where {T}
542546
elseif E === ADD || E === MUL
543547
args = BasicSymbolic[]
544548
push!(args, get_coeff(t))
545-
for (k, coeff) in t.impl.dict
549+
for (k, coeff) in get_dict(t)
546550
push!(
547551
args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k]))
548552
end
@@ -567,15 +571,15 @@ function makeadd(sign, coeff, xs...)
567571
for x in xs
568572
if isadd(x)
569573
coeff += get_coeff(x)
570-
_merge!(+, d, x.impl.dict, filter = _iszero)
574+
_merge!(+, d, get_dict(x), filter = _iszero)
571575
continue
572576
end
573577
if x isa Number
574578
coeff += x
575579
continue
576580
end
577581
if ismul(x)
578-
k = _Mul(symtype(x), 1, x.impl.dict)
582+
k = _Mul(symtype(x), 1, get_dict(x))
579583
v = sign * get_coeff(x) + get(d, k, 0)
580584
else
581585
k = x
@@ -598,7 +602,7 @@ function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}())
598602
coeff *= x
599603
elseif ismul(x)
600604
coeff *= get_coeff(x)
601-
_merge!(+, d, x.impl.dict, filter = _iszero)
605+
_merge!(+, d, get_dict(x), filter = _iszero)
602606
else
603607
v = 1 + get(d, x, 0)
604608
if _iszero(v)
@@ -1223,10 +1227,10 @@ function +(a::SN, b::SN)
12231227
!issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata
12241228
if isadd(a) && isadd(b)
12251229
return _Add(
1226-
add_t(a, b), get_coeff(a) + get_coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero))
1230+
add_t(a, b), get_coeff(a) + get_coeff(b), _merge(+, get_dict(a), get_dict(b), filter = _iszero))
12271231
elseif isadd(a)
12281232
coeff, dict = makeadd(1, 0, b)
1229-
return _Add(add_t(a, b), get_coeff(a) + coeff, _merge(+, a.impl.dict, dict, filter = _iszero))
1233+
return _Add(add_t(a, b), get_coeff(a) + coeff, _merge(+, get_dict(a), dict, filter = _iszero))
12301234
elseif isadd(b)
12311235
return b + a
12321236
end
@@ -1240,7 +1244,7 @@ function +(a::Number, b::SN)
12401244
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
12411245
iszero(a) && return b
12421246
if isadd(b)
1243-
_Add(add_t(a, b), a + get_coeff(b), b.impl.dict)
1247+
_Add(add_t(a, b), a + get_coeff(b), get_dict(b))
12441248
else
12451249
_Add(add_t(a, b), makeadd(1, a, b)...)
12461250
end
@@ -1258,15 +1262,15 @@ function -(a::SN)
12581262
return term(-, a)
12591263
end
12601264
if isadd(a)
1261-
_Add(sub_t(a), -get_coeff(a), mapvalues((_, v) -> -v, a.impl.dict))
1265+
_Add(sub_t(a), -get_coeff(a), mapvalues((_, v) -> -v, get_dict(a)))
12621266
else
12631267
_Add(sub_t(a), makeadd(-1, 0, a)...)
12641268
end
12651269
end
12661270
function -(a::SN, b::SN)
12671271
(!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b)
12681272
if isadd(a) && isadd(b)
1269-
_Add(sub_t(a, b), get_coeff(a) - get_coeff(b), _merge(-, a.impl.dict, b.impl.dict, filter = _iszero))
1273+
_Add(sub_t(a, b), get_coeff(a) - get_coeff(b), _merge(-, get_dict(a), get_dict(b), filter = _iszero))
12701274
else
12711275
a + (-b)
12721276
end
@@ -1294,16 +1298,16 @@ function *(a::SN, b::SN)
12941298
_Div(a * b.impl.num, b.impl.den)
12951299
elseif ismul(a) && ismul(b)
12961300
_Mul(mul_t(a, b), get_coeff(a) * get_coeff(b),
1297-
_merge(+, a.impl.dict, b.impl.dict, filter = _iszero))
1301+
_merge(+, get_dict(a), get_dict(b), filter = _iszero))
12981302
elseif ismul(a) && ispow(b)
12991303
if b.impl.exp isa Number
13001304
_Mul(mul_t(a, b),
13011305
get_coeff(a),
1302-
_merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp),
1306+
_merge(+, get_dict(a), Base.ImmutableDict(b.impl.base => b.impl.exp),
13031307
filter = _iszero))
13041308
else
13051309
_Mul(mul_t(a, b), get_coeff(a),
1306-
_merge(+, a.impl.dict, Base.ImmutableDict(b => 1), filter = _iszero))
1310+
_merge(+, get_dict(a), Base.ImmutableDict(b => 1), filter = _iszero))
13071311
end
13081312
elseif ispow(a) && ismul(b)
13091313
b * a
@@ -1326,7 +1330,7 @@ function *(a::Number, b::SN)
13261330
# -1(a+b) -> -a - b
13271331
T = promote_symtype(+, typeof(a), symtype(b))
13281332
_Add(T, get_coeff(b) * a,
1329-
Dict{BasicSymbolic, Any}(k => v * a for (k, v) in b.impl.dict))
1333+
Dict{BasicSymbolic, Any}(k => v * a for (k, v) in get_dict(b)))
13301334
else
13311335
_Mul(mul_t(a, b), makemul(a, b)...)
13321336
end
@@ -1352,7 +1356,7 @@ function ^(a::SN, b)
13521356
elseif ismul(a) && b isa Number
13531357
coeff = unstable_pow(get_coeff(a), b)
13541358
_Mul(promote_symtype(^, symtype(a), symtype(b)),
1355-
coeff, mapvalues((k, v) -> b * v, a.impl.dict))
1359+
coeff, mapvalues((k, v) -> b * v, get_dict(a)))
13561360
else
13571361
_Pow(a, b)
13581362
end

test/basics.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm,
2-
BasicSymbolic, term, get_name, get_coeff
2+
BasicSymbolic, term, get_name, get_coeff, get_dict
33
using SymbolicUtils
44
using IfElse: ifelse
55
using Setfield
@@ -234,7 +234,7 @@ end
234234

235235
@testset "maketerm" begin
236236
@syms a b c
237-
@test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).impl.dict, Dict(a=>1,b=>1,c=>1))
237+
@test isequal(get_dict(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing)), Dict(a=>1,b=>1,c=>1))
238238
@test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b)
239239

240240
# test that maketerm doesn't hard-code BasicSymbolic subtype

0 commit comments

Comments
 (0)