Skip to content

Commit 8c0c45a

Browse files
authored
Add support for Dataflux Iterable Dataset (#17)
* add Dataflux Iterable Dataset * rename to DataFluxIterableDataset * add test case for multi-worker setup * update license header year * update license header year #2 * Address comments
1 parent a0de75f commit 8c0c45a

File tree

5 files changed

+467
-3
lines changed

5 files changed

+467
-3
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import math
19+
import logging
20+
21+
from torch.utils import data
22+
from google.cloud import storage
23+
from google.api_core.client_info import ClientInfo
24+
25+
import dataflux_core
26+
27+
28+
class Config:
29+
"""Customizable configuration to the DataFluxIterableDataset.
30+
31+
Attributes:
32+
sort_listing_results: A boolean flag indicating if data listing results
33+
will be alphabetically sorted. Default to False.
34+
35+
max_composite_object_size: An integer indicating a cap for the maximum
36+
size of the composite object in bytes. Default to 100000000 = 100 MiB.
37+
38+
num_processes: The number of processes to be used in the Dataflux algorithms.
39+
Default to the number of CPUs from the running environment.
40+
41+
prefix: The prefix that is used to list the objects in the bucket with.
42+
The default is None which means it will list all the objects in the bucket.
43+
44+
max_listing_retries: An integer indicating the maximum number of retries
45+
to attempt in case of any Python multiprocessing errors during
46+
GCS objects listing. Default to 3.
47+
"""
48+
49+
def __init__(
50+
self,
51+
sort_listing_results: bool = False,
52+
max_composite_object_size: int = 100000000,
53+
num_processes: int = os.cpu_count(),
54+
prefix: str = None,
55+
max_listing_retries: int = 3,
56+
):
57+
self.sort_listing_results = sort_listing_results
58+
self.max_composite_object_size = max_composite_object_size
59+
self.num_processes = num_processes
60+
self.prefix = prefix
61+
self.max_listing_retries = max_listing_retries
62+
63+
64+
class DataFluxIterableDataset(data.IterableDataset):
65+
def __init__(
66+
self,
67+
project_name,
68+
bucket_name,
69+
config=Config(),
70+
data_format_fn=lambda data: data,
71+
storage_client=None,
72+
):
73+
"""Initializes the DataFluxIterableDataset.
74+
75+
The initialization sets up the needed configuration and runs data
76+
listing using the Dataflux algorithm.
77+
78+
Args:
79+
project_name: The name of the GCP project.
80+
bucket_name: The name of the GCS bucket that holds the objects to compose.
81+
The Dataflux download algorithm uploads the the composed object to this bucket too.
82+
destination_blob_name: The name of the composite object to be created.
83+
config: A dataflux_iterable_dataset.Config object that includes configuration
84+
customizations. If not specified, a default config with default parameters is created.
85+
data_format_fn: A function that formats the downloaded bytes to the desired format.
86+
If not specified, the default formatting function leaves the data as-is.
87+
storage_client: The google.cloud.storage.Client object initiated with sufficient permission
88+
to access the project and the bucket. If not specified, it will be created
89+
during initialization.
90+
"""
91+
super().__init__()
92+
self.storage_client = storage_client
93+
if not storage_client:
94+
self.storage_client = storage.Client(
95+
project=project_name,
96+
client_info=ClientInfo(user_agent="dataflux/0.0"),
97+
)
98+
self.project_name = project_name
99+
self.bucket_name = bucket_name
100+
self.data_format_fn = data_format_fn
101+
self.config = config
102+
self.dataflux_download_optimization_params = (
103+
dataflux_core.download.DataFluxDownloadOptimizationParams(
104+
max_composite_object_size=self.config.max_composite_object_size
105+
)
106+
)
107+
108+
self.objects = self._list_GCS_blobs_with_retry()
109+
110+
def __iter__(self):
111+
worker_info = data.get_worker_info()
112+
if worker_info is None:
113+
# Single-process data loading.
114+
yield from [
115+
self.data_format_fn(bytes_content)
116+
for bytes_content in dataflux_core.download.dataflux_download_lazy(
117+
project_name=self.project_name,
118+
bucket_name=self.bucket_name,
119+
objects=self.objects,
120+
storage_client=self.storage_client,
121+
dataflux_download_optimization_params=self.dataflux_download_optimization_params,
122+
)
123+
]
124+
else:
125+
# Multi-process data loading. Split the workload among workers.
126+
# Ref: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset.
127+
per_worker = int(
128+
math.ceil(len(self.objects) / float(worker_info.num_workers))
129+
)
130+
worker_id = worker_info.id
131+
start = worker_id * per_worker
132+
end = min(start + per_worker, len(self.objects))
133+
yield from [
134+
self.data_format_fn(bytes_content)
135+
for bytes_content in dataflux_core.download.dataflux_download_lazy(
136+
project_name=self.project_name,
137+
bucket_name=self.bucket_name,
138+
objects=self.objects[start:end],
139+
storage_client=self.storage_client,
140+
dataflux_download_optimization_params=self.dataflux_download_optimization_params,
141+
)
142+
]
143+
144+
def _list_GCS_blobs_with_retry(self):
145+
"""Retries Dataflux Listing upon exceptions, up to the retries defined in self.config."""
146+
error = None
147+
listed_objects = []
148+
for _ in range(self.config.max_listing_retries):
149+
try:
150+
listed_objects = dataflux_core.fast_list.ListingController(
151+
max_parallelism=self.config.num_processes,
152+
project=self.project_name,
153+
bucket=self.bucket_name,
154+
sort_results=self.config.sort_listing_results,
155+
prefix=self.config.prefix,
156+
).run()
157+
except Exception as e:
158+
logging.error(
159+
f"exception {str(e)} caught running Dataflux fast listing."
160+
)
161+
error = e
162+
continue
163+
164+
# No exception -- we can immediately return the listed objects.
165+
else:
166+
return listed_objects
167+
168+
# Did not break the for loop, therefore all attempts
169+
# raised an exception.
170+
else:
171+
raise error

dataflux_pytorch/dataflux_mapstyle_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
Args:
7878
project_name: The name of the GCP project.
7979
bucket_name: The name of the GCS bucket that holds the objects to compose.
80-
The function uploads the the composed object to this bucket too.
80+
The Dataflux download algorithm uploads the the composed object to this bucket too.
8181
destination_blob_name: The name of the composite object to be created.
8282
config: A dataflux_mapstyle_dataset.Config object that includes configuration
8383
customizations. If not specified, a default config with default parameters is created.

0 commit comments

Comments
 (0)