Skip to content

Commit d9d1645

Browse files
authored
Merge pull request #3 from saturncloud/skirmer/linting
fixing unit tests and linting for deployment
2 parents b790537 + 0da5996 commit d9d1645

File tree

5 files changed

+81
-26
lines changed

5 files changed

+81
-26
lines changed

dask_pytorch/data.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,13 @@ class S3ImageFolder(Dataset):
5959
An image folder that lives in S3. Directories containing the image are classes.
6060
"""
6161

62-
def __init__(self, s3_bucket: str, s3_prefix: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
62+
def __init__(
63+
self,
64+
s3_bucket: str,
65+
s3_prefix: str,
66+
transform: Optional[Callable] = None,
67+
target_transform: Optional[Callable] = None,
68+
):
6369
self.s3_bucket = s3_bucket
6470
self.s3_prefix = s3_prefix
6571
self.all_files = _list_all_files(s3_bucket, s3_prefix)

dask_pytorch/dispatch.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,32 @@ def run(client: Client, pytorch_function: Callable, *args, **kwargs):
3131
futures = [
3232
client.submit(
3333
dispatch_with_ddp,
34-
pytorch_function = pytorch_function,
35-
master_addr = host,
36-
master_port = port,
37-
rank = idx,
38-
world_size = world_size,
39-
backend = "nccl",
40-
*args,
34+
pytorch_function=pytorch_function,
35+
master_addr=host,
36+
master_port=port,
37+
rank=idx,
38+
world_size=world_size,
39+
backend="nccl",
40+
*args,
4141
**kwargs
4242
)
4343
for idx, w in enumerate(worker_keys)
4444
]
45-
45+
4646
return futures
4747

4848

49+
# pylint: disable=keyword-arg-before-vararg
50+
# pylint: disable=too-many-arguments
4951
def dispatch_with_ddp(
50-
pytorch_function: Callable, master_addr: str, master_port: int, rank: int, world_size: int, backend: str = "nccl", *args, **kwargs
52+
pytorch_function: Callable,
53+
master_addr: Any,
54+
master_port: Any,
55+
rank: Any,
56+
world_size: Any,
57+
backend: str = "nccl",
58+
*args,
59+
**kwargs
5160
) -> Any:
5261
"""
5362
runs a pytorch function, setting up torch.distributed before execution

dask_pytorch/results.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def _get_results(self, futures: List[Future], raise_errors: bool = True):
6161
raise
6262
futures = result.not_done
6363

64-
def process_results(self, prefix: str, futures: List[Future], raise_errors: bool = True) -> None:
64+
def process_results(
65+
self, prefix: str, futures: List[Future], raise_errors: bool = True
66+
) -> None:
6567
"""
6668
Process the intermediate results:
6769
result objects will be dictionaries of the form {'path': path, 'data': data}

tests/test_data.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
from unittest.mock import Mock, patch, ANY
22

33

4-
from dask_pytorch.data import BOTOS3ImageFolder
4+
from dask_pytorch.data import S3ImageFolder
55

66

77
def test_image_folder_constructor():
88
fake_file_list = ["d/a.jpg", "c/b.jpg"]
9-
with patch("dask_pytorch.data.get_all_files", return_value=fake_file_list):
9+
with patch("dask_pytorch.data._list_all_files", return_value=fake_file_list):
1010
fake_transform = Mock()
1111
fake_target_transform = Mock()
12-
folder = BOTOS3ImageFolder(
13-
"fake-bucket", "fake-prefix/fake-prefix", fake_transform, fake_target_transform
12+
folder = S3ImageFolder(
13+
"fake-bucket",
14+
"fake-prefix/fake-prefix",
15+
fake_transform,
16+
fake_target_transform,
1417
)
1518
assert folder.all_files == fake_file_list
1619
assert folder.classes == ["c", "d"]
@@ -21,17 +24,17 @@ def test_image_folder_constructor():
2124

2225
def test_image_folder_len():
2326
fake_file_list = ["d/a.jpg", "c/b.jpg"]
24-
with patch("dask_pytorch.data.get_all_files", return_value=fake_file_list):
25-
folder = BOTOS3ImageFolder("fake-bucket", "fake-prefix/fake-prefix")
27+
with patch("dask_pytorch.data._list_all_files", return_value=fake_file_list):
28+
folder = S3ImageFolder("fake-bucket", "fake-prefix/fake-prefix")
2629
assert len(folder) == 2
2730

2831

2932
def test_image_folder_getitem():
3033
fake_file_list = ["d/a.jpg", "c/b.jpg"]
31-
with patch("dask_pytorch.data.get_all_files", return_value=fake_file_list):
32-
folder = BOTOS3ImageFolder("fake-bucket", "fake-prefix/fake-prefix")
33-
with patch("dask_pytorch.data.read_s3_fileobj") as read_s3_fileobj, patch(
34-
"dask_pytorch.data.load_image_obj"
34+
with patch("dask_pytorch.data._list_all_files", return_value=fake_file_list):
35+
folder = S3ImageFolder("fake-bucket", "fake-prefix/fake-prefix")
36+
with patch("dask_pytorch.data._read_s3_fileobj") as read_s3_fileobj, patch(
37+
"dask_pytorch.data._load_image_obj"
3538
) as load_image_obj:
3639

3740
read_s3_fileobj.return_value = Mock()

tests/test_dispatch.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
23
from unittest.mock import Mock, patch
34

45
from dask_pytorch.dispatch import run, dispatch_with_ddp
@@ -31,16 +32,40 @@ def test_run():
3132
output = run(client, fake_pytorch_func)
3233

3334
client.submit.assert_any_call(
34-
dispatch_with_ddp, fake_pytorch_func, host, 23456, 0, len(workers), workers=[worker_keys[0]]
35+
dispatch_with_ddp,
36+
pytorch_function=fake_pytorch_func,
37+
master_addr=host,
38+
master_port=23456,
39+
rank=0,
40+
world_size=len(workers),
41+
backend="nccl",
3542
)
3643
client.submit.assert_any_call(
37-
dispatch_with_ddp, fake_pytorch_func, host, 23456, 1, len(workers), workers=[worker_keys[1]]
44+
dispatch_with_ddp,
45+
pytorch_function=fake_pytorch_func,
46+
master_addr=host,
47+
master_port=23456,
48+
rank=1,
49+
world_size=len(workers),
50+
backend="nccl",
3851
)
3952
client.submit.assert_any_call(
40-
dispatch_with_ddp, fake_pytorch_func, host, 23456, 2, len(workers), workers=[worker_keys[2]]
53+
dispatch_with_ddp,
54+
pytorch_function=fake_pytorch_func,
55+
master_addr=host,
56+
master_port=23456,
57+
rank=2,
58+
world_size=len(workers),
59+
backend="nccl",
4160
)
4261
client.submit.assert_any_call(
43-
dispatch_with_ddp, fake_pytorch_func, host, 23456, 3, len(workers), workers=[worker_keys[3]]
62+
dispatch_with_ddp,
63+
pytorch_function=fake_pytorch_func,
64+
master_addr=host,
65+
master_port=23456,
66+
rank=3,
67+
world_size=len(workers),
68+
backend="nccl",
4469
)
4570
assert output == fake_results
4671

@@ -51,7 +76,17 @@ def test_dispatch_with_ddp():
5176
with patch.object(os, "environ", {}) as environ, patch(
5277
"dask_pytorch.dispatch.dist", return_value=Mock()
5378
) as dist:
54-
dispatch_with_ddp(pytorch_func, "master_addr", 2343, 1, 10, "a", "b", foo="bar")
79+
dispatch_with_ddp(
80+
pytorch_func,
81+
"master_addr",
82+
2343,
83+
1,
84+
10,
85+
"nccl",
86+
"a",
87+
"b",
88+
foo="bar",
89+
)
5590
assert environ["MASTER_ADDR"] == "master_addr"
5691
assert environ["MASTER_PORT"] == "2343"
5792
assert environ["RANK"] == "1"

0 commit comments

Comments
 (0)