Skip to content

Commit 7a0f514

Browse files
committed
moved the commutativiry checks to only acrule macro
1 parent c2e2696 commit 7a0f514

File tree

3 files changed

+78
-42
lines changed

3 files changed

+78
-42
lines changed

src/matchers.jl

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
# 3. Callback: takes arguments Dictionary × Number of elements matched
77
#
88

9-
function matcher(val::Any)
9+
function matcher(val::Any; acSets = nothing)
1010
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
1111
if iscall(val)
1212
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
1313
# just two arguments bc defslot is only supported with operations with two args: *, ^, +
1414
if any(x -> isa(x, DefSlot), arguments(val))
15-
return defslot_term_matcher_constructor(val)
15+
return defslot_term_matcher_constructor(val, acSets)
1616
end
1717
# else return a normal term matcher
18-
return term_matcher_constructor(val)
18+
return term_matcher_constructor(val, acSets)
1919
end
2020

2121
function literal_matcher(next, data, bindings)
@@ -24,7 +24,8 @@ function matcher(val::Any)
2424
end
2525
end
2626

27-
function matcher(slot::Slot)
27+
# acSets is not used but needs to be there in case matcher(::Slot) is directly called from the macro
28+
function matcher(slot::Slot; acSets = nothing)
2829
function slot_matcher(next, data, bindings)
2930
!islist(data) && return nothing
3031
val = get(bindings, slot.name, nothing)
@@ -43,7 +44,7 @@ end
4344
# this is called only when defslot_term_matcher finds the operation and tries
4445
# to match it, so no default value used. So the same function as slot_matcher
4546
# can be used
46-
function matcher(defslot::DefSlot)
47+
function matcher(defslot::DefSlot; acSets = nothing)
4748
matcher(Slot(defslot.name, defslot.predicate))
4849
end
4950

@@ -59,7 +60,7 @@ function opposite_sign_matcher(slot::Slot)
5960
return next(bindings, 1)
6061
end
6162
elseif slot.predicate(car(data))
62-
next(assoc(bindings, slot.name, -car(data)), 1) # this - is the only differenct wrt matcher(slot::Slot)
63+
next(assoc(bindings, slot.name, -car(data)), 1) # this - is the only difference wrt matcher(slot::Slot)
6364
end
6465
end
6566
end
@@ -96,7 +97,7 @@ function trymatchexpr(data, value, n)
9697
end
9798
end
9899

99-
function matcher(segment::Segment)
100+
function matcher(segment::Segment; acSets=nothing)
100101
function segment_matcher(success, data, bindings)
101102
val = get(bindings, segment.name, nothing)
102103

@@ -124,8 +125,8 @@ function matcher(segment::Segment)
124125
end
125126
end
126127

127-
function term_matcher_constructor(term)
128-
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
128+
function term_matcher_constructor(term, acSets)
129+
matchers = (matcher(operation(term); acSets=acSets), map(x->matcher(x;acSets=acSets), arguments(term))...,)
129130

130131
function loop(term, bindings′, matchers′) # Get it to compile faster
131132
if !islist(matchers′)
@@ -137,7 +138,7 @@ function term_matcher_constructor(term)
137138
car(matchers′)(term, bindings′) do b, n
138139
loop(drop_n(term, n), b, cdr(matchers′))
139140
end
140-
# explenation of above 3 lines:
141+
# explanation of above 3 lines:
141142
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
142143
# <------ next(b,n) ---------------------------->
143144
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
@@ -171,7 +172,7 @@ function term_matcher_constructor(term)
171172
frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2])
172173
result = loop(frankestein, bindings, matchers)
173174
# if b is a number, like 3, we cant call loop with a^-3 bc it
174-
# will automatically transform into 1/a^3. Therfore we need to
175+
# will automatically transform into 1/a^3. Therefore we need to
175176
# create a matcher that flips the sign of the exponent. I created
176177
# this matecher just for `Slot`s and `DefSlot`s, but not for
177178
# terms or literals, because if b is a number and not a call,
@@ -181,21 +182,27 @@ function term_matcher_constructor(term)
181182
result = loop(denominator, bindings, matchers_modified)
182183
end
183184
end
184-
if result !== nothing
185-
return success(result, 1)
186-
end
185+
result !== nothing && return success(result, 1)
186+
return nothing
187187
end
188188
return term_matcher_pow
189189
# if the operation is commutative
190-
elseif operation(term) in [+, *]
190+
elseif acSets!==nothing && !isa(arguments(term)[1], Segment) && operation(term) in [+, *]
191191
function term_matcher_comm(success, data, bindings)
192192
!islist(data) && return nothing # if data is not a list, return nothing
193193
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
194194

195-
for m in all_matchers
196-
result = loop(car(data), bindings, m)
197-
result !== nothing && return success(result, 1)
195+
T = symtype(car(data))
196+
f = operation(car(data))
197+
data_args = arguments(car(data))
198+
199+
for inds in acSets(eachindex(data_args), length(arguments(term)))
200+
candidate = Term{T}(f, @views data_args[inds])
201+
202+
result = loop(candidate, bindings, matchers)
203+
result !== nothing && length(data_args) == length(inds) && return success(result,1)
198204
end
205+
return nothing
199206
end
200207
return term_matcher_comm
201208
else
@@ -204,9 +211,8 @@ function term_matcher_constructor(term)
204211
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
205212

206213
result = loop(car(data), bindings, matchers)
207-
if result !== nothing
208-
return success(result, 1)
209-
end
214+
result !== nothing && return success(result, 1)
215+
return nothing
210216
end
211217
return term_matcher
212218
end
@@ -215,35 +221,31 @@ end
215221
# creates a matcher for a term containing a defslot, such as:
216222
# (~x + ...complicated pattern...) * ~!y
217223
# normal part (can bee a tree) operation defslot part
218-
function defslot_term_matcher_constructor(term)
224+
function defslot_term_matcher_constructor(term, acSets)
219225
a = arguments(term)
220226
defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
221227
defslot = a[defslot_index]
222228
if length(a) == 2
223-
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1])
229+
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1]; acSets = acSets)
224230
else
225231
others = [a[i] for i in eachindex(a) if i != defslot_index]
226232
T = symtype(term)
227233
f = operation(term)
228-
other_part_matcher = term_matcher_constructor(Term{T}(f, others))
234+
other_part_matcher = term_matcher_constructor(Term{T}(f, others), acSets)
229235
end
230236

231-
normal_matcher = term_matcher_constructor(term)
237+
normal_matcher = term_matcher_constructor(term, acSets)
232238

233239
function defslot_term_matcher(success, data, bindings)
234240
!islist(data) && return nothing # if data is not a list, return nothing
235241
result = normal_matcher(success, data, bindings)
236242
result !== nothing && return result
237-
# if no match, try to match with a defslot
238-
# if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation)
239-
240-
# checks wether it matches the normal part if yes executes (foo)
243+
# if no match, try to match with a defslot.
244+
# checks whether it matches the normal part if yes executes (foo)
241245
# (foo): adds the pair (default value name, default value) to the found bindings
242246
# <------------------(foo)---------------------------->
243247
result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings)
244-
println(result)
245248
result !== nothing && return success(result, 1)
246-
247249
nothing
248250
end
249251
end

src/rule.jl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,16 +437,46 @@ Rule(acr::ACRule) = acr.rule
437437
getdepth(r::ACRule) = getdepth(r.rule)
438438

439439
macro acrule(expr)
440-
arity = length(expr.args[2].args[2:end])
440+
@assert expr.head == :call && expr.args[1] == :(=>)
441+
lhs = expr.args[2]
442+
rhs = rewrite_rhs(expr.args[3])
443+
keys = Symbol[]
444+
lhs_term = makepattern(lhs, keys)
445+
unique!(keys)
446+
447+
arity = length(lhs.args[2:end])
448+
441449
quote
442-
ACRule(permutations, $(esc(:(@rule($(expr))))), $arity)
450+
$(__source__)
451+
lhs_pattern = $(lhs_term)
452+
rule = Rule($(QuoteNode(expr)),
453+
lhs_pattern,
454+
matcher(lhs_pattern; acSets = permutations),
455+
__MATCHES__ -> $(makeconsequent(rhs)),
456+
rule_depth($lhs_term))
457+
ACRule(permutations, rule, $arity)
443458
end
444459
end
445460

446461
macro ordered_acrule(expr)
447-
arity = length(expr.args[2].args[2:end])
462+
@assert expr.head == :call && expr.args[1] == :(=>)
463+
lhs = expr.args[2]
464+
rhs = rewrite_rhs(expr.args[3])
465+
keys = Symbol[]
466+
lhs_term = makepattern(lhs, keys)
467+
unique!(keys)
468+
469+
arity = length(lhs.args[2:end])
470+
448471
quote
449-
ACRule(combinations, $(esc(:(@rule($(expr))))), $arity)
472+
$(__source__)
473+
lhs_pattern = $(lhs_term)
474+
rule = Rule($(QuoteNode(expr)),
475+
lhs_pattern,
476+
matcher(lhs_pattern; acSets = combinations),
477+
__MATCHES__ -> $(makeconsequent(rhs)),
478+
rule_depth($lhs_term))
479+
ACRule(combinations, rule, $arity)
450480
end
451481
end
452482

test/rewrite.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,21 @@ end
4848
end
4949

5050
@testset "Commutative + and *" begin
51-
r1 = @rule sin(~x) + cos(~x) => ~x
52-
@test r1(sin(a)+cos(a)) === a
53-
@test r1(sin(x)+cos(x)) === x
54-
r2 = @rule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m)
55-
r3 = @rule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m)
51+
r1 = @acrule exp(sin(~x) + cos(~x)) => ~x
52+
@test r1(exp(sin(a)+cos(a))) === a
53+
@test r1(exp(sin(x)+cos(x))) === x
54+
r2 = @acrule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m)
55+
r3 = @acrule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m)
5656
@test r2((a+b)*(x+c)^b) === (a, b, x, c, b)
5757
@test r3((a+b)*(x+c)^b) === (a, b, x, c, b)
58-
rPredicate1 = @rule ~x::(x->isa(x,Number)) + ~y => (~x, ~y)
59-
rPredicate2 = @rule ~y + ~x::(x->isa(x,Number)) => (~x, ~y)
58+
rPredicate1 = @acrule ~x::(x->isa(x,Number)) + ~y => (~x, ~y)
59+
rPredicate2 = @acrule ~y + ~x::(x->isa(x,Number)) => (~x, ~y)
6060
@test rPredicate1(2+x) === (2, x)
6161
@test rPredicate2(2+x) === (2, x)
62+
r5 = @acrule (~y*(~z+~w))+~x => (~x, ~y, ~z, ~w)
63+
r6 = @acrule ~x+((~z+~w)*~y) => (~x, ~y, ~z, ~w)
64+
@test r5(c*(a+b)+d) === (d, c, a, b)
65+
@test r6(c*(a+b)+d) === (d, c, a, b)
6266
end
6367

6468
@testset "Slot matcher with default value" begin

0 commit comments

Comments
 (0)