Skip to content

Commit a10e23b

Browse files
committed
datadeps: Hide argument type instability
1 parent 82df8f5 commit a10e23b

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

src/datadeps.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,24 @@ const DAG_SPECS = Vector{Pair{DAGSpec, DAGSpecSchedule}}()
197197

198198
#const DAG_SCHEDULE_CACHE = Dict{DAGSpec, DAGSpecSchedule}()
199199

200+
_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h)
201+
_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h))))
202+
203+
struct ArgumentWrapper
204+
arg
205+
dep_mod
206+
hash::UInt
207+
208+
function ArgumentWrapper(arg, dep_mod)
209+
h = hash(dep_mod)
210+
h = _identity_hash(arg, h)
211+
return new(arg, dep_mod, h)
212+
end
213+
end
214+
Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash)
215+
Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) =
216+
aw1.hash == aw2.hash
217+
200218
struct DataDepsAliasingState
201219
# Track original and current data locations
202220
# We track data => space
@@ -214,7 +232,7 @@ struct DataDepsAliasingState
214232
task_to_id::IdDict{DTask,Int}
215233

216234
# Cache ainfo lookups
217-
ainfo_cache::Dict{Tuple{Any,Any},AbstractAliasing}
235+
ainfo_cache::Dict{ArgumentWrapper,AbstractAliasing}
218236

219237
function DataDepsAliasingState()
220238
data_origin = Dict{AbstractAliasing,MemorySpace}()
@@ -227,7 +245,7 @@ struct DataDepsAliasingState
227245
g = SimpleDiGraph()
228246
task_to_id = IdDict{DTask,Int}()
229247

230-
ainfo_cache = Dict{Tuple{Any,Any},AbstractAliasing}()
248+
ainfo_cache = Dict{ArgumentWrapper,AbstractAliasing}()
231249

232250
return new(data_origin, data_locality,
233251
ainfos_owner, ainfos_readers, ainfos_overlaps,
@@ -288,7 +306,8 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
288306
end
289307

290308
function aliasing(astate::DataDepsAliasingState, arg, dep_mod)
291-
return get!(astate.ainfo_cache, (arg, dep_mod)) do
309+
aw = ArgumentWrapper(arg, dep_mod)
310+
get!(astate.ainfo_cache, aw) do
292311
return aliasing(arg, dep_mod)
293312
end
294313
end

0 commit comments

Comments
 (0)