Skip to content

Commit 6dfd06b

Browse files
Merge pull request #749 from Bumblebee00/default_values_rules
first prototype of default value rules implemented
2 parents a95aea2 + 2f5ae40 commit 6dfd06b

File tree

3 files changed

+204
-25
lines changed

3 files changed

+204
-25
lines changed

src/matchers.jl

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,48 @@
55
# 2. Dictionary
66
# 3. Callback: takes arguments Dictionary × Number of elements matched
77
#
8+
89
function matcher(val::Any)
9-
iscall(val) && return term_matcher(val)
10+
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
11+
if iscall(val)
12+
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
13+
if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val))
14+
return defslot_term_matcher_constructor(val)
15+
# else return a normal term matcher
16+
else
17+
return term_matcher_constructor(val)
18+
end
19+
end
20+
1021
function literal_matcher(next, data, bindings)
22+
# car data is the first element of data
1123
islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
1224
end
1325
end
1426

1527
function matcher(slot::Slot)
1628
function slot_matcher(next, data, bindings)
17-
!islist(data) && return
29+
!islist(data) && return nothing
1830
val = get(bindings, slot.name, nothing)
31+
# if slot name already is in bindings, check if it matches
1932
if val !== nothing
2033
if isequal(val, car(data))
2134
return next(bindings, 1)
2235
end
23-
else
24-
if slot.predicate(car(data))
25-
next(assoc(bindings, slot.name, car(data)), 1)
26-
end
36+
# elseif the first element of data matches the slot predicate, add it to bindings and call next
37+
elseif slot.predicate(car(data))
38+
next(assoc(bindings, slot.name, car(data)), 1)
2739
end
2840
end
2941
end
3042

43+
# this is called only when defslot_term_matcher finds the operation and tries
44+
# to match it, so no default value used. So the same function as slot_matcher
45+
# can be used
46+
function matcher(defslot::DefSlot)
47+
matcher(Slot(defslot.name, defslot.predicate))
48+
end
49+
3150
# returns n == offset, 0 if failed
3251
function trymatchexpr(data, value, n)
3352
if !islist(value)
@@ -84,13 +103,73 @@ function matcher(segment::Segment)
84103
end
85104
end
86105

87-
function term_matcher(term)
106+
function term_matcher_constructor(term)
88107
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
108+
89109
function term_matcher(success, data, bindings)
110+
!islist(data) && return nothing # if data is not a list, return nothing
111+
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
112+
113+
function loop(term, bindings′, matchers′) # Get it to compile faster
114+
if !islist(matchers′)
115+
if !islist(term)
116+
return success(bindings′, 1)
117+
end
118+
return nothing
119+
end
120+
car(matchers′)(term, bindings′) do b, n
121+
loop(drop_n(term, n), b, cdr(matchers′))
122+
end
123+
# explenation of above 3 lines:
124+
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
125+
# <------ next(b,n) ---------------------------->
126+
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
127+
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
128+
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
129+
# the length of the list, is considered empty
130+
end
131+
132+
loop(car(data), bindings, matchers) # Try to eat exactly one term
133+
end
134+
end
135+
136+
# creates a matcher for a term containing a defslot, such as:
137+
# (~x + ...complicated pattern...) * ~!y
138+
# normal part (can bee a tree) operation defslot part
90139

140+
# defslot_term_matcher works like this:
141+
# checks wether data starts with the default operation.
142+
# if yes (1): continues like term_matcher
143+
# if no checks wether data matches the normal part
144+
# if no returns nothing, rule is not applied
145+
# if yes (2): adds the pair (default value name, default value) to the found bindings and
146+
# calls the success function like term_matcher would do
147+
148+
function defslot_term_matcher_constructor(term)
149+
a = arguments(term) # lenght two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments
150+
matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term
151+
152+
defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
153+
defslot = a[defslot_index]
154+
155+
function defslot_term_matcher(success, data, bindings)
156+
# if data is not a list, return nothing
91157
!islist(data) && return nothing
92-
!iscall(car(data)) && return nothing
158+
# if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation)
159+
if !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation)
160+
other_part_matcher = matchers[defslot_index==2 ? 2 : 3] # find the matcher of the normal part
161+
162+
# checks wether it matches the normal part
163+
# <-----------------(2)------------------------------->
164+
bindings = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings)
165+
166+
if bindings === nothing
167+
return nothing
168+
end
169+
return success(bindings, 1)
170+
end
93171

172+
# (1)
94173
function loop(term, bindings′, matchers′) # Get it to compile faster
95174
if !islist(matchers′)
96175
if !islist(term)

src/rule.jl

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
@inline alwaystrue(x) = true
33

4-
# Matcher patterns with Slot and Segment
4+
# Matcher patterns with Slot, DefSlot and Segment
55

66
# matches one term
77
# syntax: ~x
@@ -16,6 +16,79 @@ Base.isequal(s1::Slot, s2::Slot) = s1.name == s2.name
1616

1717
Base.show(io::IO, s::Slot) = (print(io, "~"); print(io, s.name))
1818

19+
# for when the slot is a symbol, like `~x`
20+
makeslot(s::Symbol, keys) = (push!(keys, s); Slot(s))
21+
22+
# for when the slot is an expression, like `~x::predicate`
23+
function makeslot(s::Expr, keys)
24+
if !(s.head == :(::))
25+
error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function")
26+
end
27+
28+
name = s.args[1]
29+
30+
push!(keys, name)
31+
:(Slot($(QuoteNode(name)), $(esc(s.args[2]))))
32+
end
33+
34+
35+
36+
37+
38+
39+
# matches one term with built in default value.
40+
# syntax: ~!x
41+
# Example usage:
42+
# (~!x + ~y) can match (a + b) but also just "a" and x takes default value of zero.
43+
# (~!x)*(~y) can match a*b but also just "a", and x takes default value of one.
44+
# (~x + ~y)^(~!z) can match (a + b)^c but also just "a + b", and z takes default value of one.
45+
# only these three operations are supported for default values.
46+
47+
struct DefSlot{P, O}
48+
name::Symbol
49+
predicate::P
50+
operation::O
51+
defaultValue::Real
52+
end
53+
54+
# operation | default
55+
# + | 0
56+
# * | 1
57+
# ^ | 1
58+
function defaultValOfCall(call)
59+
if call == :+
60+
return 0
61+
elseif call == :*
62+
return 1
63+
elseif call == :^
64+
return 1
65+
end
66+
# else no default value for this call
67+
error("You can use default slots only with +, * and ^, but you tried with: $call")
68+
end
69+
70+
DefSlot(s) = DefSlot(s, alwaystrue, nothing, 0)
71+
Base.isequal(s1::DefSlot, s2::DefSlot) = s1.name == s2.name
72+
Base.show(io::IO, s::DefSlot) = (print(io, "~!"); print(io, s.name))
73+
74+
makeDefSlot(s::Symbol, keys, op) = (push!(keys, s); DefSlot(s, alwaystrue, op, defaultValOfCall(op)))
75+
76+
function makeDefSlot(s::Expr, keys, op)
77+
if !(s.head == :(::))
78+
error("Syntax for specifying a default slot is ~!x::\$predicate, where predicate is a boolean function")
79+
end
80+
81+
name = s.args[1]
82+
83+
push!(keys, name)
84+
tmp = defaultValOfCall(op)
85+
:(DefSlot($(QuoteNode(name)), $(esc(s.args[2])), $(esc(op))), $(esc(tmp)))
86+
end
87+
88+
89+
90+
91+
1992
# matches zero or more terms
2093
# syntax: ~~x
2194
struct Segment{F}
@@ -37,37 +110,29 @@ function makesegment(s::Expr, keys)
37110
end
38111

39112
name = s.args[1]
40-
113+
41114
push!(keys, name)
42115
:(Segment($(QuoteNode(name)), $(esc(s.args[2]))))
43116
end
44117

45-
makeslot(s::Symbol, keys) = (push!(keys, s); Slot(s))
46-
47-
function makeslot(s::Expr, keys)
48-
if !(s.head == :(::))
49-
error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function")
50-
end
51-
52-
name = s.args[1]
53-
54-
push!(keys, name)
55-
:(Slot($(QuoteNode(name)), $(esc(s.args[2]))))
56-
end
57-
58-
function makepattern(expr, keys)
118+
# parent call is needed to know which default value to give if any default slots are present
119+
function makepattern(expr, keys, parentCall=nothing)
59120
if expr isa Expr
60121
if expr.head === :call
61122
if expr.args[1] === :(~)
62123
if expr.args[2] isa Expr && expr.args[2].args[1] == :(~)
63124
# matches ~~x::predicate
64125
makesegment(expr.args[2].args[2], keys)
126+
elseif expr.args[2] isa Expr && expr.args[2].args[1] == :(!)
127+
# matches ~!x::predicate
128+
makeDefSlot(expr.args[2].args[2], keys, parentCall)
65129
else
66130
# matches ~x::predicate
67131
makeslot(expr.args[2], keys)
68132
end
69133
else
70-
:(term($(map(x->makepattern(x, keys), expr.args)...); type=Any))
134+
# make a pattern for every argument of the expr.
135+
:(term($(map(x->makepattern(x, keys, operation(expr)), expr.args)...); type=Any))
71136
end
72137
elseif expr.head === :ref
73138
:(term(getindex, $(map(x->makepattern(x, keys), expr.args)...); type=Any))

test/rewrite.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,41 @@ end
4747
@eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, [])
4848
end
4949

50+
@testset "Slot matcher with default value" begin
51+
r_sum = @rule (~x + ~!y)^2 => ~y
52+
@test r_sum((a + b)^2) === b
53+
@test r_sum(b^2) === 0
54+
55+
r_mult = @rule ~x * ~!y => ~y
56+
@test r_mult(a * b) === b
57+
@test r_mult(a) === 1
58+
59+
r_mult2 = @rule (~x * ~!y + ~z) => ~y
60+
@test r_mult2(c + a*b) === b
61+
@test r_mult2(c + b) === 1
62+
63+
# here the "normal part" in the defslot_term_matcher is not a symbol but a tree
64+
r_mult3 = @rule (~!x)*(~y + ~z) => ~x
65+
@test r_mult3(a*(c+2)) === a
66+
@test r_mult3(2*(c+2)) === 2
67+
@test r_mult3(c+2) === 1
68+
69+
r_pow = @rule (~x)^(~!m) => ~m
70+
@test r_pow(a^(b+1)) === b+1
71+
@test r_pow(a) === 1
72+
@test r_pow(a+1) === 1
73+
74+
# here the "normal part" in the defslot_term_matcher is not a symbol but a tree
75+
r_pow2 = @rule (~x + ~y)^(~!m) => ~m
76+
@test r_pow2((a+b)^c) === c
77+
@test r_pow2(a+b) === 1
78+
79+
r_mix = @rule (~x + (~y)*(~!c))^(~!m) => ~m + ~c
80+
@test r_mix((a + b*c)^2) === 2 + c
81+
@test r_mix((a + b*c)) === 1 + c
82+
@test r_mix((a + b)) === 2 #1+1
83+
end
84+
5085
using SymbolicUtils: @capture
5186

5287
@testset "Capture form" begin

0 commit comments

Comments
 (0)