Skip to content

Commit 608e23a

Browse files
committed
Adapt functions to new struct definition
1 parent 989907b commit 608e23a

File tree

1 file changed

+31
-23
lines changed

1 file changed

+31
-23
lines changed

src/types.jl

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ sdict(kv...) = Dict{BasicSymbolic, Any}(kv...)
4646
end
4747

4848
Base.@kwdef struct BasicSymbolic{T} <: Symbolic{T}
49-
x::BasicSymbolicImpl
49+
impl::BasicSymbolicImpl
5050
metadata::Metadata = NO_METADATA
5151
hash::RefValue{UInt} = Ref(EMPTY_HASH)
5252
end
@@ -56,7 +56,7 @@ function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic})
5656
end
5757

5858
function exprtype(x::BasicSymbolic)
59-
@match x::BasicSymbolic begin
59+
@match x.impl begin
6060
Term => TERM
6161
Add => ADD
6262
Mul => MUL
@@ -71,6 +71,7 @@ end
7171
# Same but different error messages
7272
@noinline error_on_type() = error("Internal error: unreachable reached!")
7373
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
74+
@noinline error_const() = error("Const doesn't have a operation or arguments!")
7475
@noinline error_property(E, s) = error("$E doesn't have field $s")
7576

7677
# We can think about bits later
@@ -94,13 +95,14 @@ symtype(x::Number) = typeof(x)
9495

9596
# We're returning a function pointer
9697
@inline function operation(x::BasicSymbolic)
97-
@match x::BasicSymbolic begin
98+
@match x.impl begin
9899
Term => x.f
99100
Add => (+)
100101
Mul => (*)
101102
Div => (/)
102103
Pow => (^)
103104
Sym => error_sym()
105+
Const => error_const()
104106
_ => error_on_type()
105107
end
106108
end
@@ -109,7 +111,7 @@ end
109111

110112
function arguments(x::BasicSymbolic)
111113
args = unsorted_arguments(x)
112-
@match x::BasicSymbolic begin
114+
@match x.impl begin
113115
Add => @goto ADD
114116
Mul => @goto MUL
115117
_ => return args
@@ -132,50 +134,51 @@ end
132134
unsorted_arguments(x) = arguments(x)
133135
children(x::BasicSymbolic) = arguments(x)
134136
function unsorted_arguments(x::BasicSymbolic)
135-
@match x::BasicSymbolic begin
137+
@match x.impl begin
136138
Term => return x.arguments
137139
Add => @goto ADDMUL
138140
Mul => @goto ADDMUL
139141
Div => @goto DIV
140142
Pow => @goto POW
141143
Sym => error_sym()
144+
Const => error_const()
142145
_ => error_on_type()
143146
end
144147

145148
@label ADDMUL
146149
E = exprtype(x)
147-
args = x.arguments
150+
args = x.impl.arguments
148151
isempty(args) || return args
149-
siz = length(x.dict)
150-
idcoeff = E === ADD ? iszero(x.coeff) : isone(x.coeff)
152+
siz = length(x.impl.dict)
153+
idcoeff = E === ADD ? iszero(x.impl.coeff) : isone(x.impl.coeff)
151154
sizehint!(args, idcoeff ? siz : siz + 1)
152-
idcoeff || push!(args, x.coeff)
155+
idcoeff || push!(args, x.impl.coeff)
153156
if isadd(x)
154-
for (k, v) in x.dict
157+
for (k, v) in x.impl.dict
155158
push!(args, applicable(*,k,v) ? k*v :
156159
maketerm(k, *, [k, v]))
157160
end
158161
else # MUL
159-
for (k, v) in x.dict
162+
for (k, v) in x.impl.dict
160163
push!(args, unstable_pow(k, v))
161164
end
162165
end
163166
return args
164167

165168
@label DIV
166-
args = x.arguments
169+
args = x.impl.arguments
167170
isempty(args) || return args
168171
sizehint!(args, 2)
169-
push!(args, x.num)
170-
push!(args, x.den)
172+
push!(args, x.impl.num)
173+
push!(args, x.impl.den)
171174
return args
172175

173176
@label POW
174-
args = x.arguments
177+
args = x.impl.arguments
175178
isempty(args) || return args
176179
sizehint!(args, 2)
177-
push!(args, x.base)
178-
push!(args, x.exp)
180+
push!(args, x.impl.base)
181+
push!(args, x.impl.exp)
179182
return args
180183
end
181184

@@ -220,15 +223,17 @@ function _isequal(a, b, E)
220223
if E === SYM
221224
nameof(a) === nameof(b)
222225
elseif E === ADD || E === MUL
223-
coeff_isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)
226+
coeff_isequal(a.impl.coeff, b.impl.coeff) && isequal(a.impl.dict, b.impl.dict)
224227
elseif E === DIV
225-
isequal(a.num, b.num) && isequal(a.den, b.den)
228+
isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den)
226229
elseif E === POW
227-
isequal(a.exp, b.exp) && isequal(a.base, b.base)
230+
isequal(a.impl.exp, b.impl.exp) && isequal(a.impl.base, b.impl.base)
228231
elseif E === TERM
229232
a1 = arguments(a)
230233
a2 = arguments(b)
231234
isequal(operation(a), operation(b)) && _allarequal(a1, a2)
235+
elseif E === CONST
236+
isequal(a.impl.val, b.impl.val)
232237
else
233238
error_on_type()
234239
end
@@ -246,6 +251,7 @@ const ADD_SALT = 0xaddaddaddaddadda % UInt
246251
const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt
247252
const DIV_SALT = 0x334b218e73bbba53 % UInt
248253
const POW_SALT = 0x2b55b97a6efb080c % UInt
254+
const COS_SALT = 0xdc3d6b8f18b75e3c % UInt
249255
function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
250256
E = exprtype(s)
251257
if E === SYM
@@ -255,13 +261,13 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
255261
h = s.hash[]
256262
!iszero(h) && return h
257263
hashoffset = isadd(s) ? ADD_SALT : SUB_SALT
258-
h′ = hash(hashoffset, hash(s.coeff, hash(s.dict, salt)))
264+
h′ = hash(hashoffset, hash(s.impl.coeff, hash(s.impl.dict, salt)))
259265
s.hash[] = h′
260266
return h′
261267
elseif E === DIV
262-
return hash(s.num, hash(s.den, salt DIV_SALT))
268+
return hash(s.impl.num, hash(s.impl.den, salt DIV_SALT))
263269
elseif E === POW
264-
hash(s.exp, hash(s.base, salt POW_SALT))
270+
hash(s.impl.exp, hash(s.impl.base, salt POW_SALT))
265271
elseif E === TERM
266272
!iszero(salt) && return hash(hash(s, zero(UInt)), salt)
267273
h = s.hash[]
@@ -271,6 +277,8 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
271277
h′ = hashvec(arguments(s), hash(oph, salt))
272278
s.hash[] = h′
273279
return h′
280+
elseif E === CONST
281+
return hash(s.impl.val, salt COS_SALT)
274282
else
275283
error_on_type()
276284
end

0 commit comments

Comments
 (0)