Skip to content

Commit fe18cf5

Browse files
committed
removed smrule macro and added commutativity checks to the rule macro
1 parent b675cc1 commit fe18cf5

File tree

4 files changed

+35
-45
lines changed

4 files changed

+35
-45
lines changed

src/SymbolicUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ export Rewriters
5050
# A library for composing together expr -> expr functions
5151

5252
using Combinatorics: permutations, combinations
53-
export @rule, @acrule, @smrule, RuleSet
53+
export @rule, @acrule, RuleSet
5454

5555
# Rule type and @rule macro
5656
include("rule.jl")

src/matchers.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# 3. Callback: takes arguments Dictionary × Number of elements matched
77
#
88

9-
function matcher(val::Any; acSets = nothing)
9+
function matcher(val::Any, acSets)
1010
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
1111
if iscall(val)
1212
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
@@ -25,7 +25,7 @@ function matcher(val::Any; acSets = nothing)
2525
end
2626

2727
# acSets is not used but needs to be there in case matcher(::Slot) is directly called from the macro
28-
function matcher(slot::Slot; acSets = nothing)
28+
function matcher(slot::Slot, acSets)
2929
function slot_matcher(next, data, bindings)
3030
!islist(data) && return nothing
3131
val = get(bindings, slot.name, nothing)
@@ -44,8 +44,8 @@ end
4444
# this is called only when defslot_term_matcher finds the operation and tries
4545
# to match it, so no default value used. So the same function as slot_matcher
4646
# can be used
47-
function matcher(defslot::DefSlot; acSets = nothing)
48-
matcher(Slot(defslot.name, defslot.predicate))
47+
function matcher(defslot::DefSlot, acSets)
48+
matcher(Slot(defslot.name, defslot.predicate), nothing) # slot matcher doesnt use acsets
4949
end
5050

5151
# returns n == offset, 0 if failed
@@ -76,7 +76,7 @@ function trymatchexpr(data, value, n)
7676
end
7777
end
7878

79-
function matcher(segment::Segment; acSets=nothing)
79+
function matcher(segment::Segment, acSets)
8080
function segment_matcher(success, data, bindings)
8181
val = get(bindings, segment.name, nothing)
8282

@@ -105,7 +105,7 @@ function matcher(segment::Segment; acSets=nothing)
105105
end
106106

107107
function term_matcher_constructor(term, acSets)
108-
matchers = (matcher(operation(term); acSets=acSets), map(x->matcher(x;acSets=acSets), arguments(term))...,)
108+
matchers = (matcher(operation(term), acSets), map(x->matcher(x,acSets), arguments(term))...,)
109109

110110
function loop(term, bindings′, matchers′) # Get it to compile faster
111111
if !islist(matchers′)
@@ -181,14 +181,21 @@ function term_matcher_constructor(term, acSets)
181181
operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try
182182

183183
T = symtype(car(data))
184-
f = operation(car(data))
185-
data_args = arguments(car(data))
186-
187-
for inds in acSets(eachindex(data_args), length(arguments(term)))
188-
candidate = Term{T}(f, @views data_args[inds])
189-
190-
result = loop(candidate, bindings, matchers)
191-
result !== nothing && length(data_args) == length(inds) && return success(result,1)
184+
if T <: Number
185+
f = operation(car(data))
186+
data_args = arguments(car(data))
187+
188+
for inds in acSets(eachindex(data_args), length(arguments(term)))
189+
candidate = Term{T}(f, @views data_args[inds])
190+
191+
result = loop(candidate, bindings, matchers)
192+
result !== nothing && length(data_args) == length(inds) && return success(result,1)
193+
end
194+
# if car(data) does not subtype to number, it might not be commutative
195+
else
196+
# call the normal matcher
197+
result = loop(car(data), bindings, matchers)
198+
result !== nothing && return success(result, 1)
192199
end
193200
return nothing
194201
end
@@ -214,7 +221,7 @@ function defslot_term_matcher_constructor(term, acSets)
214221
defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
215222
defslot = a[defslot_index]
216223
if length(a) == 2
217-
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1]; acSets = acSets)
224+
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets)
218225
else
219226
others = [a[i] for i in eachindex(a) if i != defslot_index]
220227
T = symtype(term)

src/rule.jl

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -367,24 +367,6 @@ _In the consequent pattern_: Use `(@ctx)` to access the context object on the ri
367367
of an expression.
368368
"""
369369
macro rule(expr)
370-
@assert expr.head == :call && expr.args[1] == :(=>)
371-
lhs = expr.args[2]
372-
rhs = rewrite_rhs(expr.args[3])
373-
keys = Symbol[]
374-
lhs_term = makepattern(lhs, keys)
375-
unique!(keys)
376-
quote
377-
$(__source__)
378-
lhs_pattern = $(lhs_term)
379-
Rule($(QuoteNode(expr)),
380-
lhs_pattern,
381-
matcher(lhs_pattern),
382-
__MATCHES__ -> $(makeconsequent(rhs)),
383-
rule_depth($lhs_term))
384-
end
385-
end
386-
387-
macro smrule(expr)
388370
@assert expr.head == :call && expr.args[1] == :(=>)
389371
lhs = expr.args[2]
390372
rhs = rewrite_rhs(expr.args[3])
@@ -397,7 +379,7 @@ macro smrule(expr)
397379
Rule(
398380
$(QuoteNode(expr)),
399381
lhs_pattern,
400-
matcher(lhs_pattern; acSets = permutations),
382+
matcher(lhs_pattern, permutations),
401383
__MATCHES__ -> $(makeconsequent(rhs)),
402384
rule_depth($lhs_term)
403385
)
@@ -435,7 +417,7 @@ macro capture(ex, lhs)
435417
lhs_pattern = $(lhs_term)
436418
__MATCHES__ = Rule($(QuoteNode(lhs)),
437419
lhs_pattern,
438-
matcher(lhs_pattern),
420+
matcher(lhs_pattern, nothing),
439421
identity,
440422
rule_depth($lhs_term))($(esc(ex)))
441423
if __MATCHES__ !== nothing
@@ -474,7 +456,7 @@ macro acrule(expr)
474456
lhs_pattern = $(lhs_term)
475457
rule = Rule($(QuoteNode(expr)),
476458
lhs_pattern,
477-
matcher(lhs_pattern; acSets = permutations),
459+
matcher(lhs_pattern, permutations),
478460
__MATCHES__ -> $(makeconsequent(rhs)),
479461
rule_depth($lhs_term))
480462
ACRule(permutations, rule, $arity)
@@ -496,7 +478,7 @@ macro ordered_acrule(expr)
496478
lhs_pattern = $(lhs_term)
497479
rule = Rule($(QuoteNode(expr)),
498480
lhs_pattern,
499-
matcher(lhs_pattern; acSets = combinations),
481+
matcher(lhs_pattern, combinations),
500482
__MATCHES__ -> $(makeconsequent(rhs)),
501483
rule_depth($lhs_term))
502484
ACRule(combinations, rule, $arity)

test/rewrite.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,20 @@ end
4848
end
4949

5050
@testset "Commutative + and *" begin
51-
r1 = @acrule exp(sin(~x) + cos(~x)) => ~x
51+
r1 = @rule exp(sin(~x) + cos(~x)) => ~x
52+
# using a or x changes the order of the arguments in the call
5253
@test r1(exp(sin(a)+cos(a))) === a
5354
@test r1(exp(sin(x)+cos(x))) === x
54-
r2 = @acrule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m)
55-
r3 = @acrule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m)
55+
r2 = @rule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m)
56+
r3 = @rule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m)
5657
@test r2((a+b)*(x+c)^b) === (a, b, x, c, b)
5758
@test r3((a+b)*(x+c)^b) === (a, b, x, c, b)
58-
rPredicate1 = @acrule ~x::(x->isa(x,Number)) + ~y => (~x, ~y)
59-
rPredicate2 = @acrule ~y + ~x::(x->isa(x,Number)) => (~x, ~y)
59+
rPredicate1 = @rule ~x::(x->isa(x,Number)) + ~y => (~x, ~y)
60+
rPredicate2 = @rule ~y + ~x::(x->isa(x,Number)) => (~x, ~y)
6061
@test rPredicate1(2+x) === (2, x)
6162
@test rPredicate2(2+x) === (2, x)
62-
r5 = @acrule (~y*(~z+~w))+~x => (~x, ~y, ~z, ~w)
63-
r6 = @acrule ~x+((~z+~w)*~y) => (~x, ~y, ~z, ~w)
63+
r5 = @rule (~y*(~z+~w))+~x => (~x, ~y, ~z, ~w)
64+
r6 = @rule ~x+((~z+~w)*~y) => (~x, ~y, ~z, ~w)
6465
@test r5(c*(a+b)+d) === (d, c, a, b)
6566
@test r6(c*(a+b)+d) === (d, c, a, b)
6667
end

0 commit comments

Comments
 (0)