Skip to content

Migrating training scripts to torchrun #1933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,32 @@ python references/detection/train_pytorch.py db_resnet50 --train_path path/to/yo

### Multi-GPU support (PyTorch only)

Multi-GPU support on Detection task with PyTorch has been added.
Arguments are the same than the ones from single GPU, except:

- `--devices`: **by default, if you do not pass `--devices`, it will use all GPUs on your computer**.
You can use specific GPUs by passing a list of ids (ex: `0 1 2`). To find them, you can use the following snippet:

```python
import torch
devices = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
device_names = [torch.cuda.get_device_name(d) for d in devices]
```
We now use the built-in [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) launcher to spawn your DDP workers. `torchrun` will set all the necessary environment variables (`LOCAL_RANK`, `RANK`, etc.) for you. Arguments are the same than the ones from single GPU, except:

- `--backend`: you can specify another `backend` for `DistribuedDataParallel` if the default one is not available on
your operating system. Fastest one is `nccl` according to [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html).

#### Key `torchrun` parameters:
- `--nproc_per_node=<N>`
Spawn `<N>` processes on the local machine (typically equal to the number of GPUs you want to use).
- `--nnodes=<M>`
(Optional) Total number of nodes in your job. Default is 1.
- `--rdzv_backend`, `--rdzv_endpoint`, `--rdzv_id`
(Optional) Rendezvous settings for multi-node jobs. See the [torchrun docs](https://pytorch.org/docs/stable/elastic/run.html) for details.

#### GPU selection:
By default all visible GPUs will be used. To limit which GPUs participate, set the `CUDA_VISIBLE_DEVICES` environment variable **before** running `torchrun`. For example, to use only CUDA devices 0 and 2:

```shell
python references/detection/train_pytorch_ddp.py db_resnet50 --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 --devices 0 1 --backend nccl
```
CUDA_VISIBLE_DEVICES=0,2 \
torchrun --nproc_per_node=2 references/detection/train_pytorch_ddp.py \
db_resnet50 \
--train_path path/to/train \
--val_path path/to/val \
--epochs 5 \
--backend nccl
```


## Data format

Expand Down
27 changes: 11 additions & 16 deletions references/detection/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

# The following import is required for DDP
import torch.distributed as dist
import torch.multiprocessing as mp
import wandb
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR
Expand Down Expand Up @@ -194,6 +193,13 @@
world_size (int): number of processes participating in the job
args: other arguments passed through the CLI
"""
# Setup device and distributed
rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(rank)
dist.init_process_group(backend=args.backend)

world_size = dist.get_world_size()

slack_token = os.getenv("TQDM_SLACK_TOKEN")
slack_channel = os.getenv("TQDM_SLACK_CHANNEL")

Expand Down Expand Up @@ -270,12 +276,11 @@
model.load_state_dict(checkpoint)

# create default process group
device = torch.device("cuda", args.devices[rank])
dist.init_process_group(args.backend, rank=rank, world_size=world_size)
device = torch.device("cuda", rank)
# create local model
model = model.to(device)
# construct the DDP model
model = DDP(model, device_ids=[device])
model = DDP(model, device_ids=[rank])

if rank == 0:
# Metrics
Expand Down Expand Up @@ -509,7 +514,7 @@
)

# DDP related args
parser.add_argument("--backend", default="nccl", type=str, help="backend to use for torch DDP")
parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for torch.distributed")

parser.add_argument("arch", type=str, help="text-detection model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
Expand Down Expand Up @@ -567,14 +572,4 @@
args = parse_args()
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPUs. Please investigate!")

if not isinstance(args.devices, list):
args.devices = list(range(torch.cuda.device_count()))
# no of process per gpu
nprocs = len(args.devices)
# Environment variables which need to be
# set when using c10d's default "env"
# initialization mode.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
mp.spawn(main, args=(nprocs, args), nprocs=nprocs, join=True)
main(args)

Check failure on line 575 in references/detection/train_pytorch_ddp.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

references/detection/train_pytorch_ddp.py#L575

No value for argument 'world_size' in function call
35 changes: 22 additions & 13 deletions references/recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,36 @@ or PyTorch:
python references/recognition/train_pytorch.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5
```

### Multi-GPU support (PyTorch only - Experimental)
### Multi-GPU support (PyTorch only )

Multi-GPU support on recognition task with PyTorch has been added. It'll be probably added for other tasks.
Arguments are the same than the ones from single GPU, except:

- `--devices`: **by default, if you do not pass `--devices`, it will use all GPUs on your computer**.
You can use specific GPUs by passing a list of ids (ex: `0 1 2`). To find them, you can use the following snippet:

```python
import torch
devices = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
device_names = [torch.cuda.get_device_name(d) for d in devices]
```
We now use the built-in [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) launcher to spawn your DDP workers. `torchrun` will set all the necessary environment variables (`LOCAL_RANK`, `RANK`, etc.) for you. Arguments are the same than the ones from single GPU, except:

- `--backend`: you can specify another `backend` for `DistribuedDataParallel` if the default one is not available on
your operating system. Fastest one is `nccl` according to [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html).

#### Key `torchrun` parameters:
- `--nproc_per_node=<N>`
Spawn `<N>` processes on the local machine (typically equal to the number of GPUs you want to use).
- `--nnodes=<M>`
(Optional) Total number of nodes in your job. Default is 1.
- `--rdzv_backend`, `--rdzv_endpoint`, `--rdzv_id`
(Optional) Rendezvous settings for multi-node jobs. See the [torchrun docs](https://pytorch.org/docs/stable/elastic/run.html) for details.

#### GPU selection:
By default all visible GPUs will be used. To limit which GPUs participate, set the `CUDA_VISIBLE_DEVICES` environment variable **before** running `torchrun`. For example, to use only CUDA devices 0 and 2:

```shell
python references/recognition/train_pytorch_ddp.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 --devices 0 1 --backend nccl
CUDA_VISIBLE_DEVICES=0,2 \
torchrun --nproc_per_node=2 references/recognition/train_pytorch_ddp.py \
crnn_vgg16_bn \
--train_path path/to/train \
--val_path path/to/val \
--epochs 5 \
--backend nccl
```



## Data format

You need to provide both `train_path` and `val_path` arguments to start training.
Expand Down
31 changes: 13 additions & 18 deletions references/recognition/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import wandb
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, PolynomialLR
Expand Down Expand Up @@ -117,13 +116,20 @@
return val_loss, result["raw"], result["unicase"]


def main(rank: int, world_size: int, args):
def main(args):

Check notice on line 119 in references/recognition/train_pytorch_ddp.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

references/recognition/train_pytorch_ddp.py#L119

main is too complex (39) (MC0001)
"""
Args:
rank (int): device id to put the model on
world_size (int): number of processes participating in the job
args: other arguments passed through the CLI
"""
# Setup device and distributed
rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(rank)
dist.init_process_group(backend=args.backend)

world_size = dist.get_world_size()

slack_token = os.getenv("TQDM_SLACK_TOKEN")
slack_channel = os.getenv("TQDM_SLACK_CHANNEL")

Expand Down Expand Up @@ -225,13 +231,11 @@
for p in model.feat_extractor.parameters():
p.requires_grad = False

# create default process group
device = torch.device("cuda", args.devices[rank])
dist.init_process_group(args.backend, rank=rank, world_size=world_size)
device = torch.device("cuda", rank)
# create local model
model = model.to(device)
# construct DDP model
model = DDP(model, device_ids=[device])
model = DDP(model, device_ids=[rank])

if rank == 0:
# Metrics
Expand Down Expand Up @@ -384,7 +388,6 @@
"train_hash": train_hash,
"val_hash": val_hash,
"pretrained": args.pretrained,
"rotation": args.rotation,
"amp": args.amp,
}

Expand Down Expand Up @@ -469,12 +472,12 @@
import argparse

parser = argparse.ArgumentParser(
description="DocTR DDP training script for text recognition (PyTorch)",
description="DocTR torchrun training script for text recognition (PyTorch)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

# DDP related args
parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for Torch DDP")
parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for torch.distributed")

parser.add_argument("arch", type=str, help="text-recognition model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
Expand Down Expand Up @@ -556,12 +559,4 @@
args = parse_args()
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPUs. Please investigate!")
if not isinstance(args.devices, list):
args.devices = list(range(torch.cuda.device_count()))
nprocs = len(args.devices)
# Environment variables which need to be
# set when using c10d's default "env"
# initialization mode.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
mp.spawn(main, args=(nprocs, args), nprocs=nprocs, join=True)
main(args)
Loading