@@ -46,7 +46,7 @@ sdict(kv...) = Dict{BasicSymbolic, Any}(kv...)
46
46
end
47
47
48
48
Base. @kwdef struct BasicSymbolic{T} <: Symbolic{T}
49
- x :: BasicSymbolicImpl
49
+ impl :: BasicSymbolicImpl
50
50
metadata:: Metadata = NO_METADATA
51
51
hash:: RefValue{UInt} = Ref (EMPTY_HASH)
52
52
end
@@ -56,7 +56,7 @@ function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic})
56
56
end
57
57
58
58
function exprtype (x:: BasicSymbolic )
59
- @match x:: BasicSymbolic begin
59
+ @match x. impl begin
60
60
Term => TERM
61
61
Add => ADD
62
62
Mul => MUL
71
71
# Same but different error messages
72
72
@noinline error_on_type () = error (" Internal error: unreachable reached!" )
73
73
@noinline error_sym () = error (" Sym doesn't have a operation or arguments!" )
74
+ @noinline error_const () = error (" Const doesn't have a operation or arguments!" )
74
75
@noinline error_property (E, s) = error (" $E doesn't have field $s " )
75
76
76
77
# We can think about bits later
@@ -94,13 +95,14 @@ symtype(x::Number) = typeof(x)
94
95
95
96
# We're returning a function pointer
96
97
@inline function operation (x:: BasicSymbolic )
97
- @match x:: BasicSymbolic begin
98
+ @match x. impl begin
98
99
Term => x. f
99
100
Add => (+ )
100
101
Mul => (* )
101
102
Div => (/ )
102
103
Pow => (^ )
103
104
Sym => error_sym ()
105
+ Const => error_const ()
104
106
_ => error_on_type ()
105
107
end
106
108
end
109
111
110
112
function arguments (x:: BasicSymbolic )
111
113
args = unsorted_arguments (x)
112
- @match x:: BasicSymbolic begin
114
+ @match x. impl begin
113
115
Add => @goto ADD
114
116
Mul => @goto MUL
115
117
_ => return args
@@ -132,50 +134,51 @@ end
132
134
unsorted_arguments (x) = arguments (x)
133
135
children (x:: BasicSymbolic ) = arguments (x)
134
136
function unsorted_arguments (x:: BasicSymbolic )
135
- @match x:: BasicSymbolic begin
137
+ @match x. impl begin
136
138
Term => return x. arguments
137
139
Add => @goto ADDMUL
138
140
Mul => @goto ADDMUL
139
141
Div => @goto DIV
140
142
Pow => @goto POW
141
143
Sym => error_sym ()
144
+ Const => error_const ()
142
145
_ => error_on_type ()
143
146
end
144
147
145
148
@label ADDMUL
146
149
E = exprtype (x)
147
- args = x. arguments
150
+ args = x. impl . arguments
148
151
isempty (args) || return args
149
- siz = length (x. dict)
150
- idcoeff = E === ADD ? iszero (x. coeff) : isone (x. coeff)
152
+ siz = length (x. impl . dict)
153
+ idcoeff = E === ADD ? iszero (x. impl . coeff) : isone (x. impl . coeff)
151
154
sizehint! (args, idcoeff ? siz : siz + 1 )
152
- idcoeff || push! (args, x. coeff)
155
+ idcoeff || push! (args, x. impl . coeff)
153
156
if isadd (x)
154
- for (k, v) in x. dict
157
+ for (k, v) in x. impl . dict
155
158
push! (args, applicable (* ,k,v) ? k* v :
156
159
maketerm (k, * , [k, v]))
157
160
end
158
161
else # MUL
159
- for (k, v) in x. dict
162
+ for (k, v) in x. impl . dict
160
163
push! (args, unstable_pow (k, v))
161
164
end
162
165
end
163
166
return args
164
167
165
168
@label DIV
166
- args = x. arguments
169
+ args = x. impl . arguments
167
170
isempty (args) || return args
168
171
sizehint! (args, 2 )
169
- push! (args, x. num)
170
- push! (args, x. den)
172
+ push! (args, x. impl . num)
173
+ push! (args, x. impl . den)
171
174
return args
172
175
173
176
@label POW
174
- args = x. arguments
177
+ args = x. impl . arguments
175
178
isempty (args) || return args
176
179
sizehint! (args, 2 )
177
- push! (args, x. base)
178
- push! (args, x. exp)
180
+ push! (args, x. impl . base)
181
+ push! (args, x. impl . exp)
179
182
return args
180
183
end
181
184
@@ -220,15 +223,17 @@ function _isequal(a, b, E)
220
223
if E === SYM
221
224
nameof (a) === nameof (b)
222
225
elseif E === ADD || E === MUL
223
- coeff_isequal (a. coeff, b. coeff) && isequal (a. dict, b. dict)
226
+ coeff_isequal (a. impl . coeff, b. impl . coeff) && isequal (a. impl . dict, b. impl . dict)
224
227
elseif E === DIV
225
- isequal (a. num, b. num) && isequal (a. den, b. den)
228
+ isequal (a. impl . num, b. impl . num) && isequal (a. impl . den, b. impl . den)
226
229
elseif E === POW
227
- isequal (a. exp, b. exp) && isequal (a. base, b. base)
230
+ isequal (a. impl . exp, b. impl . exp) && isequal (a. impl . base, b. impl . base)
228
231
elseif E === TERM
229
232
a1 = arguments (a)
230
233
a2 = arguments (b)
231
234
isequal (operation (a), operation (b)) && _allarequal (a1, a2)
235
+ elseif E === CONST
236
+ isequal (a. impl. val, b. impl. val)
232
237
else
233
238
error_on_type ()
234
239
end
@@ -246,6 +251,7 @@ const ADD_SALT = 0xaddaddaddaddadda % UInt
246
251
const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt
247
252
const DIV_SALT = 0x334b218e73bbba53 % UInt
248
253
const POW_SALT = 0x2b55b97a6efb080c % UInt
254
+ const COS_SALT = 0xdc3d6b8f18b75e3c % UInt
249
255
function Base. hash (s:: BasicSymbolic , salt:: UInt ):: UInt
250
256
E = exprtype (s)
251
257
if E === SYM
@@ -255,13 +261,13 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
255
261
h = s. hash[]
256
262
! iszero (h) && return h
257
263
hashoffset = isadd (s) ? ADD_SALT : SUB_SALT
258
- h′ = hash (hashoffset, hash (s. coeff, hash (s. dict, salt)))
264
+ h′ = hash (hashoffset, hash (s. impl . coeff, hash (s. impl . dict, salt)))
259
265
s. hash[] = h′
260
266
return h′
261
267
elseif E === DIV
262
- return hash (s. num, hash (s. den, salt ⊻ DIV_SALT))
268
+ return hash (s. impl . num, hash (s. impl . den, salt ⊻ DIV_SALT))
263
269
elseif E === POW
264
- hash (s. exp, hash (s. base, salt ⊻ POW_SALT))
270
+ hash (s. impl . exp, hash (s. impl . base, salt ⊻ POW_SALT))
265
271
elseif E === TERM
266
272
! iszero (salt) && return hash (hash (s, zero (UInt)), salt)
267
273
h = s. hash[]
@@ -271,6 +277,8 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
271
277
h′ = hashvec (arguments (s), hash (oph, salt))
272
278
s. hash[] = h′
273
279
return h′
280
+ elseif E === CONST
281
+ return hash (s. impl. val, salt ⊻ COS_SALT)
274
282
else
275
283
error_on_type ()
276
284
end
0 commit comments