Skip to content

Commit b0c5de1

Browse files
committed
Define get_val
1 parent f7a77d6 commit b0c5de1

File tree

10 files changed

+38
-34
lines changed

10 files changed

+38
-34
lines changed

src/code.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
99
import ..SymbolicUtils
1010
import ..SymbolicUtils.Rewriters
1111
import SymbolicUtils: @matchable, BasicSymbolic, _Sym, Term, iscall, operation, arguments, issym,
12-
isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm
12+
isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm, get_val
1313
import SymbolicIndexingInterface: symbolic_type, NotSymbolic
1414

1515
##== state management ==##
@@ -183,7 +183,7 @@ function toexpr(O, st)
183183
O = substitute_name(O, st)
184184
return issym(O) ? nameof(O) : toexpr(O, st)
185185
elseif isconst(O)
186-
return toexpr(O.impl.val, st)
186+
return toexpr(get_val(O), st)
187187
end
188188
O = substitute_name(O, st)
189189

src/inspect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function AbstractTrees.nodevalue(x::BasicSymbolic)
99
str = if issym(x)
1010
string(exprtype(x), "(", x, ")")
1111
elseif isconst(x)
12-
string(x.impl.val)
12+
string(get_val(x))
1313
elseif isadd(x)
1414
string(exprtype(x),
1515
(scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in get_dict(x))))

src/matchers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
function matcher(val::Any)
99
if isconst(val)
10-
slot = val.impl.val
10+
slot = get_val(val)
1111
return matcher(slot)
1212
elseif iscall(val)
1313
return term_matcher(val)
@@ -16,7 +16,7 @@ function matcher(val::Any)
1616
if islist(data)
1717
cd = car(data)
1818
if isconst(cd)
19-
cd = cd.impl.val
19+
cd = get_val(cd)
2020
end
2121
if isequal(cd, val)
2222
return next(bindings, 1)

src/methods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ for f in [!, ~]
190190
promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool
191191
function (::$(typeof(f)))(s::Symbolic{Bool})
192192
if isconst(s)
193-
s = s.impl.val
193+
s = get_val(s)
194194
return !s
195195
end
196196
_Term(Bool, !, [s])

src/ordering.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function get_degrees(expr)
2727
elseif iscall(expr)
2828
op = operation(expr)
2929
args = sorted_arguments(expr)
30-
if op == (^) && (args[2] isa Number || (isconst(args[2]) && args[2].impl.val isa Number))
30+
if op == (^) && (args[2] isa Number || (isconst(args[2]) && get_val(args[2]) isa Number))
3131
return map(get_degrees(args[1])) do (base, pow)
3232
(base => pow * args[2])
3333
end
@@ -81,11 +81,11 @@ end
8181
function <(a::BasicSymbolic, b::BasicSymbolic)
8282
aisconst = isconst(a)
8383
if aisconst
84-
a = a.impl.val
84+
a = get_val(a)
8585
end
8686
bisconst = isconst(b)
8787
if bisconst
88-
b = b.impl.val
88+
b = get_val(b)
8989
end
9090
if aisconst || bisconst
9191
return a <ₑ b

src/polyform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ _isone(p::PolyForm) = isone(p.p)
9696

9797
function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
9898
if isconst(x)
99-
x = x.impl.val
99+
x = get_val(x)
100100
end
101101
if x isa Number
102102
return x

src/substitute.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function substitute(expr, dict; fold=true)
2323
args = map(arguments(expr)) do x
2424
x′ = substitute(x, dict; fold=fold)
2525
if isconst(x)
26-
x′ = x′.impl.val
26+
x′ = get_val(x′)
2727
end
2828
canfold = canfold && !(x′ isa Symbolic)
2929
x′
@@ -58,7 +58,7 @@ function _occursin(needle, haystack)
5858
args = arguments(haystack)
5959
for arg in args
6060
if isconst(arg)
61-
arg = arg.impl.val
61+
arg = get_val(arg)
6262
end
6363
if needle isa Integer || needle isa AbstractFloat
6464
isequal(needle, arg) && return true

src/types.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ function get_exp(x::BasicSymbolic)
9292
x.impl.exp
9393
end
9494

95+
function get_val(x::BasicSymbolic)
96+
x.impl.val
97+
end
98+
9599
# Same but different error messages
96100
@noinline error_on_type() = error("Internal error: unreachable reached!")
97101
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
@@ -327,7 +331,7 @@ function _isequal(a, b, E)
327331
a2 = arguments(b)
328332
isequal(operation(a), operation(b)) && _allarequal(a1, a2)
329333
elseif E === CONST
330-
isequal(a.impl.val, b.impl.val)
334+
isequal(get_val(a), get_val(b))
331335
else
332336
error_on_type()
333337
end
@@ -378,7 +382,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
378382
s.hash[] = h′
379383
return h′
380384
elseif E === CONST
381-
return hash(s.impl.val, salt COS_SALT)
385+
return hash(get_val(s), salt COS_SALT)
382386
else
383387
error_on_type()
384388
end
@@ -452,14 +456,14 @@ end
452456

453457
function _iszero(x::BasicSymbolic)
454458
@match x.impl begin
455-
Const(_...) => iszero(x.impl.val)
459+
Const(_...) => iszero(get_val(x))
456460
_ => false
457461
end
458462
end
459463

460464
function _isone(x::BasicSymbolic)
461465
@match x.impl begin
462-
Const(_...) => isone(x.impl.val)
466+
Const(_...) => isone(get_val(x))
463467
_ => false
464468
end
465469
end
@@ -833,7 +837,7 @@ const show_simplified = Ref(false)
833837
isnegative(t::Real) = t < 0
834838
function isnegative(t)
835839
if isconst(t)
836-
val = t.impl.val
840+
val = get_val(t)
837841
return isnegative(val)
838842
end
839843
if iscall(t) && operation(t) === (*)
@@ -872,7 +876,7 @@ function remove_minus(t)
872876
args = arguments(t)
873877
arg1 = args[1]
874878
if isconst(arg1)
875-
arg1 = arg1.impl.val
879+
arg1 = get_val(arg1)
876880
end
877881
@assert arg1 < 0
878882
Any[-arg1, args[2:end]...]
@@ -911,14 +915,14 @@ end
911915

912916
function show_mul(io, args)
913917
if isconst(args)
914-
print(io, args.impl.val)
918+
print(io, get_val(args))
915919
return
916920
end
917921
length(args) == 1 && return print_arg(io, *, args[1])
918922

919923
arg1 = args[1]
920924
if isconst(arg1)
921-
arg1 = arg1.impl.val
925+
arg1 = get_val(arg1)
922926
end
923927

924928
minus = arg1 isa Number && arg1 == -1
@@ -930,7 +934,7 @@ function show_mul(io, args)
930934

931935
nostar = minus || unit ||
932936
(!paren_scalar && arg1 isa Number &&
933-
!(isconst(args[2]) && args[2].impl.val isa Number))
937+
!(isconst(args[2]) && get_val(args[2]) isa Number))
934938

935939
for (i, t) in enumerate(args)
936940
if i != 1
@@ -1021,7 +1025,7 @@ showraw(t) = showraw(stdout, t)
10211025
function Base.show(io::IO, v::BasicSymbolic)
10221026
@match v.impl begin
10231027
Sym(_...) => Base.show_unquoted(io, get_name(v))
1024-
Const(_...) => print(io, v.impl.val)
1028+
Const(_...) => print(io, get_val(v))
10251029
_ => show_term(io, v)
10261030
end
10271031
end
@@ -1235,10 +1239,10 @@ sub_t(a) = promote_symtype(-, symtype(a))
12351239
import Base: (+), (-), (*), (//), (/), (\), (^)
12361240
function +(a::SN, b::SN)
12371241
if isconst(a)
1238-
return a.impl.val + b
1242+
return get_val(a) + b
12391243
end
12401244
if isconst(b)
1241-
return b.impl.val + a
1245+
return get_val(b) + a
12421246
end
12431247
!issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata
12441248
if isadd(a) && isadd(b)
@@ -1255,7 +1259,7 @@ function +(a::SN, b::SN)
12551259
end
12561260
function +(a::Number, b::SN)
12571261
if isconst(b)
1258-
return a + b.impl.val
1262+
return a + get_val(b)
12591263
end
12601264
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
12611265
iszero(a) && return b
@@ -1270,7 +1274,7 @@ end
12701274

12711275
function -(a::SN)
12721276
if isconst(a)
1273-
v = a.impl.val
1277+
v = get_val(a)
12741278
mv = -v
12751279
return _Const(mv)
12761280
end
@@ -1299,10 +1303,10 @@ mul_t(a) = promote_symtype(*, symtype(a))
12991303

13001304
function *(a::SN, b::SN)
13011305
if isconst(a)
1302-
return a.impl.val * b
1306+
return get_val(a) * b
13031307
end
13041308
if isconst(b)
1305-
return b.impl.val * a
1309+
return get_val(b) * a
13061310
end
13071311
# Always make sure Div wraps Mul
13081312
!issafecanon(*, a, b) && return term(*, a, b)
@@ -1333,7 +1337,7 @@ function *(a::SN, b::SN)
13331337
end
13341338
function *(a::Number, b::SN)
13351339
if isconst(b)
1336-
return a * b.impl.val
1340+
return a * get_val(b)
13371341
end
13381342
!issafecanon(*, b) && return term(*, a, b)
13391343
if iszero(a)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ sym_isa(::Type{T}) where {T} = @nospecialize(x) -> x isa T || symtype(x) <: T
6666

6767
function is_literal_number(x)
6868
if isconst(x)
69-
x = x.impl.val
69+
x = get_val(x)
7070
end
7171
x isa Number
7272
end

test/types.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SymbolicUtils: Symbolic, BasicSymbolic, _Sym, _Term, _Const, _Add, get_name
1+
using SymbolicUtils: Symbolic, BasicSymbolic, _Sym, _Term, _Const, _Add, get_name, get_val
22

33
@testset "Expronicon generated constructors" begin
44
s1 = Sym(:abc)
@@ -133,17 +133,17 @@ end
133133
@test typeof(c1) == BasicSymbolic{Float64}
134134
@test c1.metadata == SymbolicUtils.NO_METADATA
135135
@test c1.hash[] == SymbolicUtils.EMPTY_HASH
136-
@test c1.impl.val == 1.0
136+
@test get_val(c1) == 1.0
137137
c2 = _Const(big"123456789012345678901234567890")
138138
@test typeof(c2) == BasicSymbolic{BigInt}
139139
@test c2.metadata == SymbolicUtils.NO_METADATA
140140
@test c2.hash[] == SymbolicUtils.EMPTY_HASH
141-
@test c2.impl.val == big"123456789012345678901234567890"
141+
@test get_val(c2) == big"123456789012345678901234567890"
142142
c3 = _Const(big"1.23456789012345678901")
143143
@test typeof(c3) == BasicSymbolic{BigFloat}
144144
@test c3.metadata == SymbolicUtils.NO_METADATA
145145
@test c3.hash[] == SymbolicUtils.EMPTY_HASH
146-
@test c3.impl.val == big"1.23456789012345678901"
146+
@test get_val(c3) == big"1.23456789012345678901"
147147
end
148148
end
149149

0 commit comments

Comments
 (0)