Skip to content

Commit dfea2f5

Browse files
Merge pull request #697 from AayushSabharwal/as/cache-hash2
feat: cache hash2 in BasicSymbolic
2 parents 9b935fc + 051a431 commit dfea2f5

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ DynamicPolynomials = "0.5, 0.6"
4949
IfElse = "0.1"
5050
LabelledArrays = "1.5"
5151
MultivariatePolynomials = "0.5"
52-
NaNMath = "0.3, 1"
52+
NaNMath = "0.3, 1.1.2"
5353
ReverseDiff = "1"
5454
Setfield = "0.7, 0.8, 1"
5555
SpecialFunctions = "0.10, 1.0, 2"

src/types.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,21 @@ const ENABLE_HASHCONSING = Ref(true)
3434
f::Any = identity # base/num if Pow; issorted if Add/Dict
3535
arguments::Vector{Any} = EMPTY_ARGS
3636
hash::RefValue{UInt} = EMPTY_HASH
37+
hash2::RefValue{UInt} = EMPTY_HASH
3738
end
3839
mutable struct Mul{T} <: BasicSymbolic{T}
3940
coeff::Any = 0 # exp/den if Pow
4041
dict::EMPTY_DICT_T = EMPTY_DICT
4142
hash::RefValue{UInt} = EMPTY_HASH
43+
hash2::RefValue{UInt} = EMPTY_HASH
4244
arguments::Vector{Any} = EMPTY_ARGS
4345
issorted::RefValue{Bool} = NOT_SORTED
4446
end
4547
mutable struct Add{T} <: BasicSymbolic{T}
4648
coeff::Any = 0 # exp/den if Pow
4749
dict::EMPTY_DICT_T = EMPTY_DICT
4850
hash::RefValue{UInt} = EMPTY_HASH
51+
hash2::RefValue{UInt} = EMPTY_HASH
4952
arguments::Vector{Any} = EMPTY_ARGS
5053
issorted::RefValue{Bool} = NOT_SORTED
5154
end
@@ -98,11 +101,11 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple
98101
# Call outer constructor because hash consing cannot be applied in inner constructor
99102
@compactified obj::BasicSymbolic begin
100103
Sym => Sym{T}(nt_new.name; nt_new...)
101-
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)))
102-
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)))
103-
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)))
104-
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new..., hash = RefValue(UInt(0)))
105-
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new..., hash = RefValue(UInt(0)))
104+
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
105+
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
106+
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
107+
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
108+
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
106109
_ => Unityper.rt_constructor(obj){T}(;nt_new...)
107110
end
108111
end
@@ -461,6 +464,9 @@ function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
461464
if E === SYM
462465
h = hash(nameof(s), salt SYM_SALT)
463466
elseif E === ADD || E === MUL
467+
if !iszero(s.hash2[])
468+
return s.hash2[]
469+
end
464470
hashoffset = isadd(s) ? ADD_SALT : SUB_SALT
465471
hv = Base.hasha_seed
466472
for (k, v) in s.dict
@@ -473,13 +479,20 @@ function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
473479
elseif E === POW
474480
h = hash2(s.exp, hash2(s.base, salt POW_SALT))
475481
elseif E === TERM
482+
if !iszero(s.hash2[])
483+
return s.hash2[]
484+
end
476485
op = operation(s)
477486
oph = op isa Function ? nameof(op) : op
478487
h = hashvec2(arguments(s), hash(oph, salt))
479488
else
480489
error_on_type()
481490
end
482-
hash(metadata(s), hash(T, h))
491+
h = hash(metadata(s), hash(T, h))
492+
if hasproperty(s, :hash2)
493+
s.hash2[] = h
494+
end
495+
return h
483496
end
484497

485498
###
@@ -530,7 +543,7 @@ function Term{T}(f, args; kw...) where T
530543
args = convert(Vector{Any}, args)
531544
end
532545

533-
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...)
546+
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...)
534547
BasicSymbolic(s)
535548
end
536549

@@ -551,7 +564,7 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
551564
end
552565
end
553566

554-
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
567+
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
555568
BasicSymbolic(s)
556569
end
557570

@@ -567,7 +580,7 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
567580
else
568581
coeff = a
569582
dict = b
570-
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
583+
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
571584
BasicSymbolic(s)
572585
end
573586
end

test/hash_consing.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,13 @@ end
137137
@test x1 !== x2
138138
SymbolicUtils.ENABLE_HASHCONSING[] = true
139139
end
140+
141+
@testset "`hash2` is cached" begin
142+
@syms a b f(..)
143+
for ex in [a + b, a * b, f(a)]
144+
h = SymbolicUtils.hash2(ex)
145+
@test h == ex.hash2[]
146+
ex2 = setmetadata(ex, Int, 3)
147+
@test ex2.hash2[] != h
148+
end
149+
end

0 commit comments

Comments
 (0)