diff --git a/src/matchers.jl b/src/matchers.jl index 99b76ea1..5b0ad5a1 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -6,16 +6,16 @@ # 3. Callback: takes arguments Dictionary × Number of elements matched # -function matcher(val::Any) +function matcher(val::Any, acSets) # if val is a call (like an operation) creates a term matcher or term matcher with defslot if iscall(val) # if has two arguments and one of them is a DefSlot, create a term matcher with defslot - if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val)) - return defslot_term_matcher_constructor(val) - # else return a normal term matcher - else - return term_matcher_constructor(val) + # just two arguments bc defslot is only supported with operations with two args: *, ^, + + if any(x -> isa(x, DefSlot), arguments(val)) + return defslot_term_matcher_constructor(val, acSets) end + # else return a normal term matcher + return term_matcher_constructor(val, acSets) end function literal_matcher(next, data, bindings) @@ -24,7 +24,8 @@ function matcher(val::Any) end end -function matcher(slot::Slot) +# acSets is not used but needs to be there in case matcher(::Slot) is directly called from the macro +function matcher(slot::Slot, acSets) function slot_matcher(next, data, bindings) !islist(data) && return nothing val = get(bindings, slot.name, nothing) @@ -43,8 +44,8 @@ end # this is called only when defslot_term_matcher finds the operation and tries # to match it, so no default value used. So the same function as slot_matcher # can be used -function matcher(defslot::DefSlot) - matcher(Slot(defslot.name, defslot.predicate)) +function matcher(defslot::DefSlot, acSets) + matcher(Slot(defslot.name, defslot.predicate), nothing) # slot matcher doesnt use acsets end # returns n == offset, 0 if failed @@ -75,7 +76,7 @@ function trymatchexpr(data, value, n) end end -function matcher(segment::Segment) +function matcher(segment::Segment, acSets) function segment_matcher(success, data, bindings) val = get(bindings, segment.name, nothing) @@ -90,12 +91,9 @@ function matcher(segment::Segment) for i=length(data):-1:0 subexpr = take_n(data, i) - if segment.predicate(subexpr) - res = success(assoc(bindings, segment.name, subexpr), i) - if res !== nothing - break - end - end + !segment.predicate(subexpr) && continue + res = success(assoc(bindings, segment.name, subexpr), i) + res !== nothing && break end return res @@ -103,85 +101,145 @@ function matcher(segment::Segment) end end -function term_matcher_constructor(term) - matchers = (matcher(operation(term)), map(matcher, arguments(term))...,) +function term_matcher_constructor(term, acSets) + matchers = (matcher(operation(term), acSets), map(x->matcher(x,acSets), arguments(term))...,) + + function loop(term, bindings′, matchers′) # Get it to compile faster + if !islist(matchers′) + if !islist(term) + return bindings′ + end + return nothing + end + car(matchers′)(term, bindings′) do b, n + loop(drop_n(term, n), b, cdr(matchers′)) + end + # explanation of above 3 lines: + # car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′) + # <------ next(b,n) ----------------------------> + # car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list + # Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term + # Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses + # the length of the list, is considered empty + end - function term_matcher(success, data, bindings) - !islist(data) && return nothing # if data is not a list, return nothing - !iscall(car(data)) && return nothing # if first element is not a call, return nothing + # if the operation is a pow, we have to match also 1/(...)^(...) with negative exponent + if operation(term) === ^ + function pow_term_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + data = car(data) # from (..., ) to ... + !iscall(data) && return nothing # if first element is not a call, return nothing + + # if data is of the alternative form (1/...)^(...), it might match with negative exponent + if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && isequal(arguments(arguments(data)[1])[1], 1) + one_over_smth = arguments(data)[1] + T = symtype(one_over_smth) + frankestein = Term{T}(^, [arguments(one_over_smth)[2], -arguments(data)[2]]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) + end - function loop(term, bindings′, matchers′) # Get it to compile faster - if !islist(matchers′) - if !islist(term) - return success(bindings′, 1) - end - return nothing + result = loop(data, bindings, matchers) + result !== nothing && return success(result, 1) + + # if data is of the alternative form 1/(...)^(...), it might match with negative exponent + if (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^) + denominator = arguments(data)[2] + T = symtype(denominator) + frankestein = Term{T}(^, [arguments(denominator)[1], -arguments(denominator)[2]]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) end - car(matchers′)(term, bindings′) do b, n - loop(drop_n(term, n), b, cdr(matchers′)) + + # if data is a exp call, it might match with base e + if operation(data)===exp + T = symtype(arguments(data)[1]) + frankestein = Term{T}(^,[ℯ,arguments(data)[1]]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) end - # explenation of above 3 lines: - # car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′) - # <------ next(b,n) ----------------------------> - # car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list - # Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term - # Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses - # the length of the list, is considered empty - end - loop(car(data), bindings, matchers) # Try to eat exactly one term + # if data is a sqrt call, it might match with exponent 1//2 + if operation(data)===sqrt + T = symtype(arguments(data)[1]) + frankestein = Term{T}(^,[arguments(data)[1], 1//2]) + result = loop(frankestein, bindings, matchers) + result !== nothing && return success(result, 1) + end + + return nothing + end + return pow_term_matcher + # if we want to do commutative checks, i.e. call matcher with different order of the arguments + elseif acSets!==nothing && operation(term) in [+, *] + function commutative_term_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + !iscall(car(data)) && return nothing # if first element is not a call, return nothing + operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try + + T = symtype(car(data)) + if T <: Number + f = operation(car(data)) + data_args = arguments(car(data)) + + for inds in acSets(eachindex(data_args), length(data_args)) + candidate = Term{T}(f, @views data_args[inds]) + + result = loop(candidate, bindings, matchers) + result !== nothing && return success(result,1) + end + # if car(data) does not subtype to number, it might not be commutative + else + # call the normal matcher + result = loop(car(data), bindings, matchers) + result !== nothing && return success(result, 1) + end + return nothing + end + return commutative_term_matcher + else + function term_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + !iscall(car(data)) && return nothing # if first element is not a call, return nothing + + result = loop(car(data), bindings, matchers) + result !== nothing && return success(result, 1) + return nothing + end + return term_matcher end end # creates a matcher for a term containing a defslot, such as: # (~x + ...complicated pattern...) * ~!y # normal part (can bee a tree) operation defslot part - -# defslot_term_matcher works like this: -# checks wether data starts with the default operation. -# if yes (1): continues like term_matcher -# if no checks wether data matches the normal part -# if no returns nothing, rule is not applied -# if yes (2): adds the pair (default value name, default value) to the found bindings and -# calls the success function like term_matcher would do - -function defslot_term_matcher_constructor(term) - a = arguments(term) # lenght two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments - matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term - +function defslot_term_matcher_constructor(term, acSets) + a = arguments(term) defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term defslot = a[defslot_index] + if length(a) == 2 + other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets) + else + others = [a[i] for i in eachindex(a) if i != defslot_index] + T = symtype(term) + f = operation(term) + other_part_matcher = term_matcher_constructor(Term{T}(f, others), acSets) + end - function defslot_term_matcher(success, data, bindings) - # if data is not a list, return nothing - !islist(data) && return nothing - # if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation) - if !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation) - other_part_matcher = matchers[defslot_index==2 ? 2 : 3] # find the matcher of the normal part - - # checks wether it matches the normal part - # <-----------------(2)-------------------------------> - bindings = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings) - - if bindings === nothing - return nothing - end - return success(bindings, 1) - end - - # (1) - function loop(term, bindings′, matchers′) # Get it to compile faster - if !islist(matchers′) - if !islist(term) - return success(bindings′, 1) - end - return nothing - end - car(matchers′)(term, bindings′) do b, n - loop(drop_n(term, n), b, cdr(matchers′)) - end - end + normal_matcher = term_matcher_constructor(term, acSets) - loop(car(data), bindings, matchers) # Try to eat exactly one term + function defslot_term_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + # call the normal matcher, with success function foo1 that simply returns the bindings + # <--foo1--> + result = normal_matcher((b,n) -> b, data, bindings) + result !== nothing && return success(result, 1) + # if no match, try to match with a defslot. + # checks whether it matches the normal part if yes executes foo2 + # foo2: adds the pair (default value name, default value) to the found bindings + # <-------------------foo2----------------------------> + result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings) + result !== nothing && return success(result, 1) + nothing end end diff --git a/src/rule.jl b/src/rule.jl index d20598b2..8e376acf 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -82,7 +82,7 @@ function makeDefSlot(s::Expr, keys, op) push!(keys, name) tmp = defaultValOfCall(op) - :(DefSlot($(QuoteNode(name)), $(esc(s.args[2])), $(esc(op))), $(esc(tmp))) + :(DefSlot($(QuoteNode(name)), $(esc(s.args[2])), $(esc(op)), $(esc(tmp)))) end @@ -130,6 +130,9 @@ function makepattern(expr, keys, parentCall=nothing) # matches ~x::predicate makeslot(expr.args[2], keys) end + elseif expr.args[1] === :(//) + # bc when the expression is not quoted, 3//2 is a Rational{Int64}, not a call + return esc(expr.args[2] // expr.args[3]) else # make a pattern for every argument of the expr. :(term($(map(x->makepattern(x, keys, operation(expr)), expr.args)...); type=Any)) @@ -373,11 +376,13 @@ macro rule(expr) quote $(__source__) lhs_pattern = $(lhs_term) - Rule($(QuoteNode(expr)), - lhs_pattern, - matcher(lhs_pattern), - __MATCHES__ -> $(makeconsequent(rhs)), - rule_depth($lhs_term)) + Rule( + $(QuoteNode(expr)), + lhs_pattern, + matcher(lhs_pattern, permutations), + __MATCHES__ -> $(makeconsequent(rhs)), + rule_depth($lhs_term) + ) end end @@ -412,7 +417,7 @@ macro capture(ex, lhs) lhs_pattern = $(lhs_term) __MATCHES__ = Rule($(QuoteNode(lhs)), lhs_pattern, - matcher(lhs_pattern), + matcher(lhs_pattern, nothing), identity, rule_depth($lhs_term))($(esc(ex))) if __MATCHES__ !== nothing @@ -437,16 +442,46 @@ Rule(acr::ACRule) = acr.rule getdepth(r::ACRule) = getdepth(r.rule) macro acrule(expr) - arity = length(expr.args[2].args[2:end]) + @assert expr.head == :call && expr.args[1] == :(=>) + lhs = expr.args[2] + rhs = rewrite_rhs(expr.args[3]) + keys = Symbol[] + lhs_term = makepattern(lhs, keys) + unique!(keys) + + arity = length(lhs.args[2:end]) + quote - ACRule(permutations, $(esc(:(@rule($(expr))))), $arity) + $(__source__) + lhs_pattern = $(lhs_term) + rule = Rule($(QuoteNode(expr)), + lhs_pattern, + matcher(lhs_pattern, permutations), + __MATCHES__ -> $(makeconsequent(rhs)), + rule_depth($lhs_term)) + ACRule(permutations, rule, $arity) end end macro ordered_acrule(expr) - arity = length(expr.args[2].args[2:end]) + @assert expr.head == :call && expr.args[1] == :(=>) + lhs = expr.args[2] + rhs = rewrite_rhs(expr.args[3]) + keys = Symbol[] + lhs_term = makepattern(lhs, keys) + unique!(keys) + + arity = length(lhs.args[2:end]) + quote - ACRule(combinations, $(esc(:(@rule($(expr))))), $arity) + $(__source__) + lhs_pattern = $(lhs_term) + rule = Rule($(QuoteNode(expr)), + lhs_pattern, + matcher(lhs_pattern, combinations), + __MATCHES__ -> $(makeconsequent(rhs)), + rule_depth($lhs_term)) + ACRule(combinations, rule, $arity) end end @@ -454,15 +489,11 @@ Base.show(io::IO, acr::ACRule) = print(io, "ACRule(", acr.rule, ")") function (acr::ACRule)(term) r = Rule(acr) - if !iscall(term) + if !iscall(term) || operation(term) != operation(r.lhs) + # different operations -> try deflsot r(term) else - f = operation(term) - # Assume that the matcher was formed by closing over a term - if f != operation(r.lhs) # Maybe offer a fallback if m.term errors. - return nothing - end - + f = operation(term) T = symtype(term) args = arguments(term) diff --git a/test/rewrite.jl b/test/rewrite.jl index b8996f79..ebd084e7 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -2,7 +2,7 @@ using SymbolicUtils include("utils.jl") -@syms a b c +@syms a b c d x @testset "Equality" begin @eqtest a == a @@ -47,6 +47,25 @@ end @eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, []) end +@testset "Commutative + and *" begin + r1 = @rule exp(sin(~x) + cos(~x)) => ~x + # using a or x changes the order of the arguments in the call + @test r1(exp(sin(a)+cos(a))) === a + @test r1(exp(sin(x)+cos(x))) === x + r2 = @rule (~x+~y)*(~z+~w)^(~m) => (~x, ~y, ~z, ~w, ~m) + r3 = @rule (~z+~w)^(~m)*(~x+~y) => (~x, ~y, ~z, ~w, ~m) + @test r2((a+b)*(x+c)^b) === (a, b, x, c, b) + @test r3((a+b)*(x+c)^b) === (a, b, x, c, b) + rPredicate1 = @rule ~x::(x->isa(x,Number)) + ~y => (~x, ~y) + rPredicate2 = @rule ~y + ~x::(x->isa(x,Number)) => (~x, ~y) + @test rPredicate1(2+x) === (2, x) + @test rPredicate2(2+x) === (2, x) + r5 = @rule (~y*(~z+~w))+~x => (~x, ~y, ~z, ~w) + r6 = @rule ~x+((~z+~w)*~y) => (~x, ~y, ~z, ~w) + @test r5(c*(a+b)+d) === (d, c, a, b) + @test r6(c*(a+b)+d) === (d, c, a, b) +end + @testset "Slot matcher with default value" begin r_sum = @rule (~x + ~!y)^2 => ~y @test r_sum((a + b)^2) === b @@ -76,10 +95,45 @@ end @test r_pow2((a+b)^c) === c @test r_pow2(a+b) === 1 - r_mix = @rule (~x + (~y)*(~!c))^(~!m) => ~m + ~c - @test r_mix((a + b*c)^2) === 2 + c - @test r_mix((a + b*c)) === 1 + c - @test r_mix((a + b)) === 2 #1+1 + r_mix = @rule (~x + (~y)*(~!c))^(~!m) => (~m, ~c) + @test r_mix((a + b*c)^2) === (2, c) + @test r_mix((a + b*c)) === (1, c) + @test r_mix((a + b)) === (1, 1) + + r_more_than_two_arguments = @rule (~!a)*exp(~x)*sin(~x) => (~a, ~x) + @test r_more_than_two_arguments(sin(x)*exp(x)) === (1, x) + @test r_more_than_two_arguments(sin(x)*exp(x)*a) === (a, x) + + r_mixmix = @rule (~!a)*exp(~x)*sin(~!b + (~x)^2 + ~x) => (~a, ~b, ~x) + @test r_mixmix(exp(x)*sin(1+x+x^2)*2) === (2, 1, x) + @test r_mixmix(exp(x)*sin(x+x^2)*2) === (2, 0, x) + @test r_mixmix(exp(x)*sin(x+x^2)) === (1, 0, x) + + r_predicate = @rule ~x + (~!m::(var->isa(var, Int))) => (~x, ~m) + @test r_predicate(x+2) === (x, 2) + @test r_predicate(x+2.0) !== (x, 2.0) + # Note: r_predicate(x+2.0) doesnt return nothing, but (x+2.0, 0) + # becasue of the defslot +end + +@testset "power matcher with negative exponent" begin + r1 = @rule (~x)^(~y) => (~x, ~y) # rule with slot as exponent + @test r1(1/a^b) === (a, -b) # uses frankestein + @test r1(1/a^(b+2c)) === (a, -b-2c) # uses frankestein + @test r1(1/a^2) === (a, -2) # uses opposite_sign_matcher + + r2 = @rule (~x)^(~y + ~z) => (~x, ~y, ~z) # rule with term as exponent + @test r2(1/a^(b+2c)) === (a, -b, -2c) # uses frankestein + @test r2(1/a^3) === nothing # should use a term_matcher that flips the sign, but is not implemented + + r1defslot = @rule (~x)^(~!y) => (~x, ~y) # rule with defslot as exponent + @test r1defslot(1/a^b) === (a, -b) # uses frankestein + @test r1defslot(1/a^(b+2c)) === (a, -b-2c) # uses frankestein + @test r1defslot(1/a^2) === (a, -2) # uses opposite_sign_matcher + @test r1defslot(a) === (a, 1) + + r = @rule (~x + ~y)^(~m) => (~x, ~y, ~m) # rule to match (1/...)^(...) + @test r((1/(a+b))^3) === (a,b,-3) end using SymbolicUtils: @capture