diff --git a/references/detection/README.md b/references/detection/README.md index e927d9d05..c62d793be 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -27,24 +27,32 @@ python references/detection/train_pytorch.py db_resnet50 --train_path path/to/yo ### Multi-GPU support (PyTorch only) -Multi-GPU support on Detection task with PyTorch has been added. -Arguments are the same than the ones from single GPU, except: - -- `--devices`: **by default, if you do not pass `--devices`, it will use all GPUs on your computer**. -You can use specific GPUs by passing a list of ids (ex: `0 1 2`). To find them, you can use the following snippet: - -```python -import torch -devices = [torch.cuda.device(i) for i in range(torch.cuda.device_count())] -device_names = [torch.cuda.get_device_name(d) for d in devices] -``` +We now use the built-in [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) launcher to spawn your DDP workers. `torchrun` will set all the necessary environment variables (`LOCAL_RANK`, `RANK`, etc.) for you. Arguments are the same than the ones from single GPU, except: - `--backend`: you can specify another `backend` for `DistribuedDataParallel` if the default one is not available on your operating system. Fastest one is `nccl` according to [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). +#### Key `torchrun` parameters: +- `--nproc_per_node=` + Spawn `` processes on the local machine (typically equal to the number of GPUs you want to use). +- `--nnodes=` + (Optional) Total number of nodes in your job. Default is 1. +- `--rdzv_backend`, `--rdzv_endpoint`, `--rdzv_id` + (Optional) Rendezvous settings for multi-node jobs. See the [torchrun docs](https://pytorch.org/docs/stable/elastic/run.html) for details. + +#### GPU selection: +By default all visible GPUs will be used. To limit which GPUs participate, set the `CUDA_VISIBLE_DEVICES` environment variable **before** running `torchrun`. For example, to use only CUDA devices 0 and 2: + ```shell -python references/detection/train_pytorch_ddp.py db_resnet50 --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 --devices 0 1 --backend nccl -``` +CUDA_VISIBLE_DEVICES=0,2 \ +torchrun --nproc_per_node=2 references/detection/train_pytorch.py \ + db_resnet50 \ + --train_path path/to/train \ + --val_path path/to/val \ + --epochs 5 \ + --backend nccl + ``` + ## Data format diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index bf4363a17..483d37039 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -10,14 +10,19 @@ import datetime import hashlib import logging -import multiprocessing as mp +import multiprocessing import time from pathlib import Path import numpy as np import torch + +# The following import is required for DDP +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): @@ -103,14 +108,14 @@ def record_lr( return lr_recorder[: len(loss_recorder)], loss_recorder -def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None): +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: scaler = torch.cuda.amp.GradScaler() model.train() # Iterate over the batches of the dataset epoch_train_loss, batch_cnt = 0, 0 - pbar = tqdm(train_loader, dynamic_ncols=True) + pbar = tqdm(train_loader, dynamic_ncols=True, disable=(rank != 0)) for images, targets in pbar: if torch.cuda.is_available(): images = images.cuda() @@ -137,7 +142,8 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a last_lr = scheduler.get_last_lr()[0] pbar.set_description(f"Training loss: {train_loss.item():.6} | LR: {last_lr:.6}") - log(train_loss=train_loss.item(), lr=last_lr) + if log: + log(train_loss=train_loss.item(), lr=last_lr) epoch_train_loss += train_loss.item() batch_cnt += 1 @@ -147,7 +153,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a @torch.no_grad() -def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=None): +def evaluate(model, val_loader, batch_transforms, val_metric, args, amp=False, log=None): # Model in eval mode model.eval() # Reset val metric @@ -174,7 +180,8 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) pbar.set_description(f"Validation loss: {out['loss'].item():.6}") - log(val_loss=out["loss"].item()) + if log: + log(val_loss=out["loss"].item()) val_loss += out["loss"].item() batch_cnt += 1 @@ -185,65 +192,105 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non def main(args): + """ + Args: + rank (int): device id to put the model on + world_size (int): number of processes participating in the job + args: other arguments passed through the CLI + """ + world_size = int(os.environ.get("WORLD_SIZE", 1)) + distributed = world_size > 1 + + # GPU setup + if distributed: + rank = int(os.environ.get("LOCAL_RANK", 0)) + dist.init_process_group(backend=args.backend) + device = torch.device("cuda", rank) + torch.cuda.set_device(device) + + else: + # single process + rank = 0 + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + device = torch.device("cuda", args.device) + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + device = torch.device("cuda", 0) + else: + logging.warning("No accessible GPU, target device set to CPU.") + device = torch.device("cpu") + slack_token = os.getenv("TQDM_SLACK_TOKEN") slack_channel = os.getenv("TQDM_SLACK_CHANNEL") - pbar = tqdm(disable=False if slack_token and slack_channel else True) + pbar = tqdm(disable=False if (slack_token and slack_channel) and (rank == 0) else True) if slack_token and slack_channel: # Monkey patch tqdm write method to send messages directly to Slack pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg) pbar.write(str(args)) - if args.push_to_hub: + if rank == 0 and args.push_to_hub: login_to_hub() if not isinstance(args.workers, int): - args.workers = min(16, mp.cpu_count()) + args.workers = min(16, multiprocessing.cpu_count()) torch.backends.cudnn.benchmark = True - st = time.time() - val_set = DetectionDataset( - img_folder=os.path.join(args.val_path, "images"), - label_path=os.path.join(args.val_path, "labels.json"), - sample_transforms=T.SampleCompose( - ( - [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] - if not args.rotation or args.eval_straight - else [] - ) - + ( - [ - T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad - T.RandomApply(T.RandomRotate(90, expand=True), 0.5), - T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), - ] - if args.rotation and not args.eval_straight - else [] - ) - ), - use_polygons=args.rotation and not args.eval_straight, - ) - val_loader = DataLoader( - val_set, - batch_size=args.batch_size, - drop_last=False, - num_workers=args.workers, - sampler=SequentialSampler(val_set), - pin_memory=torch.cuda.is_available(), - collate_fn=val_set.collate_fn, - ) - pbar.write(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)") - with open(os.path.join(args.val_path, "labels.json"), "rb") as f: - val_hash = hashlib.sha256(f.read()).hexdigest() + if rank == 0: + # validation dataset related code + st = time.time() + val_set = DetectionDataset( + img_folder=os.path.join(args.val_path, "images"), + label_path=os.path.join(args.val_path, "labels.json"), + sample_transforms=T.SampleCompose( + ( + [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] + if not args.rotation or args.eval_straight + else [] + ) + + ( + [ + T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + ] + if args.rotation and not args.eval_straight + else [] + ) + ), + use_polygons=args.rotation and not args.eval_straight, + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(val_set), + pin_memory=torch.cuda.is_available(), + collate_fn=val_set.collate_fn, + ) + pbar.write( + f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)" + ) + with open(os.path.join(args.val_path, "labels.json"), "rb") as f: + val_hash = hashlib.sha256(f.read()).hexdigest() + + class_names = val_set.class_names + else: + class_names = None batch_transforms = Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)) - # Load doctr model + # Load docTR model model = detection.__dict__[args.arch]( pretrained=args.pretrained, assume_straight_pages=not args.rotation, - class_names=val_set.class_names, + class_names=class_names, ) # Resume weights @@ -252,27 +299,15 @@ def main(args): checkpoint = torch.load(args.resume, map_location="cpu") model.load_state_dict(checkpoint) - # GPU - if isinstance(args.device, int): - if not torch.cuda.is_available(): - raise AssertionError("PyTorch cannot access your GPU. Please investigate!") - if args.device >= torch.cuda.device_count(): - raise ValueError("Invalid device index") - # Silent default switch to GPU if available - elif torch.cuda.is_available(): - args.device = 0 - else: - logging.warning("No accessible GPU, target device set to CPU.") - if torch.cuda.is_available(): - torch.cuda.set_device(args.device) - model = model.cuda() - - # Metrics - val_metric = LocalizationConfusion(use_polygons=args.rotation and not args.eval_straight) + if rank == 0: + # Metrics + val_metric = LocalizationConfusion(use_polygons=args.rotation and not args.eval_straight) - if args.test_only: + if rank == 0 and args.test_only: pbar.write("Running evaluation") - val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) + val_loss, recall, precision, mean_iou = evaluate( + model, val_loader, batch_transforms, val_metric, args, amp=args.amp + ) pbar.write( f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | " f"Mean IoU: {mean_iou:.2%})" @@ -331,29 +366,45 @@ def main(args): use_polygons=args.rotation, ) + if distributed: + sampler = DistributedSampler(train_set, rank=rank, shuffle=False, drop_last=True) + else: + sampler = RandomSampler(train_set) + train_loader = DataLoader( train_set, batch_size=args.batch_size, drop_last=True, num_workers=args.workers, - sampler=RandomSampler(train_set), + sampler=sampler, pin_memory=torch.cuda.is_available(), collate_fn=train_set.collate_fn, ) - pbar.write(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)") + if rank == 0: + pbar.write( + f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)" + ) + with open(os.path.join(args.train_path, "labels.json"), "rb") as f: train_hash = hashlib.sha256(f.read()).hexdigest() - if args.show_samples: + if rank == 0 and args.show_samples: x, target = next(iter(train_loader)) plot_samples(x, target) - return + # return # Backbone freezing if args.freeze_backbone: for p in model.feat_extractor.parameters(): p.requires_grad = False + if torch.cuda.is_available(): + torch.cuda.set_device(device) + model = model.to(device) + + if distributed: + # construct DDP model + model = DDP(model, device_ids=[rank]) # Optimizer if args.optim == "adam": optimizer = torch.optim.Adam( @@ -373,7 +424,7 @@ def main(args): ) # LR Finder - if args.find_lr: + if rank == 0 and args.find_lr: lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) plot_recorder(lrs, losses) return @@ -389,22 +440,24 @@ def main(args): # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name - config = { - "learning_rate": args.lr, - "epochs": args.epochs, - "weight_decay": args.weight_decay, - "batch_size": args.batch_size, - "architecture": args.arch, - "input_size": args.input_size, - "optimizer": args.optim, - "framework": "pytorch", - "scheduler": args.sched, - "train_hash": train_hash, - "val_hash": val_hash, - "pretrained": args.pretrained, - "rotation": args.rotation, - "amp": args.amp, - } + + if rank == 0: + config = { + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": args.weight_decay, + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": args.optim, + "framework": "pytorch", + "scheduler": args.sched, + "train_hash": train_hash, + "val_hash": val_hash, + "pretrained": args.pretrained, + "rotation": args.rotation, + "amp": args.amp, + } global global_step global_step = 0 # Shared global step counter @@ -470,58 +523,62 @@ def log_at_step(train_loss=None, val_loss=None, lr=None): # Training loop for epoch in range(args.epochs): train_loss, actual_lr = fit_one_epoch( - model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, log=log_at_step + model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, log=log_at_step, rank=rank ) pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Training loss: {train_loss:.6} | LR: {actual_lr:.6}") - # Validation loop at the end of each epoch - val_loss, recall, precision, mean_iou = evaluate( - model, val_loader, batch_transforms, val_metric, amp=args.amp, log=log_at_step - ) - if val_loss < min_loss: - pbar.write(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") - min_loss = val_loss - if args.save_interval_epoch: - pbar.write(f"Saving state at epoch: {epoch + 1}") - torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt") - log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " - if any(val is None for val in (recall, precision, mean_iou)): - log_msg += "(Undefined metric value, caused by empty GTs or predictions)" - else: - log_msg += f"(Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})" - pbar.write(log_msg) - # W&B - if args.wb: - wandb.log({ - "train_loss": train_loss, - "val_loss": val_loss, - "learning_rate": actual_lr, - "recall": recall, - "precision": precision, - "mean_iou": mean_iou, - }) - - # ClearML - if args.clearml: - from clearml import Logger - logger = Logger.current_logger() - logger.report_scalar(title="Training Loss", series="train_loss", value=train_loss, iteration=epoch) - logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch) - logger.report_scalar(title="Learning Rate", series="lr", value=actual_lr, iteration=epoch) - logger.report_scalar(title="Recall", series="recall", value=recall, iteration=epoch) - logger.report_scalar(title="Precision", series="precision", value=precision, iteration=epoch) - logger.report_scalar(title="Mean IoU", series="mean_iou", value=mean_iou, iteration=epoch) - - if args.early_stop and early_stopper.early_stop(val_loss): - pbar.write("Training halted early due to reaching patience limit.") - break + if rank == 0: + # Validation loop at the end of each epoch + val_loss, recall, precision, mean_iou = evaluate( + model, val_loader, batch_transforms, val_metric, args, amp=args.amp, log=log_at_step + ) + if val_loss < min_loss: + pbar.write(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") + params = model.module if hasattr(model, "module") else model + torch.save(params.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") + min_loss = val_loss + if args.save_interval_epoch: + pbar.write(f"Saving state at epoch: {epoch + 1}") + torch.save(model.module.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt") + log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " + if any(val is None for val in (recall, precision, mean_iou)): + log_msg += "(Undefined metric value, caused by empty GTs or predictions)" + else: + log_msg += f"(Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})" + pbar.write(log_msg) + # W&B + if args.wb: + wandb.log({ + "train_loss": train_loss, + "val_loss": val_loss, + "learning_rate": actual_lr, + "recall": recall, + "precision": precision, + "mean_iou": mean_iou, + }) + + # ClearML + if args.clearml: + from clearml import Logger + + logger = Logger.current_logger() + logger.report_scalar(title="Training Loss", series="train_loss", value=train_loss, iteration=epoch) + logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch) + logger.report_scalar(title="Learning Rate", series="lr", value=actual_lr, iteration=epoch) + logger.report_scalar(title="Recall", series="recall", value=recall, iteration=epoch) + logger.report_scalar(title="Precision", series="precision", value=precision, iteration=epoch) + logger.report_scalar(title="Mean IoU", series="mean_iou", value=mean_iou, iteration=epoch) + + if args.early_stop and early_stopper.early_stop(val_loss): + pbar.write("Training halted early due to reaching patience limit.") + break - if args.wb: - run.finish() + if rank == 0: + if args.wb: + run.finish() - if args.push_to_hub: - push_to_hf_hub(model, exp_name, task="detection", run_config=args) + if args.push_to_hub: + push_to_hf_hub(model, exp_name, task="detection", run_config=args) def parse_args(): @@ -532,6 +589,14 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + # DDP related args + parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for torch.distributed") + parser.add_argument( + "--device", + default=None, + type=int, + help="Specify gpu device for single-gpu training. In destributed setting, this parameter is ignored", + ) parser.add_argument("arch", type=str, help="text-detection model to train") parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") @@ -539,14 +604,13 @@ def parse_args(): parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") - parser.add_argument("--device", default=None, type=int, help="device") parser.add_argument( "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" ) parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)") parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay") - parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") + parser.add_argument("-j", "--workers", type=int, default=0, help="number of workers used for dataloading") parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") parser.add_argument( diff --git a/references/detection/train_pytorch_ddp.py b/references/detection/train_pytorch_ddp.py deleted file mode 100644 index 641b6790a..000000000 --- a/references/detection/train_pytorch_ddp.py +++ /dev/null @@ -1,580 +0,0 @@ -# Copyright (C) 2021-2025, Mindee. - -# This program is licensed under the Apache License 2.0. -# See LICENSE or go to for full license details. - -import os - -os.environ["USE_TORCH"] = "1" - -import datetime -import hashlib -import multiprocessing -import time -from pathlib import Path - -import numpy as np -import torch - -# The following import is required for DDP -import torch.distributed as dist -import torch.multiprocessing as mp -import wandb -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR -from torch.utils.data import DataLoader, SequentialSampler -from torch.utils.data.distributed import DistributedSampler -from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort - -if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): - from tqdm.contrib.slack import tqdm -else: - from tqdm.auto import tqdm - -from doctr import transforms as T -from doctr.datasets import DetectionDataset -from doctr.models import detection, login_to_hub, push_to_hf_hub -from doctr.utils.metrics import LocalizationConfusion -from utils import EarlyStopper, plot_recorder, plot_samples - - -def record_lr( - model: torch.nn.Module, - train_loader: DataLoader, - batch_transforms, - optimizer, - start_lr: float = 1e-7, - end_lr: float = 1, - num_it: int = 100, - amp: bool = False, -): - """Gridsearch the optimal learning rate for the training. - Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py - """ - if num_it > len(train_loader): - raise ValueError("the value of `num_it` needs to be lower than the number of available batches") - - model = model.train() - # Update param groups & LR - optimizer.defaults["lr"] = start_lr - for pgroup in optimizer.param_groups: - pgroup["lr"] = start_lr - - gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) - scheduler = MultiplicativeLR(optimizer, lambda step: gamma) - - lr_recorder = [start_lr * gamma**idx for idx in range(num_it)] - loss_recorder = [] - - if amp: - scaler = torch.cuda.amp.GradScaler() - - for batch_idx, (images, targets) in enumerate(train_loader): - if torch.cuda.is_available(): - images = images.cuda() - - images = batch_transforms(images) - - # Forward, Backward & update - optimizer.zero_grad() - if amp: - with torch.cuda.amp.autocast(): - train_loss = model(images, targets)["loss"] - scaler.scale(train_loss).backward() - # Gradient clipping - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - # Update the params - scaler.step(optimizer) - scaler.update() - else: - train_loss = model(images, targets)["loss"] - train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - optimizer.step() - # Update LR - scheduler.step() - - # Record - if not torch.isfinite(train_loss): - if batch_idx == 0: - raise ValueError("loss value is NaN or inf.") - else: - break - loss_recorder.append(train_loss.item()) - # Stop after the number of iterations - if batch_idx + 1 == num_it: - break - - return lr_recorder[: len(loss_recorder)], loss_recorder - - -def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False): - if amp: - scaler = torch.cuda.amp.GradScaler() - - model.train() - # Iterate over the batches of the dataset - epoch_train_loss, batch_cnt = 0, 0 - pbar = tqdm(train_loader, dynamic_ncols=True) - for images, targets in pbar: - if torch.cuda.is_available(): - images = images.cuda() - images = batch_transforms(images) - - optimizer.zero_grad() - if amp: - with torch.cuda.amp.autocast(): - train_loss = model(images, targets)["loss"] - scaler.scale(train_loss).backward() - # Gradient clipping - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - # Update the params - scaler.step(optimizer) - scaler.update() - else: - train_loss = model(images, targets)["loss"] - train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - optimizer.step() - - scheduler.step() - last_lr = scheduler.get_last_lr()[0] - - pbar.set_description(f"Training loss: {train_loss.item():.6} | LR: {last_lr:.6}") - epoch_train_loss += train_loss.item() - batch_cnt += 1 - - epoch_train_loss /= batch_cnt - return epoch_train_loss, last_lr - - -@torch.no_grad() -def evaluate(model, val_loader, batch_transforms, val_metric, args, amp=False): - # Model in eval mode - model.eval() - # Reset val metric - val_metric.reset() - # Validation loop - val_loss, batch_cnt = 0, 0 - pbar = tqdm(val_loader, dynamic_ncols=True) - for images, targets in pbar: - if torch.cuda.is_available(): - images = images.cuda() - images = batch_transforms(images) - if amp: - with torch.cuda.amp.autocast(): - out = model(images, targets, return_preds=True) - else: - out = model(images, targets, return_preds=True) - # Compute metric - loc_preds = out["preds"] - for target, loc_pred in zip(targets, loc_preds): - for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()): - if args.rotation and args.eval_straight: - # Convert pred to boxes [xmin, ymin, xmax, ymax] N, 5, 2 (with scores) --> N, 4 - boxes_pred = np.concatenate((boxes_pred[:, :4].min(axis=1), boxes_pred[:, :4].max(axis=1)), axis=-1) - val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) - - pbar.set_description(f"Validation loss: {out['loss'].item():.6}") - - val_loss += out["loss"].item() - batch_cnt += 1 - - val_loss /= batch_cnt - recall, precision, mean_iou = val_metric.summary() - return val_loss, recall, precision, mean_iou - - -def main(rank: int, world_size: int, args): - """ - Args: - rank (int): device id to put the model on - world_size (int): number of processes participating in the job - args: other arguments passed through the CLI - """ - slack_token = os.getenv("TQDM_SLACK_TOKEN") - slack_channel = os.getenv("TQDM_SLACK_CHANNEL") - - pbar = tqdm(disable=False if slack_token and slack_channel else True) - if slack_token and slack_channel: - # Monkey patch tqdm write method to send messages directly to Slack - pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg) - pbar.write(str(args)) - - if rank == 0 and args.push_to_hub: - login_to_hub() - - if not isinstance(args.workers, int): - args.workers = min(16, multiprocessing.cpu_count()) - - torch.backends.cudnn.benchmark = True - - if rank == 0: - # validation dataset related code - st = time.time() - val_set = DetectionDataset( - img_folder=os.path.join(args.val_path, "images"), - label_path=os.path.join(args.val_path, "labels.json"), - sample_transforms=T.SampleCompose( - ( - [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] - if not args.rotation or args.eval_straight - else [] - ) - + ( - [ - T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad - T.RandomApply(T.RandomRotate(90, expand=True), 0.5), - T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), - ] - if args.rotation and not args.eval_straight - else [] - ) - ), - use_polygons=args.rotation and not args.eval_straight, - ) - val_loader = DataLoader( - val_set, - batch_size=args.batch_size, - drop_last=False, - num_workers=args.workers, - sampler=SequentialSampler(val_set), - pin_memory=torch.cuda.is_available(), - collate_fn=val_set.collate_fn, - ) - pbar.write( - f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)" - ) - with open(os.path.join(args.val_path, "labels.json"), "rb") as f: - val_hash = hashlib.sha256(f.read()).hexdigest() - - class_names = val_set.class_names - else: - class_names = None - - batch_transforms = Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)) - - # Load docTR model - model = detection.__dict__[args.arch]( - pretrained=args.pretrained, - assume_straight_pages=not args.rotation, - class_names=class_names, - ) - - # Resume weights - if isinstance(args.resume, str): - pbar.write(f"Resuming {args.resume}") - checkpoint = torch.load(args.resume, map_location="cpu") - model.load_state_dict(checkpoint) - - # create default process group - device = torch.device("cuda", args.devices[rank]) - dist.init_process_group(args.backend, rank=rank, world_size=world_size) - # create local model - model = model.to(device) - # construct the DDP model - model = DDP(model, device_ids=[device]) - - if rank == 0: - # Metrics - val_metric = LocalizationConfusion(use_polygons=args.rotation and not args.eval_straight) - - if rank == 0 and args.test_only: - pbar.write("Running evaluation") - val_loss, recall, precision, mean_iou = evaluate( - model, val_loader, batch_transforms, val_metric, args, amp=args.amp - ) - pbar.write( - f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | " - f"Mean IoU: {mean_iou:.2%})" - ) - return - - st = time.time() - # Augmentations - # Image augmentations - img_transforms = T.OneOf([ - Compose([ - T.RandomApply(T.ColorInversion(), 0.3), - T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2), - ]), - Compose([ - T.RandomApply(T.RandomShadow(), 0.3), - T.RandomApply(T.GaussianNoise(), 0.1), - T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), - RandomGrayscale(p=0.15), - ]), - RandomPhotometricDistort(p=0.3), - lambda x: x, # Identity no transformation - ]) - # Image + target augmentations - sample_transforms = T.SampleCompose( - ( - [ - T.RandomHorizontalFlip(0.15), - T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), - T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), - ]), - T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), - ] - if not args.rotation - else [ - T.RandomHorizontalFlip(0.15), - T.OneOf([ - T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), - T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), - ]), - # Rotation augmentation - T.Resize(args.input_size, preserve_aspect_ratio=True), - T.RandomApply(T.RandomRotate(90, expand=True), 0.5), - T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), - ] - ) - ) - - # Load both train and val data generators - train_set = DetectionDataset( - img_folder=os.path.join(args.train_path, "images"), - label_path=os.path.join(args.train_path, "labels.json"), - img_transforms=img_transforms, - sample_transforms=sample_transforms, - use_polygons=args.rotation, - ) - - train_loader = DataLoader( - train_set, - batch_size=args.batch_size, - drop_last=True, - num_workers=args.workers, - sampler=DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True), - pin_memory=torch.cuda.is_available(), - collate_fn=train_set.collate_fn, - ) - pbar.write(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)") - - with open(os.path.join(args.train_path, "labels.json"), "rb") as f: - train_hash = hashlib.sha256(f.read()).hexdigest() - - if rank == 0 and args.show_samples: - x, target = next(iter(train_loader)) - plot_samples(x, target) - # return - - # Backbone freezing - if args.freeze_backbone: - for p in model.feat_extractor.parameters(): - p.requires_grad = False - - # Optimizer - if args.optim == "adam": - optimizer = torch.optim.Adam( - [p for p in model.parameters() if p.requires_grad], - args.lr, - betas=(0.95, 0.999), - eps=1e-6, - weight_decay=args.weight_decay, - ) - elif args.optim == "adamw": - optimizer = torch.optim.AdamW( - [p for p in model.parameters() if p.requires_grad], - args.lr, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=args.weight_decay or 1e-4, - ) - - # LR Finder - if args.find_lr: - lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) - plot_recorder(lrs, losses) - return - - # Scheduler - if args.sched == "cosine": - scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) - elif args.sched == "onecycle": - scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) - elif args.sched == "poly": - scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) - - # Training monitoring - current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name - - if rank == 0: - config = { - "learning_rate": args.lr, - "epochs": args.epochs, - "weight_decay": args.weight_decay, - "batch_size": args.batch_size, - "architecture": args.arch, - "input_size": args.input_size, - "optimizer": args.optim, - "framework": "pytorch", - "scheduler": args.sched, - "train_hash": train_hash, - "val_hash": val_hash, - "pretrained": args.pretrained, - "rotation": args.rotation, - "amp": args.amp, - } - - # W&B - if rank == 0 and args.wb: - run = wandb.init( - name=exp_name, - project="text-detection", - config=config, - ) - - # ClearML - if rank == 0 and args.clearml: - from clearml import Task - - task = Task.init(project_name="docTR/text-detection", task_name=exp_name, reuse_last_task_id=False) - task.upload_artifact("config", config) - - # Create loss queue - min_loss = np.inf - if args.early_stop: - early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta) - - # Training loop - for epoch in range(args.epochs): - train_loss, actual_lr = fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp) - pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Training loss: {train_loss:.6} | LR: {actual_lr:.6}") - - if rank == 0: - # Validation loop at the end of each epoch - val_loss, recall, precision, mean_iou = evaluate( - model, val_loader, batch_transforms, val_metric, args, amp=args.amp - ) - if val_loss < min_loss: - pbar.write(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.module.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") - min_loss = val_loss - if args.save_interval_epoch: - pbar.write(f"Saving state at epoch: {epoch + 1}") - torch.save(model.module.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt") - log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " - if any(val is None for val in (recall, precision, mean_iou)): - log_msg += "(Undefined metric value, caused by empty GTs or predictions)" - else: - log_msg += f"(Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})" - pbar.write(log_msg) - # W&B - if args.wb: - wandb.log({ - "train_loss": train_loss, - "val_loss": val_loss, - "learning_rate": actual_lr, - "recall": recall, - "precision": precision, - "mean_iou": mean_iou, - }) - - # ClearML - if args.clearml: - from clearml import Logger - - logger = Logger.current_logger() - logger.report_scalar(title="Training Loss", series="train_loss", value=train_loss, iteration=epoch) - logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch) - logger.report_scalar(title="Learning Rate", series="lr", value=actual_lr, iteration=epoch) - logger.report_scalar(title="Recall", series="recall", value=recall, iteration=epoch) - logger.report_scalar(title="Precision", series="precision", value=precision, iteration=epoch) - logger.report_scalar(title="Mean IoU", series="mean_iou", value=mean_iou, iteration=epoch) - - if args.early_stop and early_stopper.early_stop(val_loss): - pbar.write("Training halted early due to reaching patience limit.") - break - - if rank == 0: - if args.wb: - run.finish() - - if args.push_to_hub: - push_to_hf_hub(model, exp_name, task="detection", run_config=args) - - -def parse_args(): - import argparse - - parser = argparse.ArgumentParser( - description="DocTR DDP training script for text detection (PyTorch)", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # DDP related args - parser.add_argument("--backend", default="nccl", type=str, help="backend to use for torch DDP") - - parser.add_argument("arch", type=str, help="text-detection model to train") - parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") - parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") - parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") - parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") - parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") - parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") - parser.add_argument("--devices", default=None, nargs="+", type=int, help="GPU devices to use for training") - parser.add_argument( - "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" - ) - parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") - parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)") - parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay") - parser.add_argument("-j", "--workers", type=int, default=0, help="number of workers used for dataloading") - parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") - parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") - parser.add_argument( - "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" - ) - parser.add_argument( - "--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples" - ) - parser.add_argument("--wb", dest="wb", action="store_true", help="Log to Weights & Biases") - parser.add_argument("--clearml", dest="clearml", action="store_true", help="Log to ClearML") - parser.add_argument("--push-to-hub", dest="push_to_hub", action="store_true", help="Push to Huggingface Hub") - parser.add_argument( - "--pretrained", - dest="pretrained", - action="store_true", - help="Load pretrained parameters before starting the training", - ) - parser.add_argument("--rotation", dest="rotation", action="store_true", help="train with rotated documents") - parser.add_argument( - "--eval-straight", - action="store_true", - help="metrics evaluation with straight boxes instead of polygons to save time + memory", - ) - parser.add_argument("--optim", type=str, default="adam", choices=["adam", "adamw"], help="optimizer to use") - parser.add_argument( - "--sched", type=str, default="poly", choices=["cosine", "onecycle", "poly"], help="scheduler to use" - ) - parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") - parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR") - parser.add_argument("--early-stop", action="store_true", help="Enable early stopping") - parser.add_argument("--early-stop-epochs", type=int, default=5, help="Patience for early stopping") - parser.add_argument("--early-stop-delta", type=float, default=0.01, help="Minimum Delta for early stopping") - args = parser.parse_args() - - return args - - -if __name__ == "__main__": - args = parse_args() - if not torch.cuda.is_available(): - raise AssertionError("PyTorch cannot access your GPUs. Please investigate!") - - if not isinstance(args.devices, list): - args.devices = list(range(torch.cuda.device_count())) - # no of process per gpu - nprocs = len(args.devices) - # Environment variables which need to be - # set when using c10d's default "env" - # initialization mode. - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" - mp.spawn(main, args=(nprocs, args), nprocs=nprocs, join=True) diff --git a/references/recognition/README.md b/references/recognition/README.md index 84d326859..96628e597 100644 --- a/references/recognition/README.md +++ b/references/recognition/README.md @@ -25,27 +25,36 @@ or PyTorch: python references/recognition/train_pytorch.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` -### Multi-GPU support (PyTorch only - Experimental) +### Multi-GPU support (PyTorch only) -Multi-GPU support on recognition task with PyTorch has been added. It'll be probably added for other tasks. -Arguments are the same than the ones from single GPU, except: - -- `--devices`: **by default, if you do not pass `--devices`, it will use all GPUs on your computer**. -You can use specific GPUs by passing a list of ids (ex: `0 1 2`). To find them, you can use the following snippet: - -```python -import torch -devices = [torch.cuda.device(i) for i in range(torch.cuda.device_count())] -device_names = [torch.cuda.get_device_name(d) for d in devices] -``` +We now use the built-in [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) launcher to spawn your DDP workers. `torchrun` will set all the necessary environment variables (`LOCAL_RANK`, `RANK`, etc.) for you. Arguments are the same than the ones from single GPU, except: - `--backend`: you can specify another `backend` for `DistribuedDataParallel` if the default one is not available on your operating system. Fastest one is `nccl` according to [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). +#### Key `torchrun` parameters: +- `--nproc_per_node=` + Spawn `` processes on the local machine (typically equal to the number of GPUs you want to use). +- `--nnodes=` + (Optional) Total number of nodes in your job. Default is 1. +- `--rdzv_backend`, `--rdzv_endpoint`, `--rdzv_id` + (Optional) Rendezvous settings for multi-node jobs. See the [torchrun docs](https://pytorch.org/docs/stable/elastic/run.html) for details. + +#### GPU selection: +By default all visible GPUs will be used. To limit which GPUs participate, set the `CUDA_VISIBLE_DEVICES` environment variable **before** running `torchrun`. For example, to use only CUDA devices 0 and 2: + ```shell -python references/recognition/train_pytorch_ddp.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 --devices 0 1 --backend nccl +CUDA_VISIBLE_DEVICES=0,2 \ +torchrun --nproc_per_node=2 references/recognition/train_pytorch.py \ + crnn_vgg16_bn \ + --train_path path/to/train \ + --val_path path/to/val \ + --epochs 5 \ + --backend nccl ``` + + ## Data format You need to provide both `train_path` and `val_path` arguments to start training. diff --git a/references/recognition/train_pytorch.py b/references/recognition/train_pytorch.py index ed18ad10d..d2862d1de 100644 --- a/references/recognition/train_pytorch.py +++ b/references/recognition/train_pytorch.py @@ -10,14 +10,17 @@ import datetime import hashlib import logging -import multiprocessing as mp +import multiprocessing import time from pathlib import Path import numpy as np import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler from torchvision.transforms.v2 import ( Compose, Normalize, @@ -110,17 +113,17 @@ def record_lr( return lr_recorder[: len(loss_recorder)], loss_recorder -def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None): +def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: scaler = torch.cuda.amp.GradScaler() model.train() # Iterate over the batches of the dataset epoch_train_loss, batch_cnt = 0, 0 - pbar = tqdm(train_loader, dynamic_ncols=True) + pbar = tqdm(train_loader, dynamic_ncols=True, disable=(rank != 0)) for images, targets in pbar: if torch.cuda.is_available(): - images = images.cuda() + images = images.to(device) images = batch_transforms(images) optimizer.zero_grad() @@ -144,7 +147,8 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a last_lr = scheduler.get_last_lr()[0] pbar.set_description(f"Training loss: {train_loss.item():.6} | LR: {last_lr:.6}") - log(train_loss=train_loss.item(), lr=last_lr) + if log: + log(train_loss=train_loss.item(), lr=last_lr) epoch_train_loss += train_loss.item() batch_cnt += 1 @@ -154,7 +158,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a @torch.no_grad() -def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=None): +def evaluate(model, device, val_loader, batch_transforms, val_metric, amp=False, log=None): # Model in eval mode model.eval() # Reset val metric @@ -163,8 +167,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non val_loss, batch_cnt = 0, 0 pbar = tqdm(val_loader, dynamic_ncols=True) for images, targets in pbar: - if torch.cuda.is_available(): - images = images.cuda() + images = images.to(device) images = batch_transforms(images) if amp: with torch.cuda.amp.autocast(): @@ -179,7 +182,8 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non val_metric.update(targets, words) pbar.set_description(f"Validation loss: {out['loss'].item():.6}") - log(val_loss=out["loss"].item()) + if log: + log(val_loss=out["loss"].item()) val_loss += out["loss"].item() batch_cnt += 1 @@ -190,88 +194,119 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non def main(args): + # Detect distributed setup + # variable is set by torchrun + world_size = int(os.environ.get("WORLD_SIZE", 1)) + distributed = world_size > 1 + + # GPU setup + if distributed: + rank = int(os.environ.get("LOCAL_RANK", 0)) + dist.init_process_group(backend=args.backend) + device = torch.device("cuda", rank) + torch.cuda.set_device(device) + + else: + # single process + rank = 0 + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + device = torch.device("cuda", args.device) + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + device = torch.device("cuda", 0) + else: + logging.warning("No accessible GPU, target device set to CPU.") + device = torch.device("cpu") + slack_token = os.getenv("TQDM_SLACK_TOKEN") slack_channel = os.getenv("TQDM_SLACK_CHANNEL") - pbar = tqdm(disable=False if slack_token and slack_channel else True) + pbar = tqdm(disable=False if (slack_token and slack_channel) and (rank == 0) else True) if slack_token and slack_channel: # Monkey patch tqdm write method to send messages directly to Slack pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg) pbar.write(str(args)) - if args.push_to_hub: + if rank == 0 and args.push_to_hub: login_to_hub() if not isinstance(args.workers, int): - args.workers = min(16, mp.cpu_count()) + args.workers = min(16, multiprocessing.cpu_count()) torch.backends.cudnn.benchmark = True vocab = VOCABS[args.vocab] fonts = args.font.split(",") - # Load val data generator - st = time.time() - if isinstance(args.val_path, str): - with open(os.path.join(args.val_path, "labels.json"), "rb") as f: - val_hash = hashlib.sha256(f.read()).hexdigest() - - val_set = RecognitionDataset( - img_folder=os.path.join(args.val_path, "images"), - labels_path=os.path.join(args.val_path, "labels.json"), - img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), - ) - elif args.val_datasets: - val_hash = None - val_datasets = args.val_datasets - - val_set = datasets.__dict__[val_datasets[0]]( - train=False, - download=True, - recognition_task=True, - use_polygons=True, - img_transforms=Compose([ - T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), - # Augmentations - T.RandomApply(T.ColorInversion(), 0.1), - ]), + if rank == 0: + # Load val data generator + st = time.time() + if isinstance(args.val_path, str): + with open(os.path.join(args.val_path, "labels.json"), "rb") as f: + val_hash = hashlib.sha256(f.read()).hexdigest() + + val_set = RecognitionDataset( + img_folder=os.path.join(args.val_path, "images"), + labels_path=os.path.join(args.val_path, "labels.json"), + img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), + ) + elif args.val_datasets: + val_hash = None + val_datasets = args.val_datasets + + val_set = datasets.__dict__[val_datasets[0]]( + train=False, + download=True, + recognition_task=True, + use_polygons=True, + img_transforms=Compose([ + T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), + # Augmentations + T.RandomApply(T.ColorInversion(), 0.1), + ]), + ) + if len(val_datasets) > 1: + for dataset_name in val_datasets[1:]: + _ds = datasets.__dict__[dataset_name]( + train=False, + download=True, + recognition_task=True, + use_polygons=True, + ) + val_set.data.extend((np_img, target) for np_img, target in _ds.data) + else: + val_hash = None + # Load synthetic data generator + val_set = WordGenerator( + vocab=vocab, + min_chars=args.min_chars, + max_chars=args.max_chars, + num_samples=args.val_samples * len(vocab), + font_family=fonts, + img_transforms=Compose([ + T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), + # Ensure we have a 90% split of white-background images + T.RandomApply(T.ColorInversion(), 0.9), + ]), + ) + + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(val_set), + pin_memory=torch.cuda.is_available(), + collate_fn=val_set.collate_fn, ) - if len(val_datasets) > 1: - for dataset_name in val_datasets[1:]: - _ds = datasets.__dict__[dataset_name]( - train=False, - download=True, - recognition_task=True, - use_polygons=True, - ) - val_set.data.extend((np_img, target) for np_img, target in _ds.data) - else: - val_hash = None - # Load synthetic data generator - val_set = WordGenerator( - vocab=vocab, - min_chars=args.min_chars, - max_chars=args.max_chars, - num_samples=args.val_samples * len(vocab), - font_family=fonts, - img_transforms=Compose([ - T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), - # Ensure we have a 90% split of white-background images - T.RandomApply(T.ColorInversion(), 0.9), - ]), + pbar.write( + f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)" ) - val_loader = DataLoader( - val_set, - batch_size=args.batch_size, - drop_last=False, - num_workers=args.workers, - sampler=SequentialSampler(val_set), - pin_memory=torch.cuda.is_available(), - collate_fn=val_set.collate_fn, - ) - pbar.write(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)") - batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)) # Load doctr model @@ -283,27 +318,28 @@ def main(args): checkpoint = torch.load(args.resume, map_location="cpu") model.load_state_dict(checkpoint) - # GPU - if isinstance(args.device, int): - if not torch.cuda.is_available(): - raise AssertionError("PyTorch cannot access your GPU. Please investigate!") - if args.device >= torch.cuda.device_count(): - raise ValueError("Invalid device index") - # Silent default switch to GPU if available - elif torch.cuda.is_available(): - args.device = 0 - else: - logging.warning("No accessible GPU, target device set to CPU.") + # Backbone freezing + if args.freeze_backbone: + for p in model.feat_extractor.parameters(): + p.requires_grad = False + if torch.cuda.is_available(): - torch.cuda.set_device(args.device) - model = model.cuda() + torch.cuda.set_device(device) + model = model.to(device) + + if distributed: + # construct DDP model + model = DDP(model, device_ids=[rank]) - # Metrics - val_metric = TextMatch() + if rank == 0: + # Metrics + val_metric = TextMatch() - if args.test_only: + if rank == 0 and args.test_only: pbar.write("Running evaluation") - val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) + val_loss, exact_match, partial_match = evaluate( + model, device, val_loader, batch_transforms, val_metric, amp=args.amp + ) pbar.write(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") return @@ -385,28 +421,30 @@ def main(args): RandomPerspective(distortion_scale=0.2, p=0.3), ]), ) + if distributed: + sampler = DistributedSampler(train_set, rank=rank, shuffle=True, drop_last=True) + else: + sampler = RandomSampler(train_set) train_loader = DataLoader( train_set, batch_size=args.batch_size, drop_last=True, num_workers=args.workers, - sampler=RandomSampler(train_set), + sampler=sampler, pin_memory=torch.cuda.is_available(), collate_fn=train_set.collate_fn, ) - pbar.write(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)") + if rank == 0: + pbar.write( + f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)" + ) - if args.show_samples: + if rank == 0 and args.show_samples: x, target = next(iter(train_loader)) plot_samples(x, target) return - # Backbone freezing - if args.freeze_backbone: - for p in model.feat_extractor.parameters(): - p.requires_grad = False - # Optimizer if args.optim == "adam": optimizer = torch.optim.Adam( @@ -425,8 +463,8 @@ def main(args): weight_decay=args.weight_decay or 1e-4, ) - # LR Finder - if args.find_lr: + # LR finder + if rank == 0 and args.find_lr: lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) plot_recorder(lrs, losses) return @@ -443,30 +481,35 @@ def main(args): current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name - config = { - "learning_rate": args.lr, - "epochs": args.epochs, - "weight_decay": args.weight_decay, - "batch_size": args.batch_size, - "architecture": args.arch, - "input_size": args.input_size, - "optimizer": args.optim, - "framework": "pytorch", - "scheduler": args.sched, - "vocab": args.vocab, - "train_hash": train_hash, - "val_hash": val_hash, - "pretrained": args.pretrained, - } + if rank == 0: + config = { + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": args.weight_decay, + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": args.optim, + "framework": "pytorch", + "scheduler": args.sched, + "vocab": args.vocab, + "train_hash": train_hash, + "val_hash": val_hash, + "pretrained": args.pretrained, + "amp": args.amp, + } global global_step global_step = 0 # Shared global step counter - # W&B - if args.wb: + if rank == 0 and args.wb: import wandb - run = wandb.init(name=exp_name, project="text-recognition", config=config) + run = wandb.init( + name=exp_name, + project="text-recognition", + config=config, + ) def wandb_log_at_step(train_loss=None, val_loss=None, lr=None): wandb.log({ @@ -476,8 +519,8 @@ def wandb_log_at_step(train_loss=None, val_loss=None, lr=None): }) # ClearML - if args.clearml: - from clearml import Logger, Task + if rank == 0 and args.clearml: + from clearml import Task task = Task.init(project_name="docTR/text-recognition", task_name=exp_name, reuse_last_task_id=False) task.upload_artifact("config", config) @@ -506,7 +549,6 @@ def clearml_log_at_step(train_loss=None, val_loss=None, lr=None): value=lr, ) - # Unified logger def log_at_step(train_loss=None, val_loss=None, lr=None): global global_step if args.wb: @@ -517,57 +559,75 @@ def log_at_step(train_loss=None, val_loss=None, lr=None): # Create loss queue min_loss = np.inf - # Training loop if args.early_stop: early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta) + # Training loop for epoch in range(args.epochs): train_loss, actual_lr = fit_one_epoch( - model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, log=log_at_step + model, + device, + train_loader, + batch_transforms, + optimizer, + scheduler, + amp=args.amp, + log=log_at_step, + rank=rank, ) - pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Training loss: {train_loss:.6} | LR: {actual_lr:.6}") - # Validation loop at the end of each epoch - val_loss, exact_match, partial_match = evaluate( - model, val_loader, batch_transforms, val_metric, amp=args.amp, log=log_at_step - ) - if val_loss < min_loss: - pbar.write(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") + if rank == 0: + pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Training loss: {train_loss:.6} | LR: {actual_lr:.6}") + + # Validation loop at the end of each epoch + val_loss, exact_match, partial_match = evaluate( + model, device, val_loader, batch_transforms, val_metric, amp=args.amp, log=log_at_step + ) + if val_loss < min_loss: + # All processes should see same parameters as they all start from same + # random parameters and gradients are synchronized in backward passes. + # Therefore, saving it in one process is sufficient. + pbar.write(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") + params = model.module if hasattr(model, "module") else model + + torch.save(params.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") min_loss = val_loss - pbar.write( - f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " - f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})" - ) - # W&B - if args.wb: - wandb.log({ - "train_loss": train_loss, - "val_loss": val_loss, - "learning_rate": actual_lr, - "exact_match": exact_match, - "partial_match": partial_match, - }) - - # ClearML - if args.clearml: - from clearml import Logger + pbar.write( + f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " + f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})" + ) + # W&B + if args.wb: + wandb.log({ + "train_loss": train_loss, + "val_loss": val_loss, + "learning_rate": actual_lr, + "exact_match": exact_match, + "partial_match": partial_match, + }) + + # ClearML + if args.clearml: + from clearml import Logger + + logger = Logger.current_logger() + logger.report_scalar(title="Training Loss", series="train_loss", value=train_loss, iteration=epoch) + logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch) + logger.report_scalar(title="Learning Rate", series="lr", value=actual_lr, iteration=epoch) + logger.report_scalar(title="Exact Match", series="exact_match", value=exact_match, iteration=epoch) + logger.report_scalar( + title="Partial Match", series="partial_match", value=partial_match, iteration=epoch + ) - logger = Logger.current_logger() - logger.report_scalar(title="Training Loss", series="train_loss", value=train_loss, iteration=epoch) - logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch) - logger.report_scalar(title="Learning Rate", series="lr", value=actual_lr, iteration=epoch) - logger.report_scalar(title="Exact Match", series="exact_match", value=exact_match, iteration=epoch) - logger.report_scalar(title="Partial Match", series="partial_match", value=partial_match, iteration=epoch) - - if args.early_stop and early_stopper.early_stop(val_loss): - pbar.write("Training halted early due to reaching patience limit.") - break + if args.early_stop and early_stopper.early_stop(val_loss): + pbar.write("Training halted early due to reaching patience limit.") + break - if args.wb: - run.finish() + if rank == 0: + if args.wb: + run.finish() - if args.push_to_hub: - push_to_hf_hub(model, exp_name, task="recognition", run_config=args) + if args.push_to_hub: + push_to_hf_hub(model, exp_name, task="recognition", run_config=args) def parse_args(): @@ -578,22 +638,13 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + # DDP related args + parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for torch.distributed") + parser.add_argument("arch", type=str, help="text-recognition model to train") parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)") parser.add_argument("--val_path", type=str, default=None, help="path to val data folder") - parser.add_argument( - "--train-samples", - type=int, - default=1000, - help="Multiplied by the vocab length gets you the number of synthetic training samples that will be used.", - ) - parser.add_argument( - "--val-samples", - type=int, - default=20, - help="Multiplied by the vocab length gets you the number of synthetic validation samples that will be used.", - ) parser.add_argument( "--train_datasets", type=str, @@ -610,6 +661,18 @@ def parse_args(): default=None, help="Built-in datasets to use for validation", ) + parser.add_argument( + "--train-samples", + type=int, + default=1000, + help="Multiplied by the vocab length gets you the number of synthetic training samples that will be used.", + ) + parser.add_argument( + "--val-samples", + type=int, + default=20, + help="Multiplied by the vocab length gets you the number of synthetic validation samples that will be used.", + ) parser.add_argument( "--font", type=str, default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf", help="Font family to be used" ) @@ -618,8 +681,15 @@ def parse_args(): parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size for training") - parser.add_argument("--device", default=None, type=int, help="device") + parser.add_argument("--input_size", type=int, default=32, help="input size H for the model, W = 4*H") + parser.add_argument( + "--device", + default=None, + type=int, + help="Specify gpu device for single-gpu training. In destributed setting, this parameter is ignored", + ) + parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)") parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") diff --git a/references/recognition/train_pytorch_ddp.py b/references/recognition/train_pytorch_ddp.py deleted file mode 100644 index b5948b3e1..000000000 --- a/references/recognition/train_pytorch_ddp.py +++ /dev/null @@ -1,567 +0,0 @@ -# Copyright (C) 2021-2025, Mindee. - -# This program is licensed under the Apache License 2.0. -# See LICENSE or go to for full license details. - -import os - -os.environ["USE_TORCH"] = "1" - -import datetime -import hashlib -import multiprocessing -import time -from pathlib import Path - -import numpy as np -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import wandb -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, PolynomialLR -from torch.utils.data import DataLoader, SequentialSampler -from torch.utils.data.distributed import DistributedSampler -from torchvision.transforms.v2 import ( - Compose, - Normalize, - RandomGrayscale, - RandomPerspective, - RandomPhotometricDistort, -) - -if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): - from tqdm.contrib.slack import tqdm -else: - from tqdm.auto import tqdm - -from doctr import datasets -from doctr import transforms as T -from doctr.datasets import VOCABS, RecognitionDataset, WordGenerator -from doctr.models import login_to_hub, push_to_hf_hub, recognition -from doctr.utils.metrics import TextMatch -from utils import EarlyStopper, plot_samples - - -def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, scheduler, amp=False): - if amp: - scaler = torch.cuda.amp.GradScaler() - - model.train() - # Iterate over the batches of the dataset - epoch_train_loss, batch_cnt = 0, 0 - pbar = tqdm(train_loader, dynamic_ncols=True) - for images, targets in pbar: - images = images.to(device) - images = batch_transforms(images) - - optimizer.zero_grad() - if amp: - with torch.cuda.amp.autocast(): - train_loss = model(images, targets)["loss"] - scaler.scale(train_loss).backward() - # Gradient clipping - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - # Update the params - scaler.step(optimizer) - scaler.update() - else: - train_loss = model(images, targets)["loss"] - train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - optimizer.step() - - scheduler.step() - last_lr = scheduler.get_last_lr()[0] - - pbar.set_description(f"Training loss: {train_loss.item():.6} | LR: {last_lr:.6}") - epoch_train_loss += train_loss.item() - batch_cnt += 1 - - epoch_train_loss /= batch_cnt - return epoch_train_loss, last_lr - - -@torch.no_grad() -def evaluate(model, device, val_loader, batch_transforms, val_metric, amp=False): - # Model in eval mode - model.eval() - # Reset val metric - val_metric.reset() - # Validation loop - val_loss, batch_cnt = 0, 0 - pbar = tqdm(val_loader, dynamic_ncols=True) - for images, targets in pbar: - images = images.to(device) - images = batch_transforms(images) - if amp: - with torch.cuda.amp.autocast(): - out = model(images, targets, return_preds=True) - else: - out = model(images, targets, return_preds=True) - # Compute metric - if len(out["preds"]): - words, _ = zip(*out["preds"]) - else: - words = [] - val_metric.update(targets, words) - - pbar.set_description(f"Validation loss: {out['loss'].item():.6}") - - val_loss += out["loss"].item() - batch_cnt += 1 - - val_loss /= batch_cnt - result = val_metric.summary() - return val_loss, result["raw"], result["unicase"] - - -def main(rank: int, world_size: int, args): - """ - Args: - rank (int): device id to put the model on - world_size (int): number of processes participating in the job - args: other arguments passed through the CLI - """ - slack_token = os.getenv("TQDM_SLACK_TOKEN") - slack_channel = os.getenv("TQDM_SLACK_CHANNEL") - - pbar = tqdm(disable=False if slack_token and slack_channel else True) - if slack_token and slack_channel: - # Monkey patch tqdm write method to send messages directly to Slack - pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg) - pbar.write(str(args)) - - if rank == 0 and args.push_to_hub: - login_to_hub() - - if not isinstance(args.workers, int): - args.workers = min(16, multiprocessing.cpu_count()) - - torch.backends.cudnn.benchmark = True - - vocab = VOCABS[args.vocab] - fonts = args.font.split(",") - - if rank == 0: - # Load val data generator - st = time.time() - if isinstance(args.val_path, str): - with open(os.path.join(args.val_path, "labels.json"), "rb") as f: - val_hash = hashlib.sha256(f.read()).hexdigest() - - val_set = RecognitionDataset( - img_folder=os.path.join(args.val_path, "images"), - labels_path=os.path.join(args.val_path, "labels.json"), - img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), - ) - elif args.val_datasets: - val_hash = None - val_datasets = args.val_datasets - - val_set = datasets.__dict__[val_datasets[0]]( - train=False, - download=True, - recognition_task=True, - use_polygons=True, - img_transforms=Compose([ - T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), - # Augmentations - T.RandomApply(T.ColorInversion(), 0.1), - ]), - ) - if len(val_datasets) > 1: - for dataset_name in val_datasets[1:]: - _ds = datasets.__dict__[dataset_name]( - train=False, - download=True, - recognition_task=True, - use_polygons=True, - ) - val_set.data.extend((np_img, target) for np_img, target in _ds.data) - else: - val_hash = None - # Load synthetic data generator - val_set = WordGenerator( - vocab=vocab, - min_chars=args.min_chars, - max_chars=args.max_chars, - num_samples=args.val_samples * len(vocab), - font_family=fonts, - img_transforms=Compose([ - T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), - # Ensure we have a 90% split of white-background images - T.RandomApply(T.ColorInversion(), 0.9), - ]), - ) - - val_loader = DataLoader( - val_set, - batch_size=args.batch_size, - drop_last=False, - num_workers=args.workers, - sampler=SequentialSampler(val_set), - pin_memory=torch.cuda.is_available(), - collate_fn=val_set.collate_fn, - ) - pbar.write( - f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)" - ) - - batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)) - - # Load doctr model - model = recognition.__dict__[args.arch](pretrained=args.pretrained, vocab=vocab) - - # Resume weights - if isinstance(args.resume, str): - pbar.write(f"Resuming {args.resume}") - checkpoint = torch.load(args.resume, map_location="cpu") - model.load_state_dict(checkpoint) - - # Backbone freezing - if args.freeze_backbone: - for p in model.feat_extractor.parameters(): - p.requires_grad = False - - # create default process group - device = torch.device("cuda", args.devices[rank]) - dist.init_process_group(args.backend, rank=rank, world_size=world_size) - # create local model - model = model.to(device) - # construct DDP model - model = DDP(model, device_ids=[device]) - - if rank == 0: - # Metrics - val_metric = TextMatch() - - if rank == 0 and args.test_only: - pbar.write("Running evaluation") - val_loss, exact_match, partial_match = evaluate( - model, device, val_loader, batch_transforms, val_metric, amp=args.amp - ) - pbar.write(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") - return - - st = time.time() - - if isinstance(args.train_path, str): - # Load train data generator - base_path = Path(args.train_path) - parts = ( - [base_path] - if base_path.joinpath("labels.json").is_file() - else [base_path.joinpath(sub) for sub in os.listdir(base_path)] - ) - with open(parts[0].joinpath("labels.json"), "rb") as f: - train_hash = hashlib.sha256(f.read()).hexdigest() - - train_set = RecognitionDataset( - parts[0].joinpath("images"), - parts[0].joinpath("labels.json"), - img_transforms=Compose([ - T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), - # Augmentations - T.RandomApply(T.ColorInversion(), 0.1), - RandomGrayscale(p=0.1), - RandomPhotometricDistort(p=0.1), - T.RandomApply(T.RandomShadow(), p=0.4), - T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1), - T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), - RandomPerspective(distortion_scale=0.2, p=0.3), - ]), - ) - if len(parts) > 1: - for subfolder in parts[1:]: - train_set.merge_dataset( - RecognitionDataset(subfolder.joinpath("images"), subfolder.joinpath("labels.json")) - ) - elif args.train_datasets: - train_hash = None - train_datasets = args.train_datasets - - train_set = datasets.__dict__[train_datasets[0]]( - train=True, - download=True, - recognition_task=True, - use_polygons=True, - img_transforms=Compose([ - T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), - # Augmentations - T.RandomApply(T.ColorInversion(), 0.1), - ]), - ) - if len(train_datasets) > 1: - for dataset_name in train_datasets[1:]: - _ds = datasets.__dict__[dataset_name]( - train=True, - download=True, - recognition_task=True, - use_polygons=True, - ) - train_set.data.extend((np_img, target) for np_img, target in _ds.data) - else: - train_hash = None - # Load synthetic data generator - train_set = WordGenerator( - vocab=vocab, - min_chars=args.min_chars, - max_chars=args.max_chars, - num_samples=args.train_samples * len(vocab), - font_family=fonts, - img_transforms=Compose([ - T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), - # Ensure we have a 90% split of white-background images - T.RandomApply(T.ColorInversion(), 0.9), - RandomGrayscale(p=0.1), - RandomPhotometricDistort(p=0.1), - T.RandomApply(T.RandomShadow(), p=0.4), - T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1), - T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), - RandomPerspective(distortion_scale=0.2, p=0.3), - ]), - ) - - train_loader = DataLoader( - train_set, - batch_size=args.batch_size, - drop_last=True, - num_workers=args.workers, - sampler=DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True), - pin_memory=torch.cuda.is_available(), - collate_fn=train_set.collate_fn, - ) - pbar.write(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)") - - if rank == 0 and args.show_samples: - x, target = next(iter(train_loader)) - plot_samples(x, target) - return - - # Optimizer - if args.optim == "adam": - optimizer = torch.optim.Adam( - [p for p in model.parameters() if p.requires_grad], - args.lr, - betas=(0.95, 0.999), - eps=1e-6, - weight_decay=args.weight_decay, - ) - elif args.optim == "adamw": - optimizer = torch.optim.AdamW( - [p for p in model.parameters() if p.requires_grad], - args.lr, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=args.weight_decay or 1e-4, - ) - - # Scheduler - if args.sched == "cosine": - scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) - elif args.sched == "onecycle": - scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) - elif args.sched == "poly": - scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) - - # Training monitoring - current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name - - if rank == 0: - config = { - "learning_rate": args.lr, - "epochs": args.epochs, - "weight_decay": args.weight_decay, - "batch_size": args.batch_size, - "architecture": args.arch, - "input_size": args.input_size, - "optimizer": args.optim, - "framework": "pytorch", - "scheduler": args.sched, - "train_hash": train_hash, - "val_hash": val_hash, - "pretrained": args.pretrained, - "rotation": args.rotation, - "amp": args.amp, - } - - # W&B - if rank == 0 and args.wb: - run = wandb.init( - name=exp_name, - project="text-recognition", - config=config, - ) - - # ClearML - if rank == 0 and args.clearml: - from clearml import Task - - task = Task.init(project_name="docTR/text-recognition", task_name=exp_name, reuse_last_task_id=False) - task.upload_artifact("config", config) - - # Create loss queue - min_loss = np.inf - if args.early_stop: - early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta) - # Training loop - for epoch in range(args.epochs): - train_loss, actual_lr = fit_one_epoch( - model, device, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp - ) - pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Training loss: {train_loss:.6} | LR: {actual_lr:.6}") - - if rank == 0: - # Validation loop at the end of each epoch - val_loss, exact_match, partial_match = evaluate( - model, device, val_loader, batch_transforms, val_metric, amp=args.amp - ) - if val_loss < min_loss: - # All processes should see same parameters as they all start from same - # random parameters and gradients are synchronized in backward passes. - # Therefore, saving it in one process is sufficient. - pbar.write(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.module.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") - min_loss = val_loss - pbar.write( - f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " - f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})" - ) - # W&B - if args.wb: - wandb.log({ - "train_loss": train_loss, - "val_loss": val_loss, - "learning_rate": actual_lr, - "exact_match": exact_match, - "partial_match": partial_match, - }) - - # ClearML - if args.clearml: - from clearml import Logger - - logger = Logger.current_logger() - logger.report_scalar(title="Training Loss", series="train_loss", value=train_loss, iteration=epoch) - logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch) - logger.report_scalar(title="Learning Rate", series="lr", value=actual_lr, iteration=epoch) - logger.report_scalar(title="Exact Match", series="exact_match", value=exact_match, iteration=epoch) - logger.report_scalar( - title="Partial Match", series="partial_match", value=partial_match, iteration=epoch - ) - - if args.early_stop and early_stopper.early_stop(val_loss): - pbar.write("Training halted early due to reaching patience limit.") - break - - if rank == 0: - if args.wb: - run.finish() - - if args.push_to_hub: - push_to_hf_hub(model, exp_name, task="recognition", run_config=args) - - -def parse_args(): - import argparse - - parser = argparse.ArgumentParser( - description="DocTR DDP training script for text recognition (PyTorch)", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # DDP related args - parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for Torch DDP") - - parser.add_argument("arch", type=str, help="text-recognition model to train") - parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") - parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)") - parser.add_argument("--val_path", type=str, default=None, help="path to val data folder") - parser.add_argument( - "--train_datasets", - type=str, - nargs="+", - choices=["CORD", "FUNSD", "IC03", "IIIT5K", "SVHN", "SVT", "SynthText"], - default=None, - help="Built-in datasets to use for training", - ) - parser.add_argument( - "--val_datasets", - type=str, - nargs="+", - choices=["CORD", "FUNSD", "IC03", "IIIT5K", "SVHN", "SVT", "SynthText"], - default=None, - help="Built-in datasets to use for validation", - ) - parser.add_argument( - "--train-samples", - type=int, - default=1000, - help="Multiplied by the vocab length gets you the number of synthetic training samples that will be used.", - ) - parser.add_argument( - "--val-samples", - type=int, - default=20, - help="Multiplied by the vocab length gets you the number of synthetic validation samples that will be used.", - ) - parser.add_argument( - "--font", type=str, default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf", help="Font family to be used" - ) - parser.add_argument("--min-chars", type=int, default=1, help="Minimum number of characters per synthetic sample") - parser.add_argument("--max-chars", type=int, default=12, help="Maximum number of characters per synthetic sample") - parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") - parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") - parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size for training") - parser.add_argument("--devices", default=None, nargs="+", type=int, help="GPU devices to use for training") - parser.add_argument("--input_size", type=int, default=32, help="input size H for the model, W = 4*H") - parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)") - parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay") - parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") - parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") - parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training") - parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") - parser.add_argument( - "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" - ) - parser.add_argument( - "--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples" - ) - parser.add_argument("--wb", dest="wb", action="store_true", help="Log to Weights & Biases") - parser.add_argument("--clearml", dest="clearml", action="store_true", help="Log to ClearML") - parser.add_argument("--push-to-hub", dest="push_to_hub", action="store_true", help="Push to Huggingface Hub") - parser.add_argument( - "--pretrained", - dest="pretrained", - action="store_true", - help="Load pretrained parameters before starting the training", - ) - parser.add_argument("--optim", type=str, default="adam", choices=["adam", "adamw"], help="optimizer to use") - parser.add_argument( - "--sched", type=str, default="cosine", choices=["cosine", "onecycle", "poly"], help="scheduler to use" - ) - parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") - parser.add_argument("--early-stop", action="store_true", help="Enable early stopping") - parser.add_argument("--early-stop-epochs", type=int, default=5, help="Patience for early stopping") - parser.add_argument("--early-stop-delta", type=float, default=0.01, help="Minimum Delta for early stopping") - args = parser.parse_args() - - return args - - -if __name__ == "__main__": - args = parse_args() - if not torch.cuda.is_available(): - raise AssertionError("PyTorch cannot access your GPUs. Please investigate!") - if not isinstance(args.devices, list): - args.devices = list(range(torch.cuda.device_count())) - nprocs = len(args.devices) - # Environment variables which need to be - # set when using c10d's default "env" - # initialization mode. - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" - mp.spawn(main, args=(nprocs, args), nprocs=nprocs, join=True)