Skip to content

Commit e011b04

Browse files
committed
Define get_num
1 parent 546b000 commit e011b04

File tree

3 files changed

+29
-25
lines changed

3 files changed

+29
-25
lines changed

src/polyform.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,11 @@ end
296296
#add_divs(x, y) = x + y
297297
function add_divs(x, y)
298298
if isdiv(x) && isdiv(y)
299-
return (x.impl.num * y.impl.den + y.impl.num * x.impl.den) / (x.impl.den * y.impl.den)
299+
return (get_num(x) * y.impl.den + get_num(y) * x.impl.den) / (x.impl.den * y.impl.den)
300300
elseif isdiv(x)
301-
return (x.impl.num + y * x.impl.den) / x.impl.den
301+
return (get_num(x) + y * x.impl.den) / x.impl.den
302302
elseif isdiv(y)
303-
return (x * y.impl.den + y.impl.num) / y.impl.den
303+
return (x * y.impl.den + get_num(y)) / y.impl.den
304304
else
305305
x + y
306306
end
@@ -384,7 +384,7 @@ function fraction_isone(x)
384384
end
385385

386386
function needs_div_rules(x)
387-
(isdiv(x) && !(x.impl.num isa Number) && !(x.impl.den isa Number)) ||
387+
(isdiv(x) && !(get_num(x) isa Number) && !(x.impl.den isa Number)) ||
388388
(iscall(x) && operation(x) === (+) && count(has_div, arguments(x)) > 1) ||
389389
(iscall(x) && any(needs_div_rules, arguments(x)))
390390
end
@@ -417,11 +417,11 @@ Has optimized processes for `Mul` and `Pow` terms.
417417
"""
418418
function quick_cancel(d)
419419
if ispow(d) && isdiv(d.impl.base)
420-
return quick_cancel((d.impl.base.impl.num^d.impl.exp) / (d.impl.base.impl.den^d.impl.exp))
420+
return quick_cancel((get_num(d.impl.base)^d.impl.exp) / (d.impl.base.impl.den^d.impl.exp))
421421
elseif ismul(d) && any(isdiv, arguments(d))
422422
return prod(arguments(d))
423423
elseif isdiv(d)
424-
num, den = quick_cancel(d.impl.num, d.impl.den)
424+
num, den = quick_cancel(get_num(d), d.impl.den)
425425
return _Div(num, den)
426426
else
427427
return d

src/types.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ function get_dict(x::BasicSymbolic)
7676
x.impl.dict
7777
end
7878

79+
function get_num(x::BasicSymbolic)
80+
x.impl.num
81+
end
82+
7983
# Same but different error messages
8084
@noinline error_on_type() = error("Internal error: unreachable reached!")
8185
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
@@ -303,7 +307,7 @@ function _isequal(a, b, E)
303307
elseif E === ADD || E === MUL
304308
coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(get_dict(a), get_dict(b))
305309
elseif E === DIV
306-
isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den)
310+
isequal(get_num(a), get_num(b)) && isequal(a.impl.den, b.impl.den)
307311
elseif E === POW
308312
isequal(a.impl.exp, b.impl.exp) && isequal(a.impl.base, b.impl.base)
309313
elseif E === TERM
@@ -349,7 +353,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
349353
s.hash[] = h′
350354
return h′
351355
elseif E === DIV
352-
return hash(s.impl.num, hash(s.impl.den, salt DIV_SALT))
356+
return hash(get_num(s), hash(s.impl.den, salt DIV_SALT))
353357
elseif E === POW
354358
hash(s.impl.exp, hash(s.impl.base, salt POW_SALT))
355359
elseif E === TERM
@@ -483,11 +487,11 @@ function _Div(::Type{T}, num, den; kwargs...) where {T}
483487
_iszero(num) && return zero(typeof(num))
484488
_isone(den) && return num
485489
if isdiv(num) && isdiv(den)
486-
return _Div(T, num.impl.num * den.impl.den, num.impl.den * den.impl.num)
490+
return _Div(T, get_num(num) * den.impl.den, num.impl.den * get_num(den))
487491
elseif isdiv(num)
488-
return _Div(T, num.impl.num, num.impl.den * den)
492+
return _Div(T, get_num(num), num.impl.den * den)
489493
elseif isdiv(den)
490-
return _Div(T, num * den.impl.den, den.impl.num)
494+
return _Div(T, num * den.impl.den, get_num(den))
491495
end
492496
if den isa Number && _isone(-den)
493497
return -1 * num
@@ -523,7 +527,7 @@ function _Div(num, den; kwargs...)
523527
end
524528

525529
@inline function numerators(x)
526-
isdiv(x) && return numerators(x.impl.num)
530+
isdiv(x) && return numerators(get_num(x))
527531
iscall(x) && operation(x) === (*) ? arguments(x) : Any[x]
528532
end
529533

@@ -552,7 +556,7 @@ function toterm(t::BasicSymbolic{T}) where {T}
552556
end
553557
_Term(T, operation(t), args)
554558
elseif E === DIV
555-
_Term(T, /, [t.impl.num, t.impl.den])
559+
_Term(T, /, [get_num(t), t.impl.den])
556560
elseif E === POW
557561
_Term(T, ^, [t.impl.base, t.impl.exp])
558562
else
@@ -1291,11 +1295,11 @@ function *(a::SN, b::SN)
12911295
# Always make sure Div wraps Mul
12921296
!issafecanon(*, a, b) && return term(*, a, b)
12931297
if isdiv(a) && isdiv(b)
1294-
_Div(a.impl.num * b.impl.num, a.impl.den * b.impl.den)
1298+
_Div(get_num(a) * get_num(b), a.impl.den * b.impl.den)
12951299
elseif isdiv(a)
1296-
_Div(a.impl.num * b, a.impl.den)
1300+
_Div(get_num(a) * b, a.impl.den)
12971301
elseif isdiv(b)
1298-
_Div(a * b.impl.num, b.impl.den)
1302+
_Div(a * get_num(b), b.impl.den)
12991303
elseif ismul(a) && ismul(b)
13001304
_Mul(mul_t(a, b), get_coeff(a) * get_coeff(b),
13011305
_merge(+, get_dict(a), get_dict(b), filter = _iszero))
@@ -1325,7 +1329,7 @@ function *(a::Number, b::SN)
13251329
elseif isone(a)
13261330
b
13271331
elseif isdiv(b)
1328-
_Div(a * b.impl.num, b.impl.den)
1332+
_Div(a * get_num(b), b.impl.den)
13291333
elseif isone(-a) && isadd(b)
13301334
# -1(a+b) -> -a - b
13311335
T = promote_symtype(+, typeof(a), symtype(b))

test/basics.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm,
2-
BasicSymbolic, term, get_name, get_coeff, get_dict
2+
BasicSymbolic, term, get_name, get_coeff, get_dict, get_num
33
using SymbolicUtils
44
using IfElse: ifelse
55
using Setfield
@@ -346,21 +346,21 @@ end
346346

347347
@testset "div" begin
348348
@syms x::SafeReal y::Real
349-
@test issym((2x / 2y).impl.num)
350-
@test get_coeff((2x / 3y).impl.num) == 2
349+
@test issym(get_num(2x / 2y))
350+
@test get_coeff(get_num(2x / 3y)) == 2
351351
@test get_coeff((2x / 3y).impl.den) == 3
352-
@test get_coeff((2x / -3x).impl.num) == -2
352+
@test get_coeff(get_num(2x / -3x)) == -2
353353
@test get_coeff((2x / -3x).impl.den) == 3
354-
@test get_coeff((2.5x / 3x).impl.num) == 2.5
354+
@test get_coeff(get_num(2.5x / 3x)) == 2.5
355355
@test get_coeff((2.5x / 3x).impl.den) == 3
356356
@test get_coeff((x / 3x).impl.den) == 3
357357

358358
@syms x y
359-
@test issym((2x / 2y).impl.num)
360-
@test get_coeff((2x / 3y).impl.num) == 2
359+
@test issym(get_num(2x / 2y))
360+
@test get_coeff(get_num(2x / 3y)) == 2
361361
@test get_coeff((2x / 3y).impl.den) == 3
362362
@test (2x / -3x) == -2 // 3
363-
@test (2.5x / 3x).impl.num == 2.5
363+
@test get_num(2.5x / 3x) == 2.5
364364
@test (2.5x / 3x).impl.den == 3
365365
@test (x / 3x) == 1 // 3
366366
@test isequal(x / 1, x)

0 commit comments

Comments
 (0)