Skip to content

Commit 4d00726

Browse files
committed
Make BasicSymbolicImpl children BasicSymbolicImpl
1 parent 440c17b commit 4d00726

File tree

1 file changed

+69
-11
lines changed

1 file changed

+69
-11
lines changed

src/types.jl

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,45 @@ function TermInterface.arguments(x::BasicSymbolicImpl)
244244
return args
245245
end
246246

247+
"""
248+
$(TYPEDSIGNATURES)
249+
250+
For given `coeff` and `dict`, return arguments of type `BasicSymbolicImpl`, children's
251+
metadata and a new dictionary with `BasicSymbolicImpl` as key's type for preparation of the
252+
construction of either `Add` or `Mul`.
253+
"""
254+
function get_arguments_metadata(coeff, dict::AbstractDict, type::ExprType)
255+
siz = length(dict)
256+
idcoeff = type === ADD ? iszero(coeff) : isone(coeff)
257+
args = Vector()
258+
sizehint!(args, idcoeff ? siz : siz + 1)
259+
idcoeff || push!(args, coeff)
260+
if type === ADD
261+
for (k, v) in dict
262+
if k isa BasicSymbolicImpl
263+
k = BasicSymbolic(k, MetadataImpl())
264+
end
265+
push!(args, applicable(*, k, v) ? k * v : maketerm(k, *, [k, v], nothing))
266+
end
267+
else # MUL
268+
for (k, v) in dict
269+
if k isa BasicSymbolicImpl
270+
k = BasicSymbolic(k, MetadataImpl())
271+
end
272+
push!(args, unstable_pow(k, v))
273+
end
274+
end
275+
metadata_children = map(getmetaimpl, args)
276+
for i in 1:length(args)
277+
if args[i] isa BasicSymbolic
278+
args[i] = args[i].expr
279+
end
280+
end
281+
keys = idcoeff ? args : @view args[2:end]
282+
bsi_dict = Dict(zip(keys, values(dict)))
283+
return args, metadata_children, bsi_dict
284+
end
285+
247286
isexpr(s::BasicSymbolic) = isexpr(s.expr)
248287
isexpr(expr::BasicSymbolicImpl) = !issym(expr)
249288
iscall(s::BasicSymbolic) = iscall(s.expr)
@@ -599,10 +638,15 @@ function Term{T}(f, args; metadata = NO_METADATA, kw...) where T
599638
if eltype(args) !== Any
600639
args = convert(Vector{Any}, args)
601640
end
602-
641+
metadata_children = map(getmetaimpl, args)
642+
for i in 1:length(args)
643+
if args[i] isa BasicSymbolic
644+
args[i] = args[i].expr
645+
end
646+
end
603647
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...)
604648
bsi = BasicSymbolicImpl(s)
605-
mdi = MetadataImpl(metadata, getmetadata.(args))
649+
mdi = MetadataImpl(metadata, metadata_children)
606650
BasicSymbolic(bsi, mdi)
607651
end
608652

@@ -622,10 +666,10 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
622666
return Mul(T, coeff, dict)
623667
end
624668
end
625-
626-
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments=[], issorted=RefValue(false), kw...)
669+
arguments, metadata_children, dict = get_arguments_metadata(coeff, dict, ADD)
670+
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments, issorted=RefValue(false), kw...)
627671
bsi = BasicSymbolicImpl(s)
628-
mdi = MetadataImpl(metadata, getmetadata.(arguments(s)))
672+
mdi = MetadataImpl(metadata, metadata_children)
629673
BasicSymbolic(bsi, mdi)
630674
end
631675

@@ -641,9 +685,10 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
641685
else
642686
coeff = a
643687
dict = b
644-
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments=[], issorted=RefValue(false), kw...)
688+
arguments, metadata_children, dict = get_arguments_metadata(coeff, dict, MUL)
689+
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments, issorted=RefValue(false), kw...)
645690
bsi = BasicSymbolicImpl(s)
646-
mdi = MetadataImpl(metadata, getmetadata.(arguments(s)))
691+
mdi = MetadataImpl(metadata, metadata_children)
647692
BasicSymbolic(bsi, mdi)
648693
end
649694
end
@@ -708,10 +753,16 @@ function Div{T}(n, d, simplified=false; metadata=NO_METADATA, kwargs...) where {
708753
end
709754
end
710755
end
711-
756+
metadata_children = [getmetaimpl(n), getmetaimpl(d)]
757+
if n isa BasicSymbolic
758+
n = n.expr
759+
end
760+
if d isa BasicSymbolic
761+
d = d.expr
762+
end
712763
s = Div{T}(; num=n, den=d, simplified, arguments=[])
713764
bsi = BasicSymbolicImpl(s)
714-
mdi = MetadataImpl(metadata, getmetadata.(arguments(s)))
765+
mdi = MetadataImpl(metadata, metadata_children)
715766
BasicSymbolic(bsi, mdi)
716767
end
717768

@@ -728,10 +779,17 @@ end
728779

729780
function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
730781
_iszero(b) && return 1
731-
_isone(b) && return a
782+
_isone(b) && return a
783+
metadata_children = [getmetaimpl(a), getmetaimpl(b)]
784+
if a isa BasicSymbolic
785+
a = a.expr
786+
end
787+
if b isa BasicSymbolic
788+
b = b.expr
789+
end
732790
s = Pow{T}(; base=a, exp=b, arguments=[])
733791
bsi = BasicSymbolicImpl(s)
734-
mdi = MetadataImpl(metadata, getmetadata.(arguments(s)))
792+
mdi = MetadataImpl(metadata, metadata_children)
735793
BasicSymbolic(bsi, mdi)
736794
end
737795

0 commit comments

Comments
 (0)