Skip to content

Commit 693f0b9

Browse files
ENH: vendor SerializableLock from dask and use as default backend lock, adapt tests (#8571)
* vendor SerializableLock from dask, adapt tests * Update doc/whats-new.rst --------- Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
1 parent 92f79a0 commit 693f0b9

File tree

4 files changed

+80
-12
lines changed

4 files changed

+80
-12
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ Bug fixes
4444

4545
- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`).
4646
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
47+
- Vendor `SerializableLock` from dask and use as default lock for netcdf4 backends (:issue:`8442`, :pull:`8571`).
48+
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
4749

4850

4951
Documentation

xarray/backends/locks.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,83 @@
22

33
import multiprocessing
44
import threading
5+
import uuid
56
import weakref
6-
from collections.abc import MutableMapping
7-
from typing import Any
8-
9-
try:
10-
from dask.utils import SerializableLock
11-
except ImportError:
12-
# no need to worry about serializing the lock
13-
SerializableLock = threading.Lock # type: ignore
7+
from collections.abc import Hashable, MutableMapping
8+
from typing import Any, ClassVar
9+
from weakref import WeakValueDictionary
10+
11+
12+
# SerializableLock is adapted from Dask:
13+
# https://github.com/dask/dask/blob/74e898f0ec712e8317ba86cc3b9d18b6b9922be0/dask/utils.py#L1160-L1224
14+
# Used under the terms of Dask's license, see licenses/DASK_LICENSE.
15+
class SerializableLock:
16+
"""A Serializable per-process Lock
17+
18+
This wraps a normal ``threading.Lock`` object and satisfies the same
19+
interface. However, this lock can also be serialized and sent to different
20+
processes. It will not block concurrent operations between processes (for
21+
this you should look at ``dask.multiprocessing.Lock`` or ``locket.lock_file``
22+
but will consistently deserialize into the same lock.
23+
24+
So if we make a lock in one process::
25+
26+
lock = SerializableLock()
27+
28+
And then send it over to another process multiple times::
29+
30+
bytes = pickle.dumps(lock)
31+
a = pickle.loads(bytes)
32+
b = pickle.loads(bytes)
33+
34+
Then the deserialized objects will operate as though they were the same
35+
lock, and collide as appropriate.
36+
37+
This is useful for consistently protecting resources on a per-process
38+
level.
39+
40+
The creation of locks is itself not threadsafe.
41+
"""
42+
43+
_locks: ClassVar[
44+
WeakValueDictionary[Hashable, threading.Lock]
45+
] = WeakValueDictionary()
46+
token: Hashable
47+
lock: threading.Lock
48+
49+
def __init__(self, token: Hashable | None = None):
50+
self.token = token or str(uuid.uuid4())
51+
if self.token in SerializableLock._locks:
52+
self.lock = SerializableLock._locks[self.token]
53+
else:
54+
self.lock = threading.Lock()
55+
SerializableLock._locks[self.token] = self.lock
56+
57+
def acquire(self, *args, **kwargs):
58+
return self.lock.acquire(*args, **kwargs)
59+
60+
def release(self, *args, **kwargs):
61+
return self.lock.release(*args, **kwargs)
62+
63+
def __enter__(self):
64+
self.lock.__enter__()
65+
66+
def __exit__(self, *args):
67+
self.lock.__exit__(*args)
68+
69+
def locked(self):
70+
return self.lock.locked()
71+
72+
def __getstate__(self):
73+
return self.token
74+
75+
def __setstate__(self, token):
76+
self.__init__(token)
77+
78+
def __str__(self):
79+
return f"<{self.__class__.__name__}: {self.token}>"
80+
81+
__repr__ = __str__
1482

1583

1684
# Locks used by multiple backends.

xarray/tests/test_backends.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,6 @@ def test_dataset_compute(self) -> None:
432432
assert_identical(expected, computed)
433433

434434
def test_pickle(self) -> None:
435-
if not has_dask:
436-
pytest.xfail("pickling requires dask for SerializableLock")
437435
expected = Dataset({"foo": ("x", [42])})
438436
with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped:
439437
with roundtripped:

xarray/tests/test_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828

2929
import xarray as xr
30-
from xarray.backends.locks import HDF5_LOCK, CombinedLock
30+
from xarray.backends.locks import HDF5_LOCK, CombinedLock, SerializableLock
3131
from xarray.tests import (
3232
assert_allclose,
3333
assert_identical,
@@ -273,7 +273,7 @@ async def test_async(c, s, a, b) -> None:
273273

274274

275275
def test_hdf5_lock() -> None:
276-
assert isinstance(HDF5_LOCK, dask.utils.SerializableLock)
276+
assert isinstance(HDF5_LOCK, SerializableLock)
277277

278278

279279
@gen_cluster(client=True)

0 commit comments

Comments
 (0)