Skip to content

[WIP] Commutative operations and negative exponent match in rules #752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
88d84c4
first version, really caothic, and doesn't work with defslot powers
Bumblebee00 Jun 14, 2025
105e1dc
second version, really caothic, but works with defslotpowers
Bumblebee00 Jun 14, 2025
c45b4a7
fix typo
Bumblebee00 Jun 18, 2025
fa22164
operation + and * are always commutative now
Bumblebee00 Jun 18, 2025
b578edd
added some tests of commutative operations
Bumblebee00 Jun 18, 2025
aee0635
fixed bug on defslot functionality
Bumblebee00 Jun 19, 2025
c223102
added defslot on operations with multiple arguments
Bumblebee00 Jun 19, 2025
4c2b475
moved the commutativiry checks to only acrule macro
Bumblebee00 Jun 19, 2025
cf9324b
negative exponent feature is done in a different way, more clean
Bumblebee00 Jun 20, 2025
a5b4d8f
fixed failing ci tests
Bumblebee00 Jun 20, 2025
66e8253
added tests with deflost in operation call with more than two arguments
Bumblebee00 Jun 20, 2025
41b30b4
now rationals can be used in rules
Bumblebee00 Jun 21, 2025
838c97c
created smrule (sum multiplication rule) macro
Bumblebee00 Jun 22, 2025
390b5b2
enhance commutative term matcher to validate operation type
Bumblebee00 Jun 22, 2025
b53ca01
fixed bug in defslot code and improved performance
Bumblebee00 Jun 22, 2025
8572290
improved negative exponent pattern matching. now it matches also for…
Bumblebee00 Jun 22, 2025
04dff37
changed order of checks in pow term matcher
Bumblebee00 Jun 24, 2025
f20375f
added match for exp and sqrt calls
Bumblebee00 Jun 27, 2025
9bcabec
removed smrule macro and added commutativity checks to the rule macro
Bumblebee00 Jun 30, 2025
08a7ae9
added commutativity checks also for segment matcher
Bumblebee00 Jul 7, 2025
1a46660
fixed predicates with defslots
Bumblebee00 Jul 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 139 additions & 81 deletions src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
# 3. Callback: takes arguments Dictionary × Number of elements matched
#

function matcher(val::Any)
function matcher(val::Any, acSets)
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
if iscall(val)
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val))
return defslot_term_matcher_constructor(val)
# else return a normal term matcher
else
return term_matcher_constructor(val)
# just two arguments bc defslot is only supported with operations with two args: *, ^, +
if any(x -> isa(x, DefSlot), arguments(val))
return defslot_term_matcher_constructor(val, acSets)
end
# else return a normal term matcher
return term_matcher_constructor(val, acSets)
end

function literal_matcher(next, data, bindings)
Expand All @@ -24,7 +24,8 @@ function matcher(val::Any)
end
end

function matcher(slot::Slot)
# acSets is not used but needs to be there in case matcher(::Slot) is directly called from the macro
function matcher(slot::Slot, acSets)
function slot_matcher(next, data, bindings)
!islist(data) && return nothing
val = get(bindings, slot.name, nothing)
Expand All @@ -43,8 +44,8 @@ end
# this is called only when defslot_term_matcher finds the operation and tries
# to match it, so no default value used. So the same function as slot_matcher
# can be used
function matcher(defslot::DefSlot)
matcher(Slot(defslot.name, defslot.predicate))
function matcher(defslot::DefSlot, acSets)
matcher(Slot(defslot.name, defslot.predicate), nothing) # slot matcher doesnt use acsets
end

# returns n == offset, 0 if failed
Expand Down Expand Up @@ -75,7 +76,7 @@ function trymatchexpr(data, value, n)
end
end

function matcher(segment::Segment)
function matcher(segment::Segment, acSets)
function segment_matcher(success, data, bindings)
val = get(bindings, segment.name, nothing)

Expand All @@ -90,98 +91,155 @@ function matcher(segment::Segment)
for i=length(data):-1:0
subexpr = take_n(data, i)

if segment.predicate(subexpr)
res = success(assoc(bindings, segment.name, subexpr), i)
if res !== nothing
break
end
end
!segment.predicate(subexpr) && continue
res = success(assoc(bindings, segment.name, subexpr), i)
res !== nothing && break
end

return res
end
end
end

function term_matcher_constructor(term)
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
function term_matcher_constructor(term, acSets)
matchers = (matcher(operation(term), acSets), map(x->matcher(x,acSets), arguments(term))...,)

function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return bindings′
end
return nothing
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))
end
# explanation of above 3 lines:
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
# <------ next(b,n) ---------------------------->
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
# the length of the list, is considered empty
end

function term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
# if the operation is a pow, we have to match also 1/(...)^(...) with negative exponent
if operation(term) === ^
function pow_term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
data = car(data) # from (..., ) to ...
!iscall(data) && return nothing # if first element is not a call, return nothing

# if data is of the alternative form (1/...)^(...), it might match with negative exponent
if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && isequal(arguments(arguments(data)[1])[1], 1)
one_over_smth = arguments(data)[1]
T = symtype(one_over_smth)
frankestein = Term{T}(^, [arguments(one_over_smth)[2], -arguments(data)[2]])
result = loop(frankestein, bindings, matchers)
result !== nothing && return success(result, 1)
end

function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return success(bindings′, 1)
end
return nothing
result = loop(data, bindings, matchers)
result !== nothing && return success(result, 1)

# if data is of the alternative form 1/(...)^(...), it might match with negative exponent
if (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^)
denominator = arguments(data)[2]
T = symtype(denominator)
frankestein = Term{T}(^, [arguments(denominator)[1], -arguments(denominator)[2]])
result = loop(frankestein, bindings, matchers)
result !== nothing && return success(result, 1)
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))

# if data is a exp call, it might match with base e
if operation(data)===exp
T = symtype(arguments(data)[1])
frankestein = Term{T}(^,[ℯ,arguments(data)[1]])
result = loop(frankestein, bindings, matchers)
result !== nothing && return success(result, 1)
end
# explenation of above 3 lines:
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
# <------ next(b,n) ---------------------------->
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
# the length of the list, is considered empty
end

loop(car(data), bindings, matchers) # Try to eat exactly one term
# if data is a sqrt call, it might match with exponent 1//2
if operation(data)===sqrt
T = symtype(arguments(data)[1])
frankestein = Term{T}(^,[arguments(data)[1], 1//2])
result = loop(frankestein, bindings, matchers)
result !== nothing && return success(result, 1)
end

return nothing
end
return pow_term_matcher
# if we want to do commutative checks, i.e. call matcher with different order of the arguments
elseif acSets!==nothing && operation(term) in [+, *]
function commutative_term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try

T = symtype(car(data))
if T <: Number
f = operation(car(data))
data_args = arguments(car(data))

for inds in acSets(eachindex(data_args), length(data_args))
candidate = Term{T}(f, @views data_args[inds])

result = loop(candidate, bindings, matchers)
result !== nothing && return success(result,1)
end
# if car(data) does not subtype to number, it might not be commutative
else
# call the normal matcher
result = loop(car(data), bindings, matchers)
result !== nothing && return success(result, 1)
end
return nothing
end
return commutative_term_matcher
else
function term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing

result = loop(car(data), bindings, matchers)
result !== nothing && return success(result, 1)
return nothing
end
return term_matcher
end
end

# creates a matcher for a term containing a defslot, such as:
# (~x + ...complicated pattern...) * ~!y
# normal part (can bee a tree) operation defslot part

# defslot_term_matcher works like this:
# checks wether data starts with the default operation.
# if yes (1): continues like term_matcher
# if no checks wether data matches the normal part
# if no returns nothing, rule is not applied
# if yes (2): adds the pair (default value name, default value) to the found bindings and
# calls the success function like term_matcher would do

function defslot_term_matcher_constructor(term)
a = arguments(term) # lenght two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments
matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term

function defslot_term_matcher_constructor(term, acSets)
a = arguments(term)
defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
defslot = a[defslot_index]
if length(a) == 2
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets)
else
others = [a[i] for i in eachindex(a) if i != defslot_index]
T = symtype(term)
f = operation(term)
other_part_matcher = term_matcher_constructor(Term{T}(f, others), acSets)
end

function defslot_term_matcher(success, data, bindings)
# if data is not a list, return nothing
!islist(data) && return nothing
# if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation)
if !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation)
other_part_matcher = matchers[defslot_index==2 ? 2 : 3] # find the matcher of the normal part

# checks wether it matches the normal part
# <-----------------(2)------------------------------->
bindings = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings)

if bindings === nothing
return nothing
end
return success(bindings, 1)
end

# (1)
function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return success(bindings′, 1)
end
return nothing
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))
end
end
normal_matcher = term_matcher_constructor(term, acSets)

loop(car(data), bindings, matchers) # Try to eat exactly one term
function defslot_term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
# call the normal matcher, with success function foo1 that simply returns the bindings
# <--foo1-->
result = normal_matcher((b,n) -> b, data, bindings)
result !== nothing && return success(result, 1)
# if no match, try to match with a defslot.
# checks whether it matches the normal part if yes executes foo2
# foo2: adds the pair (default value name, default value) to the found bindings
# <-------------------foo2---------------------------->
result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings)
result !== nothing && return success(result, 1)
nothing
end
end
67 changes: 49 additions & 18 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function makeDefSlot(s::Expr, keys, op)

push!(keys, name)
tmp = defaultValOfCall(op)
:(DefSlot($(QuoteNode(name)), $(esc(s.args[2])), $(esc(op))), $(esc(tmp)))
:(DefSlot($(QuoteNode(name)), $(esc(s.args[2])), $(esc(op)), $(esc(tmp))))
end


Expand Down Expand Up @@ -130,6 +130,9 @@ function makepattern(expr, keys, parentCall=nothing)
# matches ~x::predicate
makeslot(expr.args[2], keys)
end
elseif expr.args[1] === :(//)
# bc when the expression is not quoted, 3//2 is a Rational{Int64}, not a call
return esc(expr.args[2] // expr.args[3])
else
# make a pattern for every argument of the expr.
:(term($(map(x->makepattern(x, keys, operation(expr)), expr.args)...); type=Any))
Expand Down Expand Up @@ -373,11 +376,13 @@ macro rule(expr)
quote
$(__source__)
lhs_pattern = $(lhs_term)
Rule($(QuoteNode(expr)),
lhs_pattern,
matcher(lhs_pattern),
__MATCHES__ -> $(makeconsequent(rhs)),
rule_depth($lhs_term))
Rule(
$(QuoteNode(expr)),
lhs_pattern,
matcher(lhs_pattern, permutations),
__MATCHES__ -> $(makeconsequent(rhs)),
rule_depth($lhs_term)
)
end
end

Expand Down Expand Up @@ -412,7 +417,7 @@ macro capture(ex, lhs)
lhs_pattern = $(lhs_term)
__MATCHES__ = Rule($(QuoteNode(lhs)),
lhs_pattern,
matcher(lhs_pattern),
matcher(lhs_pattern, nothing),
identity,
rule_depth($lhs_term))($(esc(ex)))
if __MATCHES__ !== nothing
Expand All @@ -437,32 +442,58 @@ Rule(acr::ACRule) = acr.rule
getdepth(r::ACRule) = getdepth(r.rule)

macro acrule(expr)
arity = length(expr.args[2].args[2:end])
@assert expr.head == :call && expr.args[1] == :(=>)
lhs = expr.args[2]
rhs = rewrite_rhs(expr.args[3])
keys = Symbol[]
lhs_term = makepattern(lhs, keys)
unique!(keys)

arity = length(lhs.args[2:end])

quote
ACRule(permutations, $(esc(:(@rule($(expr))))), $arity)
$(__source__)
lhs_pattern = $(lhs_term)
rule = Rule($(QuoteNode(expr)),
lhs_pattern,
matcher(lhs_pattern, permutations),
__MATCHES__ -> $(makeconsequent(rhs)),
rule_depth($lhs_term))
ACRule(permutations, rule, $arity)
end
end

macro ordered_acrule(expr)
arity = length(expr.args[2].args[2:end])
@assert expr.head == :call && expr.args[1] == :(=>)
lhs = expr.args[2]
rhs = rewrite_rhs(expr.args[3])
keys = Symbol[]
lhs_term = makepattern(lhs, keys)
unique!(keys)

arity = length(lhs.args[2:end])

quote
ACRule(combinations, $(esc(:(@rule($(expr))))), $arity)
$(__source__)
lhs_pattern = $(lhs_term)
rule = Rule($(QuoteNode(expr)),
lhs_pattern,
matcher(lhs_pattern, combinations),
__MATCHES__ -> $(makeconsequent(rhs)),
rule_depth($lhs_term))
ACRule(combinations, rule, $arity)
end
end

Base.show(io::IO, acr::ACRule) = print(io, "ACRule(", acr.rule, ")")

function (acr::ACRule)(term)
r = Rule(acr)
if !iscall(term)
if !iscall(term) || operation(term) != operation(r.lhs)
# different operations -> try deflsot
r(term)
else
f = operation(term)
# Assume that the matcher was formed by closing over a term
if f != operation(r.lhs) # Maybe offer a fallback if m.term errors.
return nothing
end

f = operation(term)
T = symtype(term)
args = arguments(term)

Expand Down
Loading
Loading