diff --git a/docs/src/manual/rewrite.md b/docs/src/manual/rewrite.md index 047bef71a..f65c94f59 100644 --- a/docs/src/manual/rewrite.md +++ b/docs/src/manual/rewrite.md @@ -68,13 +68,15 @@ If you want to match a variable number of subexpressions at once, you will need ```jldoctest rewrite @syms x y z -@rule(+(~~xs) => ~~xs)(x + y + z) +r = @rule(+(~~xs) => sort!(~~xs, by=get_degrees)) +expr = x + y + z +r(expr) # output 3-element view(::Vector{Any}, 1:3) with eltype Any: - z - y x + y + z ``` `~~xs` is a vector of subexpressions matched. You can use it to construct something more useful: diff --git a/src/ordering.jl b/src/ordering.jl index 332f11cf8..8d628a589 100644 --- a/src/ordering.jl +++ b/src/ordering.jl @@ -41,7 +41,7 @@ function get_degrees(expr) elseif op == (getindex) return ((Symbol.(args)...,) => 1,) else - return ((Symbol("zzzzzzz", hash(expr)),) => 1,) + return ((Symbol("zzzzzzz", hash2(expr)),) => 1,) end else return () diff --git a/src/types.jl b/src/types.jl index 683f58d44..828602ffa 100644 --- a/src/types.jl +++ b/src/types.jl @@ -271,30 +271,30 @@ Base.zero(s::Symbolic) = zero(symtype(s)) Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymbolic doesn't have a name") ## This is much faster than hash of an array of Any -hashvec(xs, z) = foldr(hash, xs, init=z) +hashvec(xs, z) = foldr(hash2, xs, init=z) const SYM_SALT = 0x4de7d7c66d41da43 % UInt const ADD_SALT = 0xaddaddaddaddadda % UInt const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt const DIV_SALT = 0x334b218e73bbba53 % UInt const POW_SALT = 0x2b55b97a6efb080c % UInt -function Base.hash(s::BasicSymbolic, salt::UInt)::UInt +function hash2(s::BasicSymbolic, salt::UInt)::UInt E = exprtype(s) if E === SYM hash(nameof(s), salt ⊻ SYM_SALT) elseif E === ADD || E === MUL - !iszero(salt) && return hash(hash(s, zero(UInt)), salt) + !iszero(salt) && return hash(hash2(s, zero(UInt)), salt) h = s.hash[] !iszero(h) && return h hashoffset = isadd(s) ? ADD_SALT : SUB_SALT - h′ = hash(hashoffset, hash(s.coeff, hash(s.dict, salt))) + h′ = hash(hashoffset, hash2(s.coeff, hash2(s.dict, salt))) s.hash[] = h′ return h′ elseif E === DIV - return hash(s.num, hash(s.den, salt ⊻ DIV_SALT)) + return hash2(s.num, hash2(s.den, salt ⊻ DIV_SALT)) elseif E === POW - hash(s.exp, hash(s.base, salt ⊻ POW_SALT)) + hash2(s.exp, hash2(s.base, salt ⊻ POW_SALT)) elseif E === TERM - !iszero(salt) && return hash(hash(s, zero(UInt)), salt) + !iszero(salt) && return hash(hash2(s, zero(UInt)), salt) h = s.hash[] !iszero(h) && return h op = operation(s) @@ -306,6 +306,19 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt error_on_type() end end +hash2(s::BasicSymbolic) = hash2(s, zero(UInt)) +hash2(s, salt::UInt) = hash(s, salt) +function hash2(a::EMPTY_DICT_T, h::UInt) + hv = Base.hasha_seed + for (k, v) in a + hv ⊻= hash2(k, hash(v)) + end + hash(hv, h) +end + +function Base.hash(s::BasicSymbolic{T}, salt::UInt)::UInt where {T} + hash(metadata(s), hash(T, hash2(s, salt))) +end ### ### Constructors