1
1
from abc import ABC , abstractmethod
2
2
from typing import Any , Dict , Iterator , Tuple
3
+ import logging
3
4
4
5
import torch
5
-
6
- import ray .data
7
- import ray .train
8
6
from ray .data import Dataset
9
7
10
- from config import BenchmarkConfig , DataLoaderConfig , RayDataConfig
8
+ from config import BenchmarkConfig , DataLoaderConfig
9
+
10
+ logger = logging .getLogger (__name__ )
11
11
12
12
13
13
class BaseDataLoaderFactory (ABC ):
@@ -34,124 +34,3 @@ def get_metrics(self) -> Dict[str, Any]:
34
34
def get_ray_datasets (self ) -> Dict [str , Dataset ]:
35
35
"""Get Ray datasets if this loader type uses Ray Data."""
36
36
return {}
37
-
38
-
39
- class RayDataLoaderFactory (BaseDataLoaderFactory ):
40
- def __init__ (self , benchmark_config : BenchmarkConfig ):
41
- super ().__init__ (benchmark_config )
42
- self ._ray_ds_iterators = {}
43
-
44
- assert isinstance (self .get_dataloader_config (), RayDataConfig ), type (
45
- self .get_dataloader_config ()
46
- )
47
-
48
- # Configure Ray Data settings.
49
- data_context = ray .data .DataContext .get_current ()
50
- data_context .enable_operator_progress_bars = False
51
-
52
- @abstractmethod
53
- def get_ray_datasets (self ) -> Dict [str , Dataset ]:
54
- """Get the Ray datasets for training and validation.
55
-
56
- Returns:
57
- Dict with "train" and "val" Dataset objects
58
- """
59
- pass
60
-
61
- @abstractmethod
62
- def collate_fn (self ) -> Dict [str , Dataset ]:
63
- """Get the collate function for the dataloader.
64
-
65
- Returns:
66
- A function that takes a batch and returns a tuple of tensors.
67
- """
68
- pass
69
-
70
- def get_train_dataloader (self ):
71
- ds_iterator = self ._ray_ds_iterators ["train" ] = ray .train .get_dataset_shard (
72
- "train"
73
- )
74
- dataloader_config = self .get_dataloader_config ()
75
- return iter (
76
- ds_iterator .iter_torch_batches (
77
- batch_size = dataloader_config .train_batch_size ,
78
- local_shuffle_buffer_size = (
79
- dataloader_config .local_buffer_shuffle_size
80
- if dataloader_config .local_buffer_shuffle_size > 0
81
- else None
82
- ),
83
- collate_fn = self .collate_fn ,
84
- )
85
- )
86
-
87
- def get_val_dataloader (self ):
88
- ds_iterator = self ._ray_ds_iterators ["val" ] = ray .train .get_dataset_shard ("val" )
89
- dataloader_config = self .get_dataloader_config ()
90
- return iter (
91
- ds_iterator .iter_torch_batches (
92
- batch_size = dataloader_config .validation_batch_size ,
93
- collate_fn = self .collate_fn ,
94
- )
95
- )
96
-
97
- def get_metrics (self ) -> Dict [str , Any ]:
98
- metrics = {}
99
- for ds_key , ds_iterator in self ._ray_ds_iterators .items ():
100
- stats = ray .get (ds_iterator ._coord_actor .stats .remote ())
101
- summary = stats .to_summary ()
102
- summary .iter_stats = ds_iterator ._iter_stats .to_summary ().iter_stats
103
- summary .iter_stats .streaming_split_coord_time .add (
104
- stats .streaming_split_coordinator_s .get ()
105
- )
106
-
107
- if not summary .parents :
108
- continue
109
-
110
- # The split() operator has no metrics, so pull the stats
111
- # from the final dataset stage.
112
- ds_output_summary = summary .parents [0 ]
113
- ds_throughput = (
114
- ds_output_summary .operators_stats [- 1 ].output_num_rows ["sum" ]
115
- / ds_output_summary .get_total_wall_time ()
116
- )
117
-
118
- iter_stats = summary .iter_stats
119
-
120
- metrics [f"dataloader/{ ds_key } " ] = {
121
- "producer_throughput" : ds_throughput ,
122
- "iter_stats" : {
123
- "prefetch_block-avg" : iter_stats .wait_time .avg (),
124
- "prefetch_block-min" : iter_stats .wait_time .min (),
125
- "prefetch_block-max" : iter_stats .wait_time .max (),
126
- "prefetch_block-total" : iter_stats .wait_time .get (),
127
- "fetch_block-avg" : iter_stats .get_time .avg (),
128
- "fetch_block-min" : iter_stats .get_time .min (),
129
- "fetch_block-max" : iter_stats .get_time .max (),
130
- "fetch_block-total" : iter_stats .get_time .get (),
131
- "block_to_batch-avg" : iter_stats .next_time .avg (),
132
- "block_to_batch-min" : iter_stats .next_time .min (),
133
- "block_to_batch-max" : iter_stats .next_time .max (),
134
- "block_to_batch-total" : iter_stats .next_time .get (),
135
- "format_batch-avg" : iter_stats .format_time .avg (),
136
- "format_batch-min" : iter_stats .format_time .min (),
137
- "format_batch-max" : iter_stats .format_time .max (),
138
- "format_batch-total" : iter_stats .format_time .get (),
139
- "collate-avg" : iter_stats .collate_time .avg (),
140
- "collate-min" : iter_stats .collate_time .min (),
141
- "collate-max" : iter_stats .collate_time .max (),
142
- "collate-total" : iter_stats .collate_time .get (),
143
- "finalize-avg" : iter_stats .finalize_batch_time .avg (),
144
- "finalize-min" : iter_stats .finalize_batch_time .min (),
145
- "finalize-max" : iter_stats .finalize_batch_time .max (),
146
- "finalize-total" : iter_stats .finalize_batch_time .get (),
147
- "time_spent_blocked-avg" : iter_stats .block_time .avg (),
148
- "time_spent_blocked-min" : iter_stats .block_time .min (),
149
- "time_spent_blocked-max" : iter_stats .block_time .max (),
150
- "time_spent_blocked-total" : iter_stats .block_time .get (),
151
- "time_spent_training-avg" : iter_stats .user_time .avg (),
152
- "time_spent_training-min" : iter_stats .user_time .min (),
153
- "time_spent_training-max" : iter_stats .user_time .max (),
154
- "time_spent_training-total" : iter_stats .user_time .get (),
155
- },
156
- }
157
- return metrics
0 commit comments