Skip to content

Commit b670c47

Browse files
authored
Simplify argument passing in actor batch calls (#3098)
1 parent 505361b commit b670c47

File tree

9 files changed

+125
-63
lines changed

9 files changed

+125
-63
lines changed

mars/_resource.pyx

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,48 @@ cdef class Resource:
2424
self.mem_bytes = mem_bytes
2525

2626
def __eq__(self, Resource other):
27-
return self.mem_bytes == other.mem_bytes and \
28-
self.num_gpus == other.num_gpus and \
29-
self.num_cpus == other.num_cpus
27+
cdef bint ret = (
28+
self.mem_bytes == other.mem_bytes
29+
and self.num_gpus == other.num_gpus
30+
and self.num_cpus == other.num_cpus
31+
)
32+
return ret
33+
34+
cdef bint _le(self, Resource other) nogil:
35+
# memory first, then gpu, cpu last
36+
cdef bint ret = (
37+
self.mem_bytes <= other.mem_bytes
38+
and self.num_gpus <= other.num_gpus
39+
and self.num_cpus <= other.num_cpus
40+
)
41+
return ret
3042

3143
def __gt__(self, Resource other):
32-
return not self.__le__(other)
44+
return not self._le(other)
3345

3446
def __le__(self, Resource other):
35-
# memory first, then gpu, cpu last
36-
return self.mem_bytes <= other.mem_bytes and \
37-
self.num_gpus <= other.num_gpus and \
38-
self.num_cpus <= other.num_cpus
47+
return self._le(other)
3948

4049
def __add__(self, Resource other):
41-
return Resource(num_cpus=self.num_cpus + other.num_cpus,
42-
num_gpus=self.num_gpus + other.num_gpus,
43-
mem_bytes=self.mem_bytes + other.mem_bytes)
50+
return Resource(
51+
num_cpus=self.num_cpus + other.num_cpus,
52+
num_gpus=self.num_gpus + other.num_gpus,
53+
mem_bytes=self.mem_bytes + other.mem_bytes,
54+
)
55+
4456
def __sub__(self, Resource other):
45-
return Resource(num_cpus=self.num_cpus - other.num_cpus,
46-
num_gpus=self.num_gpus - other.num_gpus,
47-
mem_bytes=self.mem_bytes - other.mem_bytes)
57+
return Resource(
58+
num_cpus=self.num_cpus - other.num_cpus,
59+
num_gpus=self.num_gpus - other.num_gpus,
60+
mem_bytes=self.mem_bytes - other.mem_bytes,
61+
)
62+
4863
def __neg__(self):
49-
return Resource(num_cpus=-self.num_cpus, num_gpus=-self.num_gpus, mem_bytes=-self.mem_bytes)
64+
return Resource(
65+
num_cpus=-self.num_cpus,
66+
num_gpus=-self.num_gpus,
67+
mem_bytes=-self.mem_bytes,
68+
)
5069

5170
def __repr__(self):
5271
return f"Resource(num_cpus={self.num_cpus}, num_gpus={self.num_gpus}, mem_bytes={self.mem_bytes})"

mars/learn/metrics/pairwise/tests/test_haversine_distances.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ def test_haversine_distances_execution(setup, x, y, use_sklearn):
5454

5555
result = distance.execute().fetch()
5656
expected = sk_haversine_distances(raw_x, raw_y)
57-
np.testing.assert_array_equal(result, expected)
57+
np.testing.assert_almost_equal(result, expected)
5858

5959
# test x is y
6060
distance = haversine_distances(x)
6161
distance.op._use_sklearn = use_sklearn
6262

6363
result = distance.execute().fetch()
6464
expected = sk_haversine_distances(raw_x, raw_x)
65-
np.testing.assert_array_equal(result, expected)
65+
np.testing.assert_almost_equal(result, expected)

mars/oscar/batch.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,52 +133,54 @@ def delay(*args, **kwargs):
133133

134134
@staticmethod
135135
def _gen_args_kwargs_list(delays):
136-
args_list = list()
137-
kwargs_list = list()
138-
for delay in delays:
139-
args_list.append(delay.args)
140-
kwargs_list.append(delay.kwargs)
136+
args_list = [delay.args for delay in delays]
137+
kwargs_list = [delay.kwargs for delay in delays]
141138
return args_list, kwargs_list
142139

143-
async def _async_batch(self, *delays):
140+
async def _async_batch(self, args_list, kwargs_list):
144141
# when there is only one call in batch, calling one-pass method
145142
# will be more efficient
146-
if len(delays) == 0:
143+
if len(args_list) == 0:
147144
return []
148-
elif len(delays) == 1:
149-
d = delays[0]
150-
return [await self._async_call(*d.args, **d.kwargs)]
145+
elif len(args_list) == 1:
146+
return [await self._async_call(*args_list[0], **kwargs_list[0])]
151147
elif self.batch_func:
152-
args_list, kwargs_list = self._gen_args_kwargs_list(delays)
153148
return await self.batch_func(args_list, kwargs_list)
154149
else:
155150
# this function has no batch implementation
156151
# call it separately
157152
tasks = [
158-
asyncio.create_task(self.func(*d.args, **d.kwargs)) for d in delays
153+
asyncio.create_task(self.func(*args, **kwargs))
154+
for args, kwargs in zip(args_list, kwargs_list)
159155
]
160156
try:
161157
return await asyncio.gather(*tasks)
162158
except asyncio.CancelledError:
163159
_ = [task.cancel() for task in tasks]
164160
return await asyncio.gather(*tasks)
165161

166-
def _sync_batch(self, *delays):
167-
if delays == 0:
162+
def _sync_batch(self, args_list, kwargs_list):
163+
if len(args_list) == 0:
168164
return []
169165
elif self.batch_func:
170-
args_list, kwargs_list = self._gen_args_kwargs_list(delays)
171166
return self.batch_func(args_list, kwargs_list)
172167
else:
173168
# this function has no batch implementation
174169
# call it separately
175-
return [self.func(*d.args, **d.kwargs) for d in delays]
170+
return [
171+
self.func(*args, **kwargs)
172+
for args, kwargs in zip(args_list, kwargs_list)
173+
]
176174

177175
def batch(self, *delays):
176+
args_list, kwargs_list = self._gen_args_kwargs_list(delays)
177+
return self.call_with_lists(args_list, kwargs_list)
178+
179+
def call_with_lists(self, args_list, kwargs_list):
178180
if self.is_async:
179-
return self._async_batch(*delays)
181+
return self._async_batch(args_list, kwargs_list)
180182
else:
181-
return self._sync_batch(*delays)
183+
return self._sync_batch(args_list, kwargs_list)
182184

183185
def bind(self, *args, **kwargs):
184186
if self.bind_func is None:

mars/oscar/core.pyx

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ import inspect
1717
import logging
1818
import sys
1919
import weakref
20-
cimport cython
2120
from typing import AsyncGenerator
2221

22+
cimport cython
23+
2324
from .context cimport get_context
2425
from .errors import Return, ActorNotExist
2526
from .utils cimport is_async_generator
@@ -33,7 +34,7 @@ logger = logging.getLogger(__name__)
3334
cdef:
3435
bint _log_unhandled_errors = False
3536
bint _log_cycle_send = False
36-
object _local_pool_map = weakref.WeakValueDictionary()
37+
dict _local_pool_map = dict()
3738
object _actor_method_wrapper
3839

3940

@@ -54,7 +55,8 @@ cdef _get_local_actor(address, uid):
5455
# disabled the local actor proxy if the debug option is on.
5556
if _log_cycle_send:
5657
return None
57-
pool = _local_pool_map.get(address)
58+
pool_ref = _local_pool_map.get(address)
59+
pool = None if pool_ref is None else pool_ref()
5860
if pool is not None:
5961
actor = pool._actors.get(uid)
6062
if actor is not None:
@@ -66,7 +68,9 @@ def register_local_pool(address, pool):
6668
"""
6769
Register local actor pool for local actor lookup.
6870
"""
69-
_local_pool_map[address] = pool
71+
_local_pool_map[address] = weakref.ref(
72+
pool, lambda _: _local_pool_map.pop(address, None)
73+
)
7074

7175

7276
cpdef create_local_actor_ref(address, uid):
@@ -177,22 +181,34 @@ cdef class ActorRefMethod:
177181

178182
def batch(self, *delays, send=True):
179183
cdef:
180-
list args_list = list()
181-
list kwargs_list = list()
184+
long n_delays = len(delays)
185+
bint has_kw = False
186+
list args_list
187+
list kwargs_list
188+
_DelayedArgument delay
189+
190+
args_list = [None] * n_delays
191+
kwargs_list = [None] * n_delays
182192

183193
last_method = None
184-
for delay in delays:
194+
for idx in range(n_delays):
195+
delay = delays[idx]
185196
method, _call_method, args, kwargs = delay.arguments
186197
if last_method is not None and method != last_method:
187198
raise ValueError('Does not support calling multiple methods in batch')
188199
last_method = method
189200

190-
args_list.append(args)
191-
kwargs_list.append(kwargs)
201+
args_list[idx] = args
202+
kwargs_list[idx] = kwargs
203+
if kwargs:
204+
has_kw = True
205+
206+
if not has_kw:
207+
kwargs_list = None
192208
if last_method is None:
193209
last_method = self.method_name
194210

195-
message = (last_method, CALL_METHOD_BATCH, (args_list, kwargs_list), {})
211+
message = (last_method, CALL_METHOD_BATCH, (args_list, kwargs_list), None)
196212
return get_context().send(self.ref, message, wait_response=send, **self._options)
197213

198214
def tell_delay(self, *args, delay=None, ignore_conn_fail=True, **kwargs):
@@ -488,10 +504,10 @@ cdef class _BaseActor:
488504
with debug_async_timeout('actor_lock_timeout',
489505
"Batch method %s of actor %s hold lock timeout, batch size %s.",
490506
method, self.uid, len(args)):
491-
delays = []
492-
for s_args, s_kwargs in zip(*args):
493-
delays.append(func.delay(*s_args, **s_kwargs))
494-
result = func.batch(*delays)
507+
args_list, kwargs_list = args
508+
if kwargs_list is None:
509+
kwargs_list = [{}] * len(args_list)
510+
result = func.call_with_lists(args_list, kwargs_list)
495511
if asyncio.iscoroutine(result):
496512
result = await result
497513
else: # pragma: no cover

mars/serialization/core.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
cdef class Serializer:
16+
cdef int _serializer_id
17+
1618
cpdef serial(self, object obj, dict context)
1719
cpdef deserial(self, tuple serialized, dict context, list subs)
1820
cpdef on_deserial_error(

mars/serialization/core.pyx

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ cdef int32_t _SERIALIZER_ID_PRIME = 32749
6666
cdef class Serializer:
6767
serializer_id = None
6868

69+
def __cinit__(self):
70+
# make the value can be referenced with C code
71+
self._serializer_id = self.serializer_id
72+
6973
cpdef serial(self, object obj, dict context):
7074
"""
7175
Returns intermediate serialization result of certain object.
@@ -158,14 +162,15 @@ cdef class Serializer:
158162

159163
@classmethod
160164
def register(cls, obj_type):
161-
inst = cls()
162165
if (
163166
cls.serializer_id is None
164167
or cls.serializer_id == getattr(super(cls, cls), "serializer_id", None)
165168
):
166169
# a class should have its own serializer_id
167170
# inherited serializer_id not acceptable
168171
cls.serializer_id = cls.calc_default_serializer_id()
172+
173+
inst = cls()
169174
_serial_dispatcher.register(obj_type, inst)
170175
if _deserializers.get(cls.serializer_id) is not None:
171176
assert type(_deserializers[cls.serializer_id]) is cls
@@ -321,10 +326,16 @@ cdef class StrSerializer(Serializer):
321326
cdef class CollectionSerializer(Serializer):
322327
obj_type = None
323328

329+
cdef object _obj_type
330+
331+
def __cinit__(self):
332+
# make the value can be referenced with C code
333+
self._obj_type = self.obj_type
334+
324335
cpdef tuple _serial_iterable(self, obj: Any):
325336
cdef list idx_to_propagate = []
326337
cdef list obj_to_propagate = []
327-
cdef list obj_list = list(obj)
338+
cdef list obj_list = <list>obj if type(obj) is list else list(obj)
328339
cdef int64_t idx
329340
cdef object item
330341

@@ -340,11 +351,14 @@ cdef class CollectionSerializer(Serializer):
340351
elif type(item) in _primitive_types:
341352
continue
342353

354+
if obj is obj_list:
355+
obj_list = list(obj)
356+
343357
obj_list[idx] = None
344358
idx_to_propagate.append(idx)
345359
obj_to_propagate.append(item)
346360

347-
if self.obj_type is not None and type(obj) is not self.obj_type:
361+
if self._obj_type is not None and type(obj) is not self._obj_type:
348362
obj_type = type(obj)
349363
else:
350364
obj_type = None
@@ -420,13 +434,19 @@ def _dict_value_replacer(context, ret, key, real_value):
420434

421435
cdef class DictSerializer(CollectionSerializer):
422436
serializer_id = 6
423-
_inspected_inherits = set()
437+
cdef set _inspected_inherits
438+
439+
def __cinit__(self):
440+
self._inspected_inherits = set()
424441

425442
cpdef serial(self, obj: Any, dict context):
426443
cdef uint64_t obj_id
427444
cdef tuple key_obj, value_obj
428445
cdef list key_bufs, value_bufs
429446

447+
if type(obj) is dict and len(<dict>obj) == 0:
448+
return (), [], True
449+
430450
obj_id = _fast_id(obj)
431451
if obj_id in context:
432452
return Placeholder(obj_id)
@@ -460,6 +480,8 @@ cdef class DictSerializer(CollectionSerializer):
460480
cdef int64_t i, num_key_bufs
461481
cdef list key_subs, value_subs, keys, values
462482

483+
if not serialized:
484+
return {}
463485
if len(serialized) == 1:
464486
# serialized directly
465487
return serialized[0]
@@ -537,6 +559,7 @@ PickleSerializer.register(object)
537559
for _primitive in _primitive_types:
538560
PrimitiveSerializer.register(_primitive)
539561
BytesSerializer.register(bytes)
562+
BytesSerializer.register(memoryview)
540563
StrSerializer.register(str)
541564
ListSerializer.register(list)
542565
TupleSerializer.register(tuple)
@@ -588,7 +611,7 @@ cdef tuple _serial_single(
588611
# REMEMBER to change _COMMON_HEADER_LEN when content of
589612
# this header changed
590613
common_header = (
591-
serializer.serializer_id, ordered_id, len(subs), final
614+
serializer._serializer_id, ordered_id, len(subs), final
592615
)
593616
break
594617
else:

mars/serialization/tests/test_serial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class CustomList(list):
5555
"abcd",
5656
["uvw", ("mno", "sdaf"), 4, 6.7],
5757
CustomList([3, 4, CustomList([5, 6])]),
58-
{"abc": 5.6, "def": [3.4], "gh": None},
58+
{"abc": 5.6, "def": [3.4], "gh": None, "ijk": {}},
5959
OrderedDict([("abcd", 5.6)]),
6060
defaultdict(lambda: 0, [("abcd", 0)]),
6161
],

0 commit comments

Comments
 (0)