Skip to content

Commit 0ea0671

Browse files
rewrite simplify
1 parent 4bcabfc commit 0ea0671

File tree

1 file changed

+163
-176
lines changed

1 file changed

+163
-176
lines changed

src/simplify_rules.jl

Lines changed: 163 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -6,180 +6,167 @@ the argument to the predicate satisfies `iscall` and `operation(x) == f`
66
"""
77
is_operation(f) = @nospecialize(x) -> iscall(x) && (operation(x) == f)
88

9-
let
10-
CANONICALIZE_PLUS = [
11-
@rule(~x::isnotflat(+) => flatten_term(+, ~x))
12-
@rule(~x::needs_sorting(+) => sort_args(+, ~x))
13-
@ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b)
14-
15-
@acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...))
16-
17-
@acrule(~x + *(~β, ~x) => *(1 + ~β, ~x))
18-
@acrule(*(~α::is_literal_number, ~x) + ~x => *(~α + 1, ~x))
19-
@rule(+(~~x::hasrepeats) => +(merge_repeats(*, ~~x)...))
20-
21-
@ordered_acrule((~z::_iszero + ~x) => ~x)
22-
@rule(+(~x) => ~x)
23-
]
24-
25-
PLUS_DISTRIBUTE = [
26-
@acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...))
27-
@acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...))
28-
]
29-
30-
CANONICALIZE_TIMES = [
31-
@rule(~x::isnotflat(*) => flatten_term(*, ~x))
32-
@rule(~x::needs_sorting(*) => sort_args(*, ~x))
33-
34-
@ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b)
35-
@rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...))
36-
37-
@acrule((~y)^(~n) * ~y => (~y)^(~n+1))
38-
39-
@ordered_acrule((~z::_isone * ~x) => ~x)
40-
@ordered_acrule((~z::_iszero * ~x) => ~z)
41-
@rule(*(~x) => ~x)
42-
]
43-
44-
MUL_DISTRIBUTE = @ordered_acrule((~x)^(~n) * (~x)^(~m) => (~x)^(~n + ~m))
45-
46-
CANONICALIZE_POW = [
47-
@rule(^(*(~~x), ~y::_isinteger) => *(map(a->pow(a, ~y), ~~x)...))
48-
@rule((((~x)^(~p::_isinteger))^(~q::_isinteger)) => (~x)^((~p)*(~q)))
49-
@rule(^(~x, ~z::_iszero) => 1)
50-
@rule(^(~x, ~z::_isone) => ~x)
51-
@rule(inv(~x) => 1/(~x))
52-
]
53-
54-
POW_RULES = [
55-
@rule(^(~x::_isone, ~z) => 1)
56-
]
57-
58-
ASSORTED_RULES = [
59-
@rule(identity(~x) => ~x)
60-
@rule(-(~x) => -1*~x)
61-
@rule(-(~x, ~y) => ~x + -1(~y))
62-
@rule(~x::_isone \ ~y => ~y)
63-
@rule(~x \ ~y => ~y / (~x))
64-
@rule(one(~x) => one(symtype(~x)))
65-
@rule(zero(~x) => zero(symtype(~x)))
66-
@rule(conj(~x::_isreal) => ~x)
67-
@rule(real(~x::_isreal) => ~x)
68-
@rule(imag(~x::_isreal) => zero(symtype(~x)))
69-
@rule(ifelse(~x::is_literal_number, ~y, ~z) => ~x ? ~y : ~z)
70-
@rule(ifelse(~x, ~y, ~y) => ~y)
71-
]
72-
73-
TRIG_EXP_RULES = [
74-
@acrule(~r*~x::has_trig_exp + ~r*~y => ~r*(~x + ~y))
75-
@acrule(~r*~x::has_trig_exp + -1*~r*~y => ~r*(~x - ~y))
76-
@acrule(sin(~x)^2 + cos(~x)^2 => one(~x))
77-
@acrule(sin(~x)^2 + -1 => -1*cos(~x)^2)
78-
@acrule(cos(~x)^2 + -1 => -1*sin(~x)^2)
79-
80-
@acrule(cos(~x)^2 + -1*sin(~x)^2 => cos(2 * ~x))
81-
@acrule(sin(~x)^2 + -1*cos(~x)^2 => -cos(2 * ~x))
82-
@acrule(cos(~x) * sin(~x) => sin(2 * ~x)/2)
83-
84-
@acrule(tan(~x)^2 + -1*sec(~x)^2 => one(~x))
85-
@acrule(-1*tan(~x)^2 + sec(~x)^2 => one(~x))
86-
@acrule(tan(~x)^2 + 1 => sec(~x)^2)
87-
@acrule(sec(~x)^2 + -1 => tan(~x)^2)
88-
89-
@acrule(cot(~x)^2 + -1*csc(~x)^2 => one(~x))
90-
@acrule(cot(~x)^2 + 1 => csc(~x)^2)
91-
@acrule(csc(~x)^2 + -1 => cot(~x)^2)
92-
93-
@acrule(cosh(~x)^2 + -1*sinh(~x)^2 => one(~x))
94-
@acrule(cosh(~x)^2 + -1 => sinh(~x)^2)
95-
@acrule(sinh(~x)^2 + 1 => cosh(~x)^2)
96-
97-
@acrule(cosh(~x)^2 + sinh(~x)^2 => cosh(2 * ~x))
98-
@acrule(cosh(~x) * sinh(~x) => sinh(2 * ~x)/2)
99-
100-
@acrule(exp(~x) * exp(~y) => _iszero(~x + ~y) ? 1 : exp(~x + ~y))
101-
@rule(exp(~x)^(~y) => exp(~x * ~y))
102-
]
103-
104-
BOOLEAN_RULES = [
105-
@rule((true | (~x)) => true)
106-
@rule(((~x) | true) => true)
107-
@rule((false | (~x)) => ~x)
108-
@rule(((~x) | false) => ~x)
109-
@rule((true & (~x)) => ~x)
110-
@rule(((~x) & true) => ~x)
111-
@rule((false & (~x)) => false)
112-
@rule(((~x) & false) => false)
113-
114-
@rule(!(~x) & ~x => false)
115-
@rule(~x & !(~x) => false)
116-
@rule(!(~x) | ~x => true)
117-
@rule(~x | !(~x) => true)
118-
@rule(xor(~x, !(~x)) => true)
119-
@rule(xor(~x, ~x) => false)
120-
121-
@rule(~x == ~x => true)
122-
@rule(~x != ~x => false)
123-
@rule(~x < ~x => false)
124-
@rule(~x > ~x => false)
125-
126-
# simplify terms with no symbolic arguments
127-
# e.g. this simplifies term(isodd, 3, type=Bool)
128-
# or term(!, false)
129-
@rule((~f)(~x::is_literal_number) => (~f)(~x))
130-
# and this simplifies any binary comparison operator
131-
@rule((~f)(~x::is_literal_number, ~y::is_literal_number) => (~f)(~x, ~y))
132-
]
133-
134-
function number_simplifier()
135-
rule_tree = [If(iscall, Chain(ASSORTED_RULES)),
136-
If(x -> !isadd(x) && is_operation(+)(x),
137-
Chain(CANONICALIZE_PLUS)),
138-
If(is_operation(+), Chain(PLUS_DISTRIBUTE)), # This would be useful even if isadd
139-
If(x -> !ismul(x) && is_operation(*)(x),
140-
Chain(CANONICALIZE_TIMES)),
141-
If(is_operation(*), MUL_DISTRIBUTE),
142-
If(x -> !ispow(x) && is_operation(^)(x),
143-
Chain(CANONICALIZE_POW)),
144-
If(is_operation(^), Chain(POW_RULES)),
145-
] |> RestartedChain
146-
147-
rule_tree
148-
end
149-
150-
trig_exp_simplifier(;kw...) = Chain(TRIG_EXP_RULES)
151-
152-
bool_simplifier() = Chain(BOOLEAN_RULES)
153-
154-
global default_simplifier
155-
global serial_simplifier
156-
global threaded_simplifier
157-
global serial_simplifier
158-
global serial_expand_simplifier
159-
160-
function default_simplifier(; kw...)
161-
IfElse(has_trig_exp,
162-
Postwalk(IfElse(x->symtype(x) <: Number,
163-
Chain((number_simplifier(),
164-
trig_exp_simplifier())),
165-
If(x->symtype(x) <: Bool,
166-
bool_simplifier()))
167-
; kw...),
168-
Postwalk(Chain((If(x->symtype(x) <: Number,
169-
number_simplifier()),
170-
If(x->symtype(x) <: Bool,
171-
bool_simplifier())))
172-
; kw...))
173-
end
174-
175-
# reduce overhead of simplify by defining these as constant
176-
serial_simplifier = If(iscall, Fixpoint(default_simplifier()))
177-
178-
threaded_simplifier(cutoff) = Fixpoint(default_simplifier(threaded=true,
179-
thread_cutoff=cutoff))
180-
181-
serial_expand_simplifier = If(iscall,
182-
Fixpoint(Chain((expand,
183-
Fixpoint(default_simplifier())))))
184-
9+
const CANONICALIZE_PLUS = (
10+
@rule(~x::isnotflat(+) => flatten_term(+, ~x)),
11+
@rule(~x::needs_sorting(+) => sort_args(+, ~x)),
12+
@ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b),
13+
14+
@acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...)),
15+
16+
@acrule(~x + *(~β, ~x) => *(1 + ~β, ~x)),
17+
@acrule(*(~α::is_literal_number, ~x) + ~x => *(~α + 1, ~x)),
18+
@rule(+(~~x::hasrepeats) => +(merge_repeats(*, ~~x)...)),
19+
20+
@ordered_acrule((~z::_iszero + ~x) => ~x),
21+
@rule(+(~x) => ~x),
22+
)
23+
24+
const PLUS_DISTRIBUTE = (
25+
@acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...)),
26+
@acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...)),
27+
)
28+
29+
const CANONICALIZE_TIMES = (
30+
@rule(~x::isnotflat(*) => flatten_term(*, ~x)),
31+
@rule(~x::needs_sorting(*) => sort_args(*, ~x)),
32+
33+
@ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b),
34+
@rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...)),
35+
36+
@acrule((~y)^(~n) * ~y => (~y)^(~n+1)),
37+
38+
@ordered_acrule((~z::_isone * ~x) => ~x),
39+
@ordered_acrule((~z::_iszero * ~x) => ~z),
40+
@rule(*(~x) => ~x),
41+
)
42+
43+
const MUL_DISTRIBUTE = @ordered_acrule((~x)^(~n) * (~x)^(~m) => (~x)^(~n + ~m))
44+
45+
const CANONICALIZE_POW = (
46+
@rule(^(*(~~x), ~y::_isinteger) => *(map(a->pow(a, ~y), ~~x)...)),
47+
@rule((((~x)^(~p::_isinteger))^(~q::_isinteger)) => (~x)^((~p)*(~q))),
48+
@rule(^(~x, ~z::_iszero) => 1),
49+
@rule(^(~x, ~z::_isone) => ~x),
50+
@rule(inv(~x) => 1/(~x)),
51+
)
52+
53+
const POW_RULES = (
54+
@rule(^(~x::_isone, ~z) => 1),
55+
)
56+
57+
const ASSORTED_RULES = (
58+
@rule(identity(~x) => ~x),
59+
@rule(-(~x) => -1*~x),
60+
@rule(-(~x, ~y) => ~x + -1(~y)),
61+
@rule(~x::_isone \ ~y => ~y),
62+
@rule(~x \ ~y => ~y / (~x)),
63+
@rule(one(~x) => one(symtype(~x))),
64+
@rule(zero(~x) => zero(symtype(~x))),
65+
@rule(conj(~x::_isreal) => ~x),
66+
@rule(real(~x::_isreal) => ~x),
67+
@rule(imag(~x::_isreal) => zero(symtype(~x))),
68+
@rule(ifelse(~x::is_literal_number, ~y, ~z) => ~x ? ~y : ~z),
69+
@rule(ifelse(~x, ~y, ~y) => ~y),
70+
)
71+
72+
const TRIG_EXP_RULES = (
73+
@acrule(~r*~x::has_trig_exp + ~r*~y => ~r*(~x + ~y)),
74+
@acrule(~r*~x::has_trig_exp + -1*~r*~y => ~r*(~x - ~y)),
75+
@acrule(sin(~x)^2 + cos(~x)^2 => one(~x)),
76+
@acrule(sin(~x)^2 + -1 => -1*cos(~x)^2),
77+
@acrule(cos(~x)^2 + -1 => -1*sin(~x)^2),
78+
79+
@acrule(cos(~x)^2 + -1*sin(~x)^2 => cos(2 * ~x)),
80+
@acrule(sin(~x)^2 + -1*cos(~x)^2 => -cos(2 * ~x)),
81+
@acrule(cos(~x) * sin(~x) => sin(2 * ~x)/2),
82+
83+
@acrule(tan(~x)^2 + -1*sec(~x)^2 => one(~x)),
84+
@acrule(-1*tan(~x)^2 + sec(~x)^2 => one(~x)),
85+
@acrule(tan(~x)^2 + 1 => sec(~x)^2),
86+
@acrule(sec(~x)^2 + -1 => tan(~x)^2),
87+
88+
@acrule(cot(~x)^2 + -1*csc(~x)^2 => one(~x)),
89+
@acrule(cot(~x)^2 + 1 => csc(~x)^2),
90+
@acrule(csc(~x)^2 + -1 => cot(~x)^2),
91+
92+
@acrule(cosh(~x)^2 + -1*sinh(~x)^2 => one(~x)),
93+
@acrule(cosh(~x)^2 + -1 => sinh(~x)^2),
94+
@acrule(sinh(~x)^2 + 1 => cosh(~x)^2),
95+
96+
@acrule(cosh(~x)^2 + sinh(~x)^2 => cosh(2 * ~x)),
97+
@acrule(cosh(~x) * sinh(~x) => sinh(2 * ~x)/2),
98+
99+
@acrule(exp(~x) * exp(~y) => _iszero(~x + ~y) ? 1 : exp(~x + ~y)),
100+
@rule(exp(~x)^(~y) => exp(~x * ~y)),
101+
)
102+
103+
const BOOLEAN_RULES = (
104+
@rule((true | (~x)) => true),
105+
@rule(((~x) | true) => true),
106+
@rule((false | (~x)) => ~x),
107+
@rule(((~x) | false) => ~x),
108+
@rule((true & (~x)) => ~x),
109+
@rule(((~x) & true) => ~x),
110+
@rule((false & (~x)) => false),
111+
@rule(((~x) & false) => false),
112+
113+
@rule(!(~x) & ~x => false),
114+
@rule(~x & !(~x) => false),
115+
@rule(!(~x) | ~x => true),
116+
@rule(~x | !(~x) => true),
117+
@rule(xor(~x, !(~x)) => true),
118+
@rule(xor(~x, ~x) => false),
119+
120+
@rule(~x == ~x => true),
121+
@rule(~x != ~x => false),
122+
@rule(~x < ~x => false),
123+
@rule(~x > ~x => false),
124+
125+
# simplify terms with no symbolic arguments
126+
# e.g. this simplifies term(isodd, 3, type=Bool)
127+
# or term(!, false)
128+
@rule((~f)(~x::is_literal_number) => (~f)(~x)),
129+
# and this simplifies any binary comparison operator
130+
@rule((~f)(~x::is_literal_number, ~y::is_literal_number) => (~f)(~x, ~y)),
131+
)
132+
133+
const NUMBER_SIMPLIFIER = RestartedChain((
134+
If(iscall, Chain(ASSORTED_RULES)),
135+
If(x -> !isadd(x) && is_operation(+)(x),
136+
Chain(CANONICALIZE_PLUS)),
137+
If(is_operation(+), Chain(PLUS_DISTRIBUTE)), # This would be useful even if isadd
138+
If(x -> !ismul(x) && is_operation(*)(x),
139+
Chain(CANONICALIZE_TIMES)),
140+
If(is_operation(*), MUL_DISTRIBUTE),
141+
If(x -> !ispow(x) && is_operation(^)(x),
142+
Chain(CANONICALIZE_POW)),
143+
If(is_operation(^), Chain(POW_RULES)),
144+
))
145+
146+
const TRIG_EXP_SIMPLIFIER = Chain(TRIG_EXP_RULES)
147+
148+
const BOOLEAN_SIMPLIFIER = Chain(BOOLEAN_RULES)
149+
150+
151+
function get_default_simplifier(; kw...)
152+
IfElse(has_trig_exp,
153+
Postwalk(IfElse(x->symtype(x) <: Number,
154+
Chain((NUMBER_SIMPLIFIER, TRIG_EXP_SIMPLIFIER)),
155+
If(x->symtype(x) <: Bool, BOOLEAN_SIMPLIFIER))
156+
; kw...),
157+
Postwalk(Chain((If(x->symtype(x) <: Number,
158+
NUMBER_SIMPLIFIER),
159+
If(x->symtype(x) <: Bool,
160+
BOOLEAN_SIMPLIFIER)))
161+
; kw...))
185162
end
163+
164+
# reduce overhead of simplify by defining these as constant
165+
const serial_simplifier = If(iscall, Fixpoint(get_default_simplifier()))
166+
167+
threaded_simplifier(cutoff) = Fixpoint(get_default_simplifier(threaded=true,
168+
thread_cutoff=cutoff))
169+
170+
const serial_expand_simplifier = If(iscall,
171+
Fixpoint(Chain((expand,
172+
Fixpoint(get_default_simplifier())))))

0 commit comments

Comments
 (0)