diff --git a/src/matchers.jl b/src/matchers.jl index 7f4dea53..e5ace3a1 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -5,29 +5,67 @@ # 2. Dictionary # 3. Callback: takes arguments Dictionary × Number of elements matched # + function matcher(val::Any) - iscall(val) && return term_matcher(val) + # if val is a call (like an operation) creates a term matcher or term matcher with defslot + if iscall(val) + # if has two arguments and one of them is a DefSlot, create a term matcher with defslot + # just two arguments bc defslot is only supported with operations with two args: *, ^, + + if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val)) + return defslot_term_matcher_constructor(val) + # else return a normal term matcher + else + return term_matcher_constructor(val) + end + end + function literal_matcher(next, data, bindings) + # car data is the first element of data islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing end end function matcher(slot::Slot) function slot_matcher(next, data, bindings) - !islist(data) && return + !islist(data) && return nothing val = get(bindings, slot.name, nothing) + # if slot name already is in bindings, check if it matches if val !== nothing if isequal(val, car(data)) return next(bindings, 1) end - else - if slot.predicate(car(data)) - next(assoc(bindings, slot.name, car(data)), 1) + # elseif the first element of data matches the slot predicate, add it to bindings and call next + elseif slot.predicate(car(data)) + next(assoc(bindings, slot.name, car(data)), 1) + end + end +end + +function opposite_sign_matcher(slot::Slot) + function slot_matcher(next, data, bindings) + !islist(data) && return nothing + val = get(bindings, slot.name, nothing) + if val !== nothing + if isequal(val, car(data)) + return next(bindings, 1) end + elseif slot.predicate(car(data)) + next(assoc(bindings, slot.name, -car(data)), 1) # this - is the only differenct wrt matcher(slot::Slot) end end end +function opposite_sign_matcher(defslot::DefSlot) + opposite_sign_matcher(Slot(defslot.name, defslot.predicate)) +end + +# this is called only when defslot_term_matcher finds the operation and tries +# to match it, so no default value used. So the same function as slot_matcher +# can be used +function matcher(defslot::DefSlot) + matcher(Slot(defslot.name, defslot.predicate)) +end + # returns n == offset, 0 if failed function trymatchexpr(data, value, n) if !islist(value) @@ -84,12 +122,12 @@ function matcher(segment::Segment) end end -function term_matcher(term) +function term_matcher_constructor(term) matchers = (matcher(operation(term)), map(matcher, arguments(term))...,) - function term_matcher(success, data, bindings) - !islist(data) && return nothing - !iscall(car(data)) && return nothing + function term_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + !iscall(car(data)) && return nothing # if first element is not a call, return nothing function loop(term, bindings′, matchers′) # Get it to compile faster if !islist(matchers′) @@ -101,8 +139,123 @@ function term_matcher(term) car(matchers′)(term, bindings′) do b, n loop(drop_n(term, n), b, cdr(matchers′)) end + # explenation of above 3 lines: + # car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′) + # <------ next(b,n) ----------------------------> + # car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list + # Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term + # Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses + # the length of the list, is considered empty end - loop(car(data), bindings, matchers) # Try to eat exactly one term + result = loop(car(data), bindings, matchers) + # if data is of the alternative form 1/(...)^(...), it might match with negative exponent + if operation(term)==^ + alternative_form = (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) + if result === nothing && alternative_form + denominator = arguments(car(data))[2] + # let's say data = a^b with a and b can be whatever + # if b is not a number then call the loop function with a^-b + if !isa(arguments(denominator)[2], Number) + frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2]) + result = loop(frankestein, bindings, matchers) + else + # if b is a number, like 3, we cant call loop with a^-3 bc it + # will automatically transform into 1/a^3. Therfore we need to + # create a matcher that flips the sign of the exponent. I created + # this matecher just for `Slot`s and not for terms, because if b + # is a number and not a call, certainly doesn't match a term (I hope). + if isa(arguments(term)[2], Slot) + matchers2 = (matcher(operation(term)), matcher(arguments(term)[1]), opposite_sign_matcher(arguments(term)[2])) # is this ok to be here or should it be outside neg_pow_term_matcher? + result = loop(denominator, bindings, matchers2) + end + end + end + end + result + end +end + +# creates a matcher for a term containing a defslot, such as: +# (~x + ...complicated pattern...) * ~!y +# normal part (can bee a tree) operation defslot part + +# defslot_term_matcher works like this: +# checks wether data is a call. +# if yes (1): continues like term_matcher (if it finds a match returns (2)) +# if still no match found checks wether data (is just a symbol) or (is a tree not starting with the default operation) +# if no returns nothing, rule is not applied +# if yes (3): adds the pair (default value name, default value) to the found bindings and +# calls the success function like term_matcher would do + +function defslot_term_matcher_constructor(term) + a = arguments(term) # lenght two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments + matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term + + defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term + defslot = a[defslot_index] + + function defslot_term_matcher(success, data, bindings) + # if data is not a list, return nothing + !islist(data) && return nothing + result = nothing + if iscall(car(data)) + # (1) + function loop(term, bindings′, matchers′) # Get it to compile faster + if !islist(matchers′) + if !islist(term) + return success(bindings′, 1) + end + return nothing + end + car(matchers′)(term, bindings′) do b, n + loop(drop_n(term, n), b, cdr(matchers′)) + end + end + + result = loop(car(data), bindings, matchers) # Try to eat exactly one term + # if data is of the alternative form 1/(...)^(...), it might match with negative exponent + if operation(term)==^ + alternative_form = (operation(car(data))==/) && arguments(car(data))[1]==1 && iscall(arguments(car(data))[2]) && (operation(arguments(car(data))[2])==^) + if result === nothing && alternative_form + denominator = arguments(car(data))[2] + # let's say data = a^b with a and b can be whatever + # if b is not a number then call the loop function with a^-b + if !isa(arguments(denominator)[2], Number) + frankestein = arguments(denominator)[1] ^ -(arguments(denominator)[2]) + result = loop(frankestein, bindings, matchers) + else + # if b is a number, like 3, we cant call loop with a^-3 bc it + # will automatically transform into 1/a^3. Therfore we need to + # create a matcher that flips the sign of the exponent. I created + # this matecher just for `DefSlot`s and not for terms, because if b + # is a number and not a call, certainly doesn't match a term (I hope). + if isa(arguments(term)[2], DefSlot) + matchers2 = (matcher(operation(term)), matcher(arguments(term)[1]), opposite_sign_matcher(arguments(term)[2])) # is this ok to be here or should it be outside neg_pow_term_matcher? + result = loop(denominator, bindings, matchers2) + end + end + end + end + # (2) + if result !== nothing + return result + end + end + + # if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation) + if ( !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation) ) + other_part_matcher = matchers[defslot_index==2 ? 2 : 3] # find the matcher of the normal part + + # checks wether it matches the normal part + # <-----------------(3)-------------------------------> + bindings = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings) + + if bindings === nothing + return nothing + end + result = success(bindings, 1) + end + result end end diff --git a/src/rule.jl b/src/rule.jl index 5de0aa79..d20598b2 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -1,7 +1,7 @@ @inline alwaystrue(x) = true -# Matcher patterns with Slot and Segment +# Matcher patterns with Slot, DefSlot and Segment # matches one term # syntax: ~x @@ -16,6 +16,79 @@ Base.isequal(s1::Slot, s2::Slot) = s1.name == s2.name Base.show(io::IO, s::Slot) = (print(io, "~"); print(io, s.name)) +# for when the slot is a symbol, like `~x` +makeslot(s::Symbol, keys) = (push!(keys, s); Slot(s)) + +# for when the slot is an expression, like `~x::predicate` +function makeslot(s::Expr, keys) + if !(s.head == :(::)) + error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function") + end + + name = s.args[1] + + push!(keys, name) + :(Slot($(QuoteNode(name)), $(esc(s.args[2])))) +end + + + + + + +# matches one term with built in default value. +# syntax: ~!x +# Example usage: +# (~!x + ~y) can match (a + b) but also just "a" and x takes default value of zero. +# (~!x)*(~y) can match a*b but also just "a", and x takes default value of one. +# (~x + ~y)^(~!z) can match (a + b)^c but also just "a + b", and z takes default value of one. +# only these three operations are supported for default values. + +struct DefSlot{P, O} + name::Symbol + predicate::P + operation::O + defaultValue::Real +end + +# operation | default +# + | 0 +# * | 1 +# ^ | 1 +function defaultValOfCall(call) + if call == :+ + return 0 + elseif call == :* + return 1 + elseif call == :^ + return 1 + end + # else no default value for this call + error("You can use default slots only with +, * and ^, but you tried with: $call") +end + +DefSlot(s) = DefSlot(s, alwaystrue, nothing, 0) +Base.isequal(s1::DefSlot, s2::DefSlot) = s1.name == s2.name +Base.show(io::IO, s::DefSlot) = (print(io, "~!"); print(io, s.name)) + +makeDefSlot(s::Symbol, keys, op) = (push!(keys, s); DefSlot(s, alwaystrue, op, defaultValOfCall(op))) + +function makeDefSlot(s::Expr, keys, op) + if !(s.head == :(::)) + error("Syntax for specifying a default slot is ~!x::\$predicate, where predicate is a boolean function") + end + + name = s.args[1] + + push!(keys, name) + tmp = defaultValOfCall(op) + :(DefSlot($(QuoteNode(name)), $(esc(s.args[2])), $(esc(op))), $(esc(tmp))) +end + + + + + # matches zero or more terms # syntax: ~~x struct Segment{F} @@ -37,37 +110,29 @@ function makesegment(s::Expr, keys) end name = s.args[1] - + push!(keys, name) :(Segment($(QuoteNode(name)), $(esc(s.args[2])))) end -makeslot(s::Symbol, keys) = (push!(keys, s); Slot(s)) - -function makeslot(s::Expr, keys) - if !(s.head == :(::)) - error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function") - end - - name = s.args[1] - - push!(keys, name) - :(Slot($(QuoteNode(name)), $(esc(s.args[2])))) -end - -function makepattern(expr, keys) +# parent call is needed to know which default value to give if any default slots are present +function makepattern(expr, keys, parentCall=nothing) if expr isa Expr if expr.head === :call if expr.args[1] === :(~) if expr.args[2] isa Expr && expr.args[2].args[1] == :(~) # matches ~~x::predicate makesegment(expr.args[2].args[2], keys) + elseif expr.args[2] isa Expr && expr.args[2].args[1] == :(!) + # matches ~!x::predicate + makeDefSlot(expr.args[2].args[2], keys, parentCall) else # matches ~x::predicate makeslot(expr.args[2], keys) end else - :(term($(map(x->makepattern(x, keys), expr.args)...); type=Any)) + # make a pattern for every argument of the expr. + :(term($(map(x->makepattern(x, keys, operation(expr)), expr.args)...); type=Any)) end elseif expr.head === :ref :(term(getindex, $(map(x->makepattern(x, keys), expr.args)...); type=Any)) diff --git a/test/rewrite.jl b/test/rewrite.jl index c2e920f9..3cd545f2 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -47,6 +47,58 @@ end @eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, []) end +@testset "Slot matcher with default value" begin + r_sum = @rule (~x + ~!y)^2 => ~y + @test r_sum((a + b)^2) === b + @test r_sum(b^2) === 0 + + r_mult = @rule ~x * ~!y => ~y + @test r_mult(a * b) === b + @test r_mult(a) === 1 + + r_mult2 = @rule (~x * ~!y + ~z) => ~y + @test r_mult2(c + a*b) === b + @test r_mult2(c + b) === 1 + + # here the "normal part" in the defslot_term_matcher is not a symbol but a tree + r_mult3 = @rule (~!x)*(~y + ~z) => ~x + @test r_mult3(a*(c+2)) === a + @test r_mult3(2*(c+2)) === 2 + @test r_mult3(c+2) === 1 + + r_pow = @rule (~x)^(~!m) => ~m + @test r_pow(a^(b+1)) === b+1 + @test r_pow(a) === 1 + @test r_pow(a+1) === 1 + + # here the "normal part" in the defslot_term_matcher is not a symbol but a tree + r_pow2 = @rule (~x + ~y)^(~!m) => ~m + @test r_pow2((a+b)^c) === c + @test r_pow2(a+b) === 1 + + r_mix = @rule (~x + (~y)*(~!c))^(~!m) => ~m + ~c + @test r_mix((a + b*c)^2) === (2, c) + @test r_mix((a + b*c)) === (1, c) + @test r_mix((a + b)) === (1, 1) +end + +@testset "1/power matches power with exponent of opposite sign" begin + r1 = @rule (~x)^(~y) => (~x, ~y) # rule with slot as exponent + @test r1(1/a^b) === (a, -b) # uses frankestein + @test r1(1/a^(b+2c)) === (a, -b-2c) # uses frankestein + @test r1(1/a^2) === (a, -2) # uses opposite_sign_matcher + + r2 = @rule (~x)^(~y + ~z) => (~x, ~y, ~z) # rule with term as exponent + @test r2(1/a^(b+2c)) === (a, -b, -2c) # uses frankestein + @test r2(1/a^3) === nothing # should use a term_matcher that flips the sign, but is not implemented + + r1defslot = @rule (~x)^(~!y) => (~x, ~y) # rule with slot as exponent + @test r1defslot(1/a^b) === (a, -b) # uses frankestein + @test r1defslot(1/a^(b+2c)) === (a, -b-2c) # uses frankestein + @test r1defslot(1/a^2) === (a, -2) # uses opposite_sign_matcher + @test r1defslot(a) === (a, 1) +end + using SymbolicUtils: @capture @testset "Capture form" begin