diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index 60519d5b..16adf00b 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -230,19 +230,19 @@ function find_term(a::Node) return term, constant end -function matching_terms(lchild, rchild) - if lchild === rchild - return (1, 1, lchild) - else - lterm, lconstant = find_term(lchild) - rterm, rconstant = find_term(rchild) - if lterm !== nothing && rterm !== nothing && lterm === rterm - return (lconstant, rconstant, lterm) - else - return nothing - end - end -end +# function matching_terms(lchild, rchild) +# if lchild === rchild +# return (1, 1, lchild) +# else +# lterm, lconstant = find_term(lchild) +# rterm, rconstant = find_term(rchild) +# if lterm !== nothing && rterm !== nothing && lterm === rterm +# return (lconstant, rconstant, lterm) +# else +# return nothing +# end +# end +# end function simplify_check_cache(::typeof(+), na, nb, cache)::Node a = Node(na) @@ -262,8 +262,8 @@ function simplify_check_cache(::typeof(+), na, nb, cache)::Node return Node(value(a) + value(children(b)[1])) + children(b)[2] elseif a === b return 2 * a - elseif (tmp = matching_terms(a, b)) !== nothing - return (tmp[1] + tmp[2]) * tmp[3] + # elseif (tmp = matching_terms(a, b)) !== nothing + # return (tmp[1] + tmp[2]) * tmp[3] else return check_cache((+, a, b), cache) end @@ -283,8 +283,8 @@ function simplify_check_cache(::typeof(-), na, nb, cache)::Node elseif is_constant(a) && is_constant(b) return Node(value(a) - value(b)) - elseif (tmp = matching_terms(a, b)) !== nothing - return (tmp[1] - tmp[2]) * tmp[3] + # elseif (tmp = matching_terms(a, b)) !== nothing + # return (tmp[1] - tmp[2]) * tmp[3] else return check_cache((-, a, b), cache) end diff --git a/src/Factoring.jl b/src/Factoring.jl index 1e7bd78e..e38f2805 100644 --- a/src/Factoring.jl +++ b/src/Factoring.jl @@ -103,23 +103,26 @@ Base.isless(::FactorOrder, a, b) = factor_order(a, b) """returns true if a should be sorted before b""" function factor_order(a::FactorableSubgraph, b::FactorableSubgraph) - if times_used(a) > times_used(b) #num_uses of contained subgraphs always ≥ num_uses of containing subgraphs. Contained subgraphs should always be factored first. It might be that a ⊄ b, but it's still correct to factor a before b. + + diffa = node_difference(a) + diffb = node_difference(b) + + + # if a ⊂ b then diff(a) < diff(b) where diff(x) = abs(dominating_node(a) - dominated_node(a)). Might be that a ⊄ b but it's safe to factor a first. + # This tests only guarantees that if a ⊂ b then a will be factored first. It could be that diff(a) < diff(b) but that a ⊄ b in which case most + # efficient option would be to factor whichever of a,b has highest times used. But determining subgraph containment precisely is time consuming. + # This ordering heuristic doesn't seem to affect efficiency of computed derivatives much but is significantly faster. + if diffa < diffb return true - elseif times_used(b) > times_used(a) + elseif times_used(a) > times_used(b) #If a is used more times than b then factor a first. More efficient. + return true + else return false - else # if a ⊂ b then diff(a) < diff(b) where diff(x) = abs(dominating_node(a) - dominated_node(a)). Might be that a ⊄ b but it's safe to factor a first and if a ⊂ b then it will always be factored first. - diffa = node_difference(a) - diffb = node_difference(b) - - if diffa < diffb - return true - else - return false #can factor a,b in either order - end end end + sort_in_factor_order!(a::AbstractVector{T}) where {T<:FactorableSubgraph} = sort!(a, lt=factor_order) @@ -625,3 +628,32 @@ function number_of_operations(jacobian::AbstractArray{T}) where {T<:Node} end return count end + +function new_factor_subgraph!(a::FactorableSubgraph{T}) where {T} + local new_edge::PathEdge{T} + if subgraph_exists(subgraph) + + if is_branching(subgraph) #handle the uncommon case of factorization creating new factorable subgraphs internal to subgraph + sum = evaluate_branching_subgraph(subgraph) + new_edge = make_factored_edge(subgraph, sum) + else + sum = evaluate_subgraph(subgraph) + # if value(sum) == 0 + # display(subgraph) + # write_dot("sph.svg", graph(subgraph), value_labels=true, reachability_labels=false, start_nodes=[24]) + # end + + # # @assert value(sum) != 0 + + new_edge = make_factored_edge(subgraph, sum) + end + add_non_dom_edges!(subgraph) + #reset roots in R, if possible. All edges earlier in the path than the first vertex with more than one child cannot be reset. + edges_to_delete = reset_edge_masks!(subgraph) + for edge in edges_to_delete + delete_edge!(graph(subgraph), edge) + end + + add_edge!(graph(subgraph), new_edge) + end +end