Skip to content

Commit cdcd5b3

Browse files
authored
Support non-tf.data.Dataset instances in multi-host setting. (#21430)
It is assumed that the data is already pre-sharded across hosts. If `distribution.auto_shard_dataset` is `True` but the input dataset isn't a tf.data.Dataset, will error out as before. Setting `auto_shard_dataset` to `False` will allow any dataset type to be used. Unfortunately we can't add a test for this in OSS since we don't have a multi-process test fixture. However, this was tested internally.
1 parent 713172a commit cdcd5b3

File tree

2 files changed

+88
-44
lines changed

2 files changed

+88
-44
lines changed

keras/src/distribution/distribution_lib.py

Lines changed: 77 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,19 @@ class Distribution:
297297
298298
Args:
299299
device_mesh: A `DeviceMesh` instance.
300+
batch_dim_name: Optional string name for the batch dimension.
301+
Defaults to None.
302+
auto_shard_dataset: Automatically shard the dataset amongst
303+
processes in a multi-process setting. Set to `False` if the dataset
304+
is already sharded across hosts. Defaults to `True`.
300305
"""
301306

302-
def __init__(self, device_mesh, batch_dim_name=None):
307+
def __init__(
308+
self, device_mesh, batch_dim_name=None, auto_shard_dataset=True
309+
):
303310
self._device_mesh = device_mesh
304311
self._batch_dim_name = batch_dim_name
312+
self._auto_shard_dataset = auto_shard_dataset
305313

306314
def get_data_layout(self, data_shape):
307315
"""Retrieve the `TensorLayout` for the input data.
@@ -358,16 +366,28 @@ def device_mesh(self):
358366
def batch_dim_name(self):
359367
return self._batch_dim_name
360368

369+
@property
370+
def auto_shard_dataset(self):
371+
return self._auto_shard_dataset
372+
373+
@auto_shard_dataset.setter
374+
def auto_shard_dataset(self, auto_shard_dataset):
375+
self._auto_shard_dataset = auto_shard_dataset
376+
361377
def distribute_dataset(self, dataset):
362-
"""Create a distributed dataset instance from the original user dataset.
378+
"""Create a distributed dataset from the original global dataset.
363379
364380
Args:
365-
dataset: the original global dataset instance. Only
366-
`tf.data.Dataset` is supported at the moment.
381+
dataset: the original global dataset instance.
367382
368383
Returns:
369-
a sharded `tf.data.Dataset` instance, which will produce data for
370-
the current local worker/process.
384+
If `auto_shard_dataset` is `True`, returns a sharded dataset that
385+
only produces data for the current local worker/process. Otherwise,
386+
returns the original dataset.
387+
388+
Raises:
389+
ValueError: if auto-sharding is requested in a multi-process
390+
setting, but the dataset type is not supported.
371391
"""
372392
raise NotImplementedError()
373393

@@ -400,31 +420,33 @@ class DataParallel(Distribution):
400420
Args:
401421
device_mesh: Optional `DeviceMesh` instance.
402422
devices: Optional list of devices.
403-
auto_shard_dataset: Automatically shard the dataset amongst processes.
404-
Defaults to true.
423+
auto_shard_dataset: Automatically shard the dataset amongst
424+
processes in a multi-process setting. Set to `False` if the dataset
425+
is already sharded across hosts. Defaults to `True`.
405426
"""
406427

407428
def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True):
408429
if device_mesh:
409-
self._initialize_with_device_mesh(device_mesh)
430+
self._initialize_with_device_mesh(device_mesh, auto_shard_dataset)
410431
elif devices:
411-
self._initialize_mesh_from_devices(devices)
432+
self._initialize_mesh_from_devices(devices, auto_shard_dataset)
412433
else:
413-
self._initialize_mesh_from_list_devices()
434+
self._initialize_mesh_from_list_devices(auto_shard_dataset)
414435

415436
# Those following attributes might get convert to public methods.
416437
self._num_process = distribution_lib.num_processes()
417438
self._process_id = distribution_lib.process_id()
418439
self._is_multi_process = self._num_process > 1
419-
self._auto_shard_dataset = auto_shard_dataset
420440

421-
def _initialize_with_device_mesh(self, device_mesh):
441+
def _initialize_with_device_mesh(self, device_mesh, auto_shard_dataset):
422442
if not isinstance(device_mesh, DeviceMesh):
423443
raise ValueError(
424444
"Expect `mesh` to be an instance of `DeviceMesh`. "
425445
f"Received: mesh={device_mesh} (of type {type(device_mesh)})"
426446
)
427-
super().__init__(device_mesh, device_mesh.axis_names[0])
447+
super().__init__(
448+
device_mesh, device_mesh.axis_names[0], auto_shard_dataset
449+
)
428450
if self.device_mesh.devices.ndim != 1:
429451
warnings.warn(
430452
"Expect the input mesh to be 1D, but received "
@@ -433,23 +455,27 @@ def _initialize_with_device_mesh(self, device_mesh):
433455
device_mesh.devices.ndim,
434456
)
435457

436-
def _initialize_mesh_from_devices(self, devices):
458+
def _initialize_mesh_from_devices(self, devices, auto_shard_dataset):
437459
devices = np.array(devices)
438460
device_mesh = DeviceMesh(
439461
shape=devices.shape,
440462
axis_names=[DEFAULT_BATCH_DIM_NAME],
441463
devices=devices,
442464
)
443-
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)
465+
super().__init__(
466+
device_mesh, DEFAULT_BATCH_DIM_NAME, auto_shard_dataset
467+
)
444468

445-
def _initialize_mesh_from_list_devices(self):
469+
def _initialize_mesh_from_list_devices(self, auto_shard_dataset):
446470
devices = np.array(list_devices())
447471
device_mesh = DeviceMesh(
448472
shape=devices.shape,
449473
axis_names=[DEFAULT_BATCH_DIM_NAME],
450474
devices=devices,
451475
)
452-
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)
476+
super().__init__(
477+
device_mesh, DEFAULT_BATCH_DIM_NAME, auto_shard_dataset
478+
)
453479

454480
def get_data_layout(self, data_shape):
455481
data_shard_spec = [None] * len(data_shape)
@@ -469,19 +495,21 @@ def get_tensor_layout(self, path):
469495
return None
470496

471497
def distribute_dataset(self, dataset):
472-
from tensorflow.python.data.experimental.ops import (
473-
distribute as tf_data_distribute,
474-
)
498+
if not self._is_multi_process or not self.auto_shard_dataset:
499+
return dataset
475500

501+
# Try to distribute a global tf.data.Dataset.
476502
from keras.src.utils.module_utils import tensorflow as tf
477503

478-
if not isinstance(dataset, tf.data.Dataset):
504+
if not tf.available or not isinstance(dataset, tf.data.Dataset):
479505
raise ValueError(
480-
"Only `tf.data.Dataset` is supported for "
481-
f"sharding, got {type(dataset)}"
506+
"Only `tf.data.Dataset` is supported for auto-sharding, "
507+
f"got {type(dataset)}"
482508
)
483-
if not self._is_multi_process or not self._auto_shard_dataset:
484-
return dataset
509+
510+
from tensorflow.python.data.experimental.ops import (
511+
distribute as tf_data_distribute,
512+
)
485513

486514
batch_size = tf_data_distribute.compute_batch_size(dataset)
487515
if batch_size.numpy() < 0:
@@ -587,9 +615,19 @@ class ModelParallel(Distribution):
587615
(of the `layout_map` object)
588616
that will be used to distribute data. If unspecified, the
589617
first axis from the device mesh will be used.
618+
auto_shard_dataset: Automatically shard the dataset amongst
619+
processes in a multi-process setting. Set to `False` if the dataset
620+
is already sharded across hosts. Defaults to `True`.
590621
"""
591622

592-
def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):
623+
def __init__(
624+
self,
625+
*,
626+
layout_map=None,
627+
batch_dim_name=None,
628+
auto_shard_dataset=True,
629+
**kwargs,
630+
):
593631
kwargs.pop("device_mesh", None)
594632
if layout_map is None:
595633
raise ValueError("You must specify a layout_map argument.")
@@ -599,9 +637,9 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):
599637
f"Received: layout_map={layout_map}"
600638
)
601639
device_mesh = layout_map.device_mesh
602-
super().__init__(device_mesh)
640+
batch_dim_name = batch_dim_name or device_mesh.axis_names[0]
641+
super().__init__(device_mesh, batch_dim_name, auto_shard_dataset)
603642
self._layout_map = layout_map
604-
self._batch_dim_name = batch_dim_name or self.device_mesh.axis_names[0]
605643

606644
# Those following attributes might get convert to public methods.
607645
self._num_process = distribution_lib.num_processes()
@@ -628,19 +666,21 @@ def get_tensor_layout(self, path):
628666
return self._layout_map[path]
629667

630668
def distribute_dataset(self, dataset):
631-
from tensorflow.python.data.experimental.ops import (
632-
distribute as tf_data_distribute,
633-
)
669+
if not self._is_multi_process or not self.auto_shard_dataset:
670+
return dataset
634671

672+
# Try to distribute a global tf.data.Dataset.
635673
from keras.src.utils.module_utils import tensorflow as tf
636674

637-
if not isinstance(dataset, tf.data.Dataset):
675+
if not tf.available or not isinstance(dataset, tf.data.Dataset):
638676
raise ValueError(
639-
"Only `tf.data.Dataset` is supported for "
640-
f"sharding, got {type(dataset)}"
677+
"Only `tf.data.Dataset` is supported for auto-sharding, "
678+
f"got {type(dataset)}"
641679
)
642-
if not self._is_multi_process:
643-
return dataset
680+
681+
from tensorflow.python.data.experimental.ops import (
682+
distribute as tf_data_distribute,
683+
)
644684

645685
global_batch_size = tf_data_distribute.compute_batch_size(dataset)
646686
if global_batch_size.numpy() < 0:

keras/src/trainers/data_adapters/__init__.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,20 @@ def get_data_adapter(
2828
if isinstance(x, data_adapter.DataAdapter):
2929
return x
3030

31-
# Check for multi-process/worker distribution. Since only tf.dataset
32-
# is supported at the moment, we will raise error if the inputs fail
33-
# the type check
31+
# Check for multi-process/worker distribution.
3432
distribution = distribution_lib.distribution()
35-
if getattr(distribution, "_is_multi_process", False) and not is_tf_dataset(
36-
x
33+
if (
34+
distribution is not None
35+
and getattr(distribution, "_is_multi_process", False)
36+
and getattr(distribution, "auto_shard_dataset", False)
37+
and not is_tf_dataset(x)
3738
):
3839
raise ValueError(
39-
"When using multi-worker distribution, the data must be provided "
40-
f"as a `tf.data.Dataset` instance. Received: type(x)={type(x)}."
40+
"When using a multi-worker distribution with auto-sharding enabled, "
41+
"the data must be provided as a `tf.data.Dataset` instance. "
42+
f"Received: type(x)={type(x)}. "
43+
"If the dataset is already sharded across workers, then set "
44+
"`distribution.auto_shard_dataset = False`."
4145
)
4246

4347
if array_data_adapter.can_convert_arrays((x, y, sample_weight)):

0 commit comments

Comments
 (0)