@@ -197,6 +197,24 @@ const DAG_SPECS = Vector{Pair{DAGSpec, DAGSpecSchedule}}()
197
197
198
198
# const DAG_SCHEDULE_CACHE = Dict{DAGSpec, DAGSpecSchedule}()
199
199
200
+ _identity_hash (arg, h:: UInt = UInt (0 )) = ismutable (arg) ? objectid (arg) : hash (arg, h)
201
+ _identity_hash (arg:: SubArray , h:: UInt = UInt (0 )) = hash (arg. indices, hash (arg. offset1, hash (arg. stride1, _identity_hash (arg. parent, h))))
202
+
203
+ struct ArgumentWrapper
204
+ arg
205
+ dep_mod
206
+ hash:: UInt
207
+
208
+ function ArgumentWrapper (arg, dep_mod)
209
+ h = hash (dep_mod)
210
+ h = _identity_hash (arg, h)
211
+ return new (arg, dep_mod, h)
212
+ end
213
+ end
214
+ Base. hash (aw:: ArgumentWrapper ) = hash (ArgumentWrapper, aw. hash)
215
+ Base.:(== )(aw1:: ArgumentWrapper , aw2:: ArgumentWrapper ) =
216
+ aw1. hash == aw2. hash
217
+
200
218
struct DataDepsAliasingState
201
219
# Track original and current data locations
202
220
# We track data => space
@@ -214,7 +232,7 @@ struct DataDepsAliasingState
214
232
task_to_id:: IdDict{DTask,Int}
215
233
216
234
# Cache ainfo lookups
217
- ainfo_cache:: Dict{Tuple{Any,Any} ,AbstractAliasing}
235
+ ainfo_cache:: Dict{ArgumentWrapper ,AbstractAliasing}
218
236
219
237
function DataDepsAliasingState ()
220
238
data_origin = Dict {AbstractAliasing,MemorySpace} ()
@@ -227,7 +245,7 @@ struct DataDepsAliasingState
227
245
g = SimpleDiGraph ()
228
246
task_to_id = IdDict {DTask,Int} ()
229
247
230
- ainfo_cache = Dict {Tuple{Any,Any} ,AbstractAliasing} ()
248
+ ainfo_cache = Dict {ArgumentWrapper ,AbstractAliasing} ()
231
249
232
250
return new (data_origin, data_locality,
233
251
ainfos_owner, ainfos_readers, ainfos_overlaps,
@@ -288,7 +306,8 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
288
306
end
289
307
290
308
function aliasing (astate:: DataDepsAliasingState , arg, dep_mod)
291
- return get! (astate. ainfo_cache, (arg, dep_mod)) do
309
+ aw = ArgumentWrapper (arg, dep_mod)
310
+ get! (astate. ainfo_cache, aw) do
292
311
return aliasing (arg, dep_mod)
293
312
end
294
313
end
0 commit comments