@@ -297,11 +297,19 @@ class Distribution:
297
297
298
298
Args:
299
299
device_mesh: A `DeviceMesh` instance.
300
+ batch_dim_name: Optional string name for the batch dimension.
301
+ Defaults to None.
302
+ auto_shard_dataset: Automatically shard the dataset amongst
303
+ processes in a multi-process setting. Set to `False` if the dataset
304
+ is already sharded across hosts. Defaults to `True`.
300
305
"""
301
306
302
- def __init__ (self , device_mesh , batch_dim_name = None ):
307
+ def __init__ (
308
+ self , device_mesh , batch_dim_name = None , auto_shard_dataset = True
309
+ ):
303
310
self ._device_mesh = device_mesh
304
311
self ._batch_dim_name = batch_dim_name
312
+ self ._auto_shard_dataset = auto_shard_dataset
305
313
306
314
def get_data_layout (self , data_shape ):
307
315
"""Retrieve the `TensorLayout` for the input data.
@@ -358,16 +366,28 @@ def device_mesh(self):
358
366
def batch_dim_name (self ):
359
367
return self ._batch_dim_name
360
368
369
+ @property
370
+ def auto_shard_dataset (self ):
371
+ return self ._auto_shard_dataset
372
+
373
+ @auto_shard_dataset .setter
374
+ def auto_shard_dataset (self , auto_shard_dataset ):
375
+ self ._auto_shard_dataset = auto_shard_dataset
376
+
361
377
def distribute_dataset (self , dataset ):
362
- """Create a distributed dataset instance from the original user dataset.
378
+ """Create a distributed dataset from the original global dataset.
363
379
364
380
Args:
365
- dataset: the original global dataset instance. Only
366
- `tf.data.Dataset` is supported at the moment.
381
+ dataset: the original global dataset instance.
367
382
368
383
Returns:
369
- a sharded `tf.data.Dataset` instance, which will produce data for
370
- the current local worker/process.
384
+ If `auto_shard_dataset` is `True`, returns a sharded dataset that
385
+ only produces data for the current local worker/process. Otherwise,
386
+ returns the original dataset.
387
+
388
+ Raises:
389
+ ValueError: if auto-sharding is requested in a multi-process
390
+ setting, but the dataset type is not supported.
371
391
"""
372
392
raise NotImplementedError ()
373
393
@@ -400,31 +420,33 @@ class DataParallel(Distribution):
400
420
Args:
401
421
device_mesh: Optional `DeviceMesh` instance.
402
422
devices: Optional list of devices.
403
- auto_shard_dataset: Automatically shard the dataset amongst processes.
404
- Defaults to true.
423
+ auto_shard_dataset: Automatically shard the dataset amongst
424
+ processes in a multi-process setting. Set to `False` if the dataset
425
+ is already sharded across hosts. Defaults to `True`.
405
426
"""
406
427
407
428
def __init__ (self , device_mesh = None , devices = None , auto_shard_dataset = True ):
408
429
if device_mesh :
409
- self ._initialize_with_device_mesh (device_mesh )
430
+ self ._initialize_with_device_mesh (device_mesh , auto_shard_dataset )
410
431
elif devices :
411
- self ._initialize_mesh_from_devices (devices )
432
+ self ._initialize_mesh_from_devices (devices , auto_shard_dataset )
412
433
else :
413
- self ._initialize_mesh_from_list_devices ()
434
+ self ._initialize_mesh_from_list_devices (auto_shard_dataset )
414
435
415
436
# Those following attributes might get convert to public methods.
416
437
self ._num_process = distribution_lib .num_processes ()
417
438
self ._process_id = distribution_lib .process_id ()
418
439
self ._is_multi_process = self ._num_process > 1
419
- self ._auto_shard_dataset = auto_shard_dataset
420
440
421
- def _initialize_with_device_mesh (self , device_mesh ):
441
+ def _initialize_with_device_mesh (self , device_mesh , auto_shard_dataset ):
422
442
if not isinstance (device_mesh , DeviceMesh ):
423
443
raise ValueError (
424
444
"Expect `mesh` to be an instance of `DeviceMesh`. "
425
445
f"Received: mesh={ device_mesh } (of type { type (device_mesh )} )"
426
446
)
427
- super ().__init__ (device_mesh , device_mesh .axis_names [0 ])
447
+ super ().__init__ (
448
+ device_mesh , device_mesh .axis_names [0 ], auto_shard_dataset
449
+ )
428
450
if self .device_mesh .devices .ndim != 1 :
429
451
warnings .warn (
430
452
"Expect the input mesh to be 1D, but received "
@@ -433,23 +455,27 @@ def _initialize_with_device_mesh(self, device_mesh):
433
455
device_mesh .devices .ndim ,
434
456
)
435
457
436
- def _initialize_mesh_from_devices (self , devices ):
458
+ def _initialize_mesh_from_devices (self , devices , auto_shard_dataset ):
437
459
devices = np .array (devices )
438
460
device_mesh = DeviceMesh (
439
461
shape = devices .shape ,
440
462
axis_names = [DEFAULT_BATCH_DIM_NAME ],
441
463
devices = devices ,
442
464
)
443
- super ().__init__ (device_mesh , DEFAULT_BATCH_DIM_NAME )
465
+ super ().__init__ (
466
+ device_mesh , DEFAULT_BATCH_DIM_NAME , auto_shard_dataset
467
+ )
444
468
445
- def _initialize_mesh_from_list_devices (self ):
469
+ def _initialize_mesh_from_list_devices (self , auto_shard_dataset ):
446
470
devices = np .array (list_devices ())
447
471
device_mesh = DeviceMesh (
448
472
shape = devices .shape ,
449
473
axis_names = [DEFAULT_BATCH_DIM_NAME ],
450
474
devices = devices ,
451
475
)
452
- super ().__init__ (device_mesh , DEFAULT_BATCH_DIM_NAME )
476
+ super ().__init__ (
477
+ device_mesh , DEFAULT_BATCH_DIM_NAME , auto_shard_dataset
478
+ )
453
479
454
480
def get_data_layout (self , data_shape ):
455
481
data_shard_spec = [None ] * len (data_shape )
@@ -469,19 +495,21 @@ def get_tensor_layout(self, path):
469
495
return None
470
496
471
497
def distribute_dataset (self , dataset ):
472
- from tensorflow .python .data .experimental .ops import (
473
- distribute as tf_data_distribute ,
474
- )
498
+ if not self ._is_multi_process or not self .auto_shard_dataset :
499
+ return dataset
475
500
501
+ # Try to distribute a global tf.data.Dataset.
476
502
from keras .src .utils .module_utils import tensorflow as tf
477
503
478
- if not isinstance (dataset , tf .data .Dataset ):
504
+ if not tf . available or not isinstance (dataset , tf .data .Dataset ):
479
505
raise ValueError (
480
- "Only `tf.data.Dataset` is supported for "
481
- f"sharding, got { type (dataset )} "
506
+ "Only `tf.data.Dataset` is supported for auto-sharding, "
507
+ f"got { type (dataset )} "
482
508
)
483
- if not self ._is_multi_process or not self ._auto_shard_dataset :
484
- return dataset
509
+
510
+ from tensorflow .python .data .experimental .ops import (
511
+ distribute as tf_data_distribute ,
512
+ )
485
513
486
514
batch_size = tf_data_distribute .compute_batch_size (dataset )
487
515
if batch_size .numpy () < 0 :
@@ -587,9 +615,19 @@ class ModelParallel(Distribution):
587
615
(of the `layout_map` object)
588
616
that will be used to distribute data. If unspecified, the
589
617
first axis from the device mesh will be used.
618
+ auto_shard_dataset: Automatically shard the dataset amongst
619
+ processes in a multi-process setting. Set to `False` if the dataset
620
+ is already sharded across hosts. Defaults to `True`.
590
621
"""
591
622
592
- def __init__ (self , * , layout_map = None , batch_dim_name = None , ** kwargs ):
623
+ def __init__ (
624
+ self ,
625
+ * ,
626
+ layout_map = None ,
627
+ batch_dim_name = None ,
628
+ auto_shard_dataset = True ,
629
+ ** kwargs ,
630
+ ):
593
631
kwargs .pop ("device_mesh" , None )
594
632
if layout_map is None :
595
633
raise ValueError ("You must specify a layout_map argument." )
@@ -599,9 +637,9 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):
599
637
f"Received: layout_map={ layout_map } "
600
638
)
601
639
device_mesh = layout_map .device_mesh
602
- super ().__init__ (device_mesh )
640
+ batch_dim_name = batch_dim_name or device_mesh .axis_names [0 ]
641
+ super ().__init__ (device_mesh , batch_dim_name , auto_shard_dataset )
603
642
self ._layout_map = layout_map
604
- self ._batch_dim_name = batch_dim_name or self .device_mesh .axis_names [0 ]
605
643
606
644
# Those following attributes might get convert to public methods.
607
645
self ._num_process = distribution_lib .num_processes ()
@@ -628,19 +666,21 @@ def get_tensor_layout(self, path):
628
666
return self ._layout_map [path ]
629
667
630
668
def distribute_dataset (self , dataset ):
631
- from tensorflow .python .data .experimental .ops import (
632
- distribute as tf_data_distribute ,
633
- )
669
+ if not self ._is_multi_process or not self .auto_shard_dataset :
670
+ return dataset
634
671
672
+ # Try to distribute a global tf.data.Dataset.
635
673
from keras .src .utils .module_utils import tensorflow as tf
636
674
637
- if not isinstance (dataset , tf .data .Dataset ):
675
+ if not tf . available or not isinstance (dataset , tf .data .Dataset ):
638
676
raise ValueError (
639
- "Only `tf.data.Dataset` is supported for "
640
- f"sharding, got { type (dataset )} "
677
+ "Only `tf.data.Dataset` is supported for auto-sharding, "
678
+ f"got { type (dataset )} "
641
679
)
642
- if not self ._is_multi_process :
643
- return dataset
680
+
681
+ from tensorflow .python .data .experimental .ops import (
682
+ distribute as tf_data_distribute ,
683
+ )
644
684
645
685
global_batch_size = tf_data_distribute .compute_batch_size (dataset )
646
686
if global_batch_size .numpy () < 0 :
0 commit comments