Skip to content

Commit 6754ab5

Browse files
committed
add pytorch modifier
1 parent b2f8a8b commit 6754ab5

File tree

4 files changed

+53
-27
lines changed

4 files changed

+53
-27
lines changed

contrib/segmentation/job.yml

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
2+
code:
3+
local_path: .
4+
command: >-
5+
python train.py
6+
--train-dir {inputs.dummy}
7+
--val-dir {inputs.dummy}
8+
--patch-dim "256, 256"
9+
inputs:
10+
dummy:
11+
data:
12+
local_path: data/train
13+
mode: mount
14+
compute:
15+
target: azureml:cpu-cluster
16+
instance_count: 1
17+
# distribution:
18+
# type: pytorch
19+
# process_count: 2
20+
# azureml:<environment-name>:<version>
21+
environment: azureml:semantic-segmentation:1
22+
experiment_name: pytorch-semantic-segmentation
23+
description: Train a Semantic Segmentation Model on the Semantic Segmentation Drone Dataset

contrib/segmentation/src/datasets/semantic_segmentation.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import torch
1111
from PIL import Image
12+
from torch.utils.data.dataset import Dataset
1213

1314
from .coco import CocoDataset
1415

@@ -398,21 +399,22 @@ def __len__(self):
398399
return self.length
399400

400401

401-
class SemanticSegmentationDataset(torch.utils.data.Dataset):
402+
class SemanticSegmentationPyTorchDataset(torch.utils.data.Dataset):
402403

403404
_available_patch_strategies = set(
404405
["resize", "deterministic_center_crop", "crop_all"]
405406
)
406407

407408
# NC24sv3 Azure VMs have 440GiB of RAM
408-
# This allows the SemanticSegmentationDataset to be stored in memory
409+
# This allows the SemanticSegmentationPyTorchDataset to be stored in memory
409410
# However, when multiple workers are used in PyTorch Dataloader,
410411
# a separate deepcopy of the dataset is made per instance
411412
# Thus, disk is currently the only shared memory pool between processes
412413
_available_cache_strategies = set([None, "none", "disk"])
413414

414415
def __init__(
415416
self,
417+
dataset: Dataset,
416418
labels_filepath: str,
417419
classes: List[int],
418420
annotation_format: str,
@@ -427,15 +429,15 @@ def __init__(
427429
):
428430
if (
429431
patch_strategy
430-
not in SemanticSegmentationDataset._available_patch_strategies
432+
not in SemanticSegmentationPyTorchDataset._available_patch_strategies
431433
):
432434
raise ValueError(
433435
f"Parameter `patch_strategy` must be one of {self._available_patch_strategies}"
434436
)
435437

436438
if (
437439
cache_strategy
438-
not in SemanticSegmentationDataset._available_cache_strategies
440+
not in SemanticSegmentationPyTorchDataset._available_cache_strategies
439441
):
440442
raise ValueError(
441443
f"Parameter `cache_strategy` must be one of {self._available_cache_strategies}"
@@ -456,24 +458,19 @@ def __init__(
456458
'Parameter `patch_dim` must not be None if `patch_strategy is "crop_all"'
457459
)
458460

459-
coco = CocoDataset(
460-
labels_filepath=labels_filepath,
461-
root_dir=root_dir,
462-
classes=classes,
463-
annotation_format=annotation_format,
464-
)
465-
466461
if patch_strategy == "resize":
467-
self.dataset = SemanticSegmentationResizeDataset(coco, resize_dim)
462+
self.dataset = SemanticSegmentationResizeDataset(
463+
dataset, resize_dim
464+
)
468465
elif patch_strategy == "deterministic_center_crop":
469466
self.dataset = (
470467
SemanticSegmentationWithDeterministicPatchingDataset(
471-
coco, patch_dim
468+
dataset, patch_dim
472469
)
473470
)
474471
elif patch_strategy == "crop_all":
475472
self.dataset = SemanticSegmentationDatasetFullCoverage(
476-
coco, patch_dim
473+
dataset, patch_dim
477474
)
478475

479476
self.root_dir = root_dir
@@ -570,7 +567,7 @@ class ToySemanticSegmentationDataset(torch.utils.data.Dataset):
570567
"""Toy semantic segmentation dataset for integration testing purposes"""
571568

572569
def __init__(self, *args, **kwargs):
573-
self._dataset = SemanticSegmentationDataset(*args, **kwargs)
570+
self._dataset = SemanticSegmentationPyTorchDataset(*args, **kwargs)
574571

575572
def __getitem__(self, idx):
576573
return self._dataset[idx]

contrib/segmentation/tests/datasets/test_semantic_segmentation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from PIL import Image
66

77
from src.datasets.semantic_segmentation import (
8-
SemanticSegmentationDataset,
8+
SemanticSegmentationPyTorchDataset,
99
)
1010

1111

@@ -53,7 +53,7 @@ def test_semantic_segmentation_dataset(
5353
"src.datasets.semantic_segmentation.Image.open",
5454
return_value=high_resolution_image,
5555
)
56-
dataset = SemanticSegmentationDataset(
56+
dataset = SemanticSegmentationPyTorchDataset(
5757
standard_labels_filepath,
5858
root_dir="data",
5959
classes=classes,

contrib/segmentation/train.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.utils.data.dataloader import DataLoader
1717

1818
from src.datasets.semantic_segmentation import (
19-
SemanticSegmentationDataset,
19+
SemanticSegmentationPyTorchDataset,
2020
SemanticSegmentationStochasticPatchingDataset,
2121
ToySemanticSegmentationDataset,
2222
)
@@ -151,7 +151,7 @@ def forward(self, x):
151151
default="",
152152
)
153153
parser.add_argument("--toy", type=bool, required=False, default=False)
154-
parser.add_argument("--classes", type=str, default="1, 2, 3, 4")
154+
parser.add_argument("--classes", type=str, default="1, 2")
155155
parser.add_argument(
156156
"--log-file", type=str, required=False, default="train.log"
157157
)
@@ -195,7 +195,7 @@ def forward(self, x):
195195
"--class-balance", type=str2bool, required=False, default=False
196196
)
197197
parser.add_argument(
198-
"--cache-strategy", type=str, required=False, default="memory"
198+
"--cache-strategy", type=str, required=False, default="none"
199199
)
200200
args = parser.parse_args()
201201

@@ -204,8 +204,8 @@ def forward(self, x):
204204

205205
train_dir = str(args.train_dir)
206206
val_dir = str(args.val_dir)
207-
experiment_dir = str(uuid.uuid4())
208-
model_dir = join(train_dir, experiment_dir)
207+
208+
model_dir = join("outputs", "models")
209209
Path(model_dir).mkdir(parents=True, exist_ok=True)
210210

211211
if args.cache_dir is not None:
@@ -285,7 +285,7 @@ def forward(self, x):
285285
)
286286
# Toy Dataset for Integration Testing Purposes
287287
Dataset = (
288-
SemanticSegmentationDataset
288+
SemanticSegmentationPyTorchDataset
289289
if not is_toy
290290
else ToySemanticSegmentationDataset
291291
)
@@ -355,14 +355,18 @@ def forward(self, x):
355355
f"Validation dataset number of images: {dataset_val_len} | Batch size: {batch_size} | Expected number of batches: {tot_validation_batches}"
356356
)
357357

358-
num_classes: int = classes[-1] + 1 # Plus 1 for background
359-
classes = [class_id_to_class_name[i] for i in range(num_classes)]
358+
num_classes: int = len(classes) + 1 # Plus 1 for background
360359

361360
# define training and validation data loaders
362361
# drop_last True to avoid single instances which throw an error on batch norm layers
363362

364363
# Maxing the num_workers at 8 due to shared memory limitations
365-
num_workers = min(int(round(multiprocessing.cpu_count() * 2 / 3)), 8)
364+
num_workers = min(
365+
# Preferably use 2/3's of total cpus. If the cpu count is 1, it will be set to 0 which will result
366+
# in dataloader using the main thread
367+
int(round(multiprocessing.cpu_count() * 2 / 3)),
368+
8,
369+
)
366370

367371
dataloader = DataLoader(
368372
dataset,
@@ -386,7 +390,9 @@ def forward(self, x):
386390
model = get_fcn_resnet50(num_classes, pretrained=pretrained)
387391
elif model_name == "deeplab":
388392
model = DeepLabModelWrapper(
389-
num_classes, pretrained=pretrained
393+
num_classes,
394+
pretrained=pretrained,
395+
is_feature_extracting=pretrained,
390396
) # get_deeplabv3(num_classes, is_feature_extracting=pretrained)
391397
else:
392398
raise ValueError(

0 commit comments

Comments
 (0)