Skip to content

Commit 66cc801

Browse files
authored
Add TorchDataLoader to Train Benchmark (#51456)
## Why are these changes needed? Add TorchDataLoader to Train Benchmark. --------- Signed-off-by: Srinath Krishnamachari <srinath.krishnamachari@anyscale.com>
1 parent bf7f085 commit 66cc801

File tree

8 files changed

+1081
-144
lines changed

8 files changed

+1081
-144
lines changed

release/release_tests.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2118,11 +2118,21 @@
21182118
timeout: 2000
21192119
script: RAY_TRAIN_V2_ENABLED=1 python train_benchmark.py --task=image_classification --dataloader_type=ray_data --num_workers=16
21202120

2121+
- __suffix__: full_training_torch_dataloader
2122+
run:
2123+
timeout: 2000
2124+
script: RAY_TRAIN_V2_ENABLED=1 python train_benchmark.py --task=image_classification --dataloader_type=torch --num_workers=16
2125+
21212126
- __suffix__: skip_training
21222127
run:
21232128
timeout: 1200
21242129
script: RAY_TRAIN_V2_ENABLED=1 python train_benchmark.py --task=image_classification --dataloader_type=ray_data --num_workers=16 --skip_train_step --skip_validation_at_epoch_end
21252130

2131+
- __suffix__: skip_training_torch_dataloader
2132+
run:
2133+
timeout: 1200
2134+
script: RAY_TRAIN_V2_ENABLED=1 python train_benchmark.py --task=image_classification --dataloader_type=torch --num_workers=16 --skip_train_step --skip_validation_at_epoch_end
2135+
21262136
- __suffix__: skip_training.fault_tolerance
21272137
run:
21282138
timeout: 2700

release/train_tests/benchmark/config.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,31 @@
77
class DataloaderType(enum.Enum):
88
RAY_DATA = "ray_data"
99
MOCK = "mock"
10+
TORCH = "torch"
1011

1112

1213
class DataLoaderConfig(BaseModel):
1314
train_batch_size: int = 32
1415
validation_batch_size: int = 256
16+
prefetch_batches: int = 1
1517

1618

1719
class RayDataConfig(DataLoaderConfig):
1820
# NOTE: Optional[int] doesn't play well with argparse.
1921
local_buffer_shuffle_size: int = -1
2022

2123

24+
class TorchConfig(DataLoaderConfig):
25+
num_torch_workers: int = 8
26+
torch_dataloader_timeout_seconds: int = 300
27+
torch_pin_memory: bool = True
28+
torch_non_blocking: bool = True
29+
30+
2231
class BenchmarkConfig(BaseModel):
2332
# ScalingConfig
2433
num_workers: int = 1
34+
2535
# Run CPU training where train workers request a `MOCK_GPU` resource instead.
2636
mock_gpu: bool = False
2737

@@ -39,11 +49,14 @@ class BenchmarkConfig(BaseModel):
3949
# Training
4050
num_epochs: int = 1
4151
skip_train_step: bool = False
52+
train_step_anomaly_detection: bool = False
53+
limit_training_rows: int = 500000
4254

4355
# Validation
4456
validate_every_n_steps: int = -1
4557
skip_validation_step: bool = False
4658
skip_validation_at_epoch_end: bool = False
59+
limit_validation_rows: int = 50000
4760

4861
# Logging
4962
log_metrics_every_n_steps: int = 512
@@ -57,11 +70,10 @@ def _is_pydantic_model(field_type) -> bool:
5770
def _add_field_to_parser(parser: argparse.ArgumentParser, field: str, field_info):
5871
field_type = field_info.annotation
5972
if field_type is bool:
60-
assert (
61-
not field_info.default
62-
), "Only supports bool flags that are False by default."
6373
parser.add_argument(
64-
f"--{field}", action="store_true", default=field_info.default
74+
f"--{field}",
75+
type=lambda x: x.lower() == "true",
76+
default=field_info.default,
6577
)
6678
else:
6779
parser.add_argument(f"--{field}", type=field_type, default=field_info.default)
@@ -87,11 +99,11 @@ def cli_to_config() -> BenchmarkConfig:
8799
nested_parser = argparse.ArgumentParser()
88100
config_cls = BenchmarkConfig.model_fields[nested_field].annotation
89101

90-
if (
91-
config_cls == DataLoaderConfig
92-
and top_level_args.dataloader_type == DataloaderType.RAY_DATA
93-
):
94-
config_cls = RayDataConfig
102+
if config_cls == DataLoaderConfig:
103+
if top_level_args.dataloader_type == DataloaderType.RAY_DATA:
104+
config_cls = RayDataConfig
105+
elif top_level_args.dataloader_type == DataloaderType.TORCH:
106+
config_cls = TorchConfig
95107

96108
for field, field_info in config_cls.model_fields.items():
97109
_add_field_to_parser(nested_parser, field, field_info)
Lines changed: 4 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from abc import ABC, abstractmethod
22
from typing import Any, Dict, Iterator, Tuple
3+
import logging
34

45
import torch
5-
6-
import ray.data
7-
import ray.train
86
from ray.data import Dataset
97

10-
from config import BenchmarkConfig, DataLoaderConfig, RayDataConfig
8+
from config import BenchmarkConfig, DataLoaderConfig
9+
10+
logger = logging.getLogger(__name__)
1111

1212

1313
class BaseDataLoaderFactory(ABC):
@@ -34,124 +34,3 @@ def get_metrics(self) -> Dict[str, Any]:
3434
def get_ray_datasets(self) -> Dict[str, Dataset]:
3535
"""Get Ray datasets if this loader type uses Ray Data."""
3636
return {}
37-
38-
39-
class RayDataLoaderFactory(BaseDataLoaderFactory):
40-
def __init__(self, benchmark_config: BenchmarkConfig):
41-
super().__init__(benchmark_config)
42-
self._ray_ds_iterators = {}
43-
44-
assert isinstance(self.get_dataloader_config(), RayDataConfig), type(
45-
self.get_dataloader_config()
46-
)
47-
48-
# Configure Ray Data settings.
49-
data_context = ray.data.DataContext.get_current()
50-
data_context.enable_operator_progress_bars = False
51-
52-
@abstractmethod
53-
def get_ray_datasets(self) -> Dict[str, Dataset]:
54-
"""Get the Ray datasets for training and validation.
55-
56-
Returns:
57-
Dict with "train" and "val" Dataset objects
58-
"""
59-
pass
60-
61-
@abstractmethod
62-
def collate_fn(self) -> Dict[str, Dataset]:
63-
"""Get the collate function for the dataloader.
64-
65-
Returns:
66-
A function that takes a batch and returns a tuple of tensors.
67-
"""
68-
pass
69-
70-
def get_train_dataloader(self):
71-
ds_iterator = self._ray_ds_iterators["train"] = ray.train.get_dataset_shard(
72-
"train"
73-
)
74-
dataloader_config = self.get_dataloader_config()
75-
return iter(
76-
ds_iterator.iter_torch_batches(
77-
batch_size=dataloader_config.train_batch_size,
78-
local_shuffle_buffer_size=(
79-
dataloader_config.local_buffer_shuffle_size
80-
if dataloader_config.local_buffer_shuffle_size > 0
81-
else None
82-
),
83-
collate_fn=self.collate_fn,
84-
)
85-
)
86-
87-
def get_val_dataloader(self):
88-
ds_iterator = self._ray_ds_iterators["val"] = ray.train.get_dataset_shard("val")
89-
dataloader_config = self.get_dataloader_config()
90-
return iter(
91-
ds_iterator.iter_torch_batches(
92-
batch_size=dataloader_config.validation_batch_size,
93-
collate_fn=self.collate_fn,
94-
)
95-
)
96-
97-
def get_metrics(self) -> Dict[str, Any]:
98-
metrics = {}
99-
for ds_key, ds_iterator in self._ray_ds_iterators.items():
100-
stats = ray.get(ds_iterator._coord_actor.stats.remote())
101-
summary = stats.to_summary()
102-
summary.iter_stats = ds_iterator._iter_stats.to_summary().iter_stats
103-
summary.iter_stats.streaming_split_coord_time.add(
104-
stats.streaming_split_coordinator_s.get()
105-
)
106-
107-
if not summary.parents:
108-
continue
109-
110-
# The split() operator has no metrics, so pull the stats
111-
# from the final dataset stage.
112-
ds_output_summary = summary.parents[0]
113-
ds_throughput = (
114-
ds_output_summary.operators_stats[-1].output_num_rows["sum"]
115-
/ ds_output_summary.get_total_wall_time()
116-
)
117-
118-
iter_stats = summary.iter_stats
119-
120-
metrics[f"dataloader/{ds_key}"] = {
121-
"producer_throughput": ds_throughput,
122-
"iter_stats": {
123-
"prefetch_block-avg": iter_stats.wait_time.avg(),
124-
"prefetch_block-min": iter_stats.wait_time.min(),
125-
"prefetch_block-max": iter_stats.wait_time.max(),
126-
"prefetch_block-total": iter_stats.wait_time.get(),
127-
"fetch_block-avg": iter_stats.get_time.avg(),
128-
"fetch_block-min": iter_stats.get_time.min(),
129-
"fetch_block-max": iter_stats.get_time.max(),
130-
"fetch_block-total": iter_stats.get_time.get(),
131-
"block_to_batch-avg": iter_stats.next_time.avg(),
132-
"block_to_batch-min": iter_stats.next_time.min(),
133-
"block_to_batch-max": iter_stats.next_time.max(),
134-
"block_to_batch-total": iter_stats.next_time.get(),
135-
"format_batch-avg": iter_stats.format_time.avg(),
136-
"format_batch-min": iter_stats.format_time.min(),
137-
"format_batch-max": iter_stats.format_time.max(),
138-
"format_batch-total": iter_stats.format_time.get(),
139-
"collate-avg": iter_stats.collate_time.avg(),
140-
"collate-min": iter_stats.collate_time.min(),
141-
"collate-max": iter_stats.collate_time.max(),
142-
"collate-total": iter_stats.collate_time.get(),
143-
"finalize-avg": iter_stats.finalize_batch_time.avg(),
144-
"finalize-min": iter_stats.finalize_batch_time.min(),
145-
"finalize-max": iter_stats.finalize_batch_time.max(),
146-
"finalize-total": iter_stats.finalize_batch_time.get(),
147-
"time_spent_blocked-avg": iter_stats.block_time.avg(),
148-
"time_spent_blocked-min": iter_stats.block_time.min(),
149-
"time_spent_blocked-max": iter_stats.block_time.max(),
150-
"time_spent_blocked-total": iter_stats.block_time.get(),
151-
"time_spent_training-avg": iter_stats.user_time.avg(),
152-
"time_spent_training-min": iter_stats.user_time.min(),
153-
"time_spent_training-max": iter_stats.user_time.max(),
154-
"time_spent_training-total": iter_stats.user_time.get(),
155-
},
156-
}
157-
return metrics

0 commit comments

Comments
 (0)