16
16
17
17
import abc
18
18
from collections .abc import Callable , Iterable , Iterator , Sequence
19
+ import dataclasses
19
20
import functools
20
21
from functools import partial
21
22
import itertools as it
22
23
import logging
23
24
import math
24
25
import operator
25
- from typing import (Any , Generic , SupportsIndex , TypeVar , overload , TYPE_CHECKING , cast )
26
+ from typing import (Any , Generic , SupportsIndex , Type , TypeVar , overload , TYPE_CHECKING , cast )
26
27
import weakref
27
28
28
29
import numpy as np
@@ -331,8 +332,8 @@ def weakref_lru_cache(call: Callable, maxsize=2048,
331
332
Least recently used cache decorator with weakref support.
332
333
333
334
The cache will take a weakref to the first argument of the wrapped function
334
- and strong refs to all subsequent operations . In all other respects it should
335
- behave similar to `functools.lru_cache`.
335
+ and strong refs to all other arguments . In all other respects it should
336
+ behave similar to `functools.lru_cache`. The cache is thread local.
336
337
"""
337
338
cached_call = _weakref_lru_cache .weakref_lru_cache (
338
339
config .trace_context if trace_context_in_key else _ignore , call , maxsize
@@ -341,6 +342,135 @@ def weakref_lru_cache(call: Callable, maxsize=2048,
341
342
return cached_call
342
343
343
344
345
+ @dataclasses .dataclass (frozen = True , slots = True , weakref_slot = True )
346
+ class MultiWeakRefCacheKey :
347
+ weakrefs : tuple [weakref .ref , ...] # Used only when len(weakrefs) >= 2
348
+
349
+
350
+ class MultiWeakRefPlaceholder :
351
+ # Stands for an arg/kwarg that was replaced with a weakref
352
+ pass
353
+ _multi_weakref_placeholder = MultiWeakRefPlaceholder ()
354
+
355
+ # The types of arguments for which `multi_weakref_lru_cache` should keep
356
+ # weak references.
357
+ weakref_cache_key_types : set [Type ] = set ()
358
+ def is_weakref_cache_key_type (v ):
359
+ return callable (v ) or (type (v ) in weakref_cache_key_types )
360
+
361
+
362
+ def multi_weakref_lru_cache (
363
+ call : Callable , * ,
364
+ maxsize = 2048 ,
365
+ trace_context_in_key : bool = True ):
366
+ """
367
+ Least recently used cache decorator with weakref support.
368
+
369
+ Similar to `weakref_lru_cache`, except that it keeps weak references
370
+ to all positional and keyword arguments for which
371
+ `is_weakref_cache_key_type()` is true, and strong references to
372
+ other arguments. The cache entry is removed if any of the weakref
373
+ arguments dies.
374
+ """
375
+ # Keep strong references to the MultiWeakRefCacheKeys that resulted in
376
+ # cache misses, and are cache keys. Indexed by id. Only keys with all
377
+ # included weakrefs live are present.
378
+ id_to_key : dict [int , MultiWeakRefCacheKey ] = {}
379
+ # For each `wr: weakref.ref` present in `key: MultiWeakRefCacheKey` we have
380
+ # `id(key) in weakref_to_key_ids[wr]`.
381
+ weakref_to_key_ids : dict [weakref .ref , set [int ]] = {}
382
+
383
+ def remove_weakref (wr : weakref .ref ):
384
+ key_ids = weakref_to_key_ids .get (wr , set ())
385
+ for key_id in key_ids :
386
+ try :
387
+ del id_to_key [key_id ]
388
+ except KeyError :
389
+ pass
390
+ try :
391
+ del weakref_to_key_ids [wr ]
392
+ except KeyError :
393
+ pass
394
+
395
+ def weakrefs_to_sentinel (v , acc : list [Any ]):
396
+ if type (v ) is tuple :
397
+ return tuple (weakrefs_to_sentinel (v1 , acc ) for v1 in v )
398
+ elif type (v ) is dict :
399
+ return {k : weakrefs_to_sentinel (v1 , acc ) for k , v1 in v .items ()}
400
+ elif is_weakref_cache_key_type (v ):
401
+ acc .append (v )
402
+ return _multi_weakref_placeholder
403
+ else :
404
+ return v
405
+
406
+ def sentinel_to_referrents (v ,
407
+ it : Iterator [weakref .ref ],
408
+ key_id : int | None ):
409
+ # key_id is not None iff we use a MultiWeakRefCacheKey (>= 2 weakrefs)
410
+ if type (v ) is tuple :
411
+ return tuple (sentinel_to_referrents (v1 , it , key_id ) for v1 in v )
412
+ elif type (v ) is dict :
413
+ return {k : sentinel_to_referrents (v1 , it , key_id )
414
+ for k , v1 in v .items ()}
415
+ elif v is _multi_weakref_placeholder :
416
+ wr = next (it )
417
+ if key_id is not None :
418
+ weakref_to_key_ids .setdefault (wr , set ()).add (key_id )
419
+ return wr ()
420
+ else :
421
+ return v
422
+
423
+ def cache_miss (key : MultiWeakRefCacheKey | MultiWeakRefPlaceholder | Any ,
424
+ * args , ** kwargs ):
425
+ if isinstance (key , MultiWeakRefCacheKey ): # had at least 2 weakrefs
426
+ # We know `key` is in `cached_call` cache, so store strong references
427
+ key_id = id (key )
428
+ id_to_key [key_id ] = key
429
+ orig_args , orig_kwargs = sentinel_to_referrents (
430
+ (args , kwargs ), iter (key .weakrefs ), key_id )
431
+ elif key is _multi_weakref_placeholder : # had 0 weakrefs
432
+ orig_args = args
433
+ orig_kwargs = kwargs
434
+ else : # had 1 weakref, we had put it first as the `key`
435
+ orig_args , orig_kwargs = sentinel_to_referrents (
436
+ (args , kwargs ), iter ([weakref .ref (key )]), None )
437
+ return call (* orig_args , ** orig_kwargs )
438
+
439
+
440
+ cached_call = _weakref_lru_cache .weakref_lru_cache (
441
+ config .trace_context if trace_context_in_key else _ignore ,
442
+ cache_miss , maxsize
443
+ )
444
+ register_cache (cached_call , str (call ))
445
+
446
+ @functools .wraps (call )
447
+ def wrapper (* orig_args , ** orig_kwargs ):
448
+ acc_weakrefs : list [Any ] = []
449
+ args , kwargs = weakrefs_to_sentinel ((orig_args , orig_kwargs ),
450
+ acc_weakrefs )
451
+ nr_weakrefs = len (acc_weakrefs )
452
+ if nr_weakrefs == 0 :
453
+ return cached_call (_multi_weakref_placeholder ,
454
+ * orig_args , ** orig_kwargs )
455
+ elif nr_weakrefs == 1 :
456
+ # Put the single weakref first, and skip the MultiWeakRefCacheKey
457
+ return cached_call (acc_weakrefs [0 ],
458
+ * args , ** kwargs )
459
+ else :
460
+ value_to_weakref = {v : weakref .ref (v , remove_weakref )
461
+ for v in set (acc_weakrefs )}
462
+ key = MultiWeakRefCacheKey (weakrefs = tuple (value_to_weakref [v ]
463
+ for v in acc_weakrefs ))
464
+ return cached_call (key , * args , ** kwargs )
465
+
466
+ wrapper .cache_info = cached_call .cache_info
467
+ wrapper .cache_clear = cached_call .cache_clear
468
+ wrapper .cache_keys = cached_call .cache_keys
469
+ wrapper ._multi_weakref_id_to_key = id_to_key # stays alive as long as wrapper
470
+ wrapper ._multi_weakref_to_key_ids = weakref_to_key_ids
471
+ return wrapper
472
+
473
+
344
474
class Unhashable :
345
475
__slots__ = ["val" ]
346
476
0 commit comments