diff --git a/README.md b/README.md index 934ce6e..fd1b8e5 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,21 @@ This repository is an implementation of Nodes that interact with the Zauberzeug This node is used to train Yolov5 Models in the Learning Loop. It is based on [this image](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-07.html) running Python 3.10. +## Hyperparameters + +We support all native hyperparameters of YOLOv5 (cf. `hyp_det.yaml` / `hyp_cla.yaml` for reference). +In addition, we support the following hyperparameters: + +- `epochs`: The number of epochs to train the model. +- `detect_nms_conf_thres`: The confidence threshold for the NMS during inference and validation (not relevant for training). +- `detect_nms_iou_thres`: The IoU threshold for the NMS during inference and validation (not used for training). + +Further, we support the following hyperparameters for point detection: + +- `reset_points`: Whether to reset the size of the points after data augmentation. +- `point_sizes_by_id`: A dictionary that maps from point category uuids to the size of the points in the output (fractional size 0-1). +- `flip_label_pairs`: A list of pairs of point uuids that should be swapped when a horizontal flip is applied during data augmentation. + ## Images Trainer Docker-Images are published on https://hub.docker.com/r/zauberzeug/yolov5-trainer diff --git a/trainer/app_code/yolov5/utils/dataloaders.py b/trainer/app_code/yolov5/utils/dataloaders.py index 90ee686..4de871d 100644 --- a/trainer/app_code/yolov5/utils/dataloaders.py +++ b/trainer/app_code/yolov5/utils/dataloaders.py @@ -29,11 +29,35 @@ from torch.utils.data import DataLoader, Dataset, dataloader, distributed from tqdm import tqdm -from .augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste, - letterbox, mixup, random_perspective) -from .general import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, check_dataset, check_requirements, check_yaml, - clean_str, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy, xywh2xyxy, xywhn2xyxy, - xyxy2xywhn) +from .augmentations import ( + Albumentations, + augment_hsv, + classify_albumentations, + classify_transforms, + copy_paste, + letterbox, + mixup, + random_perspective, +) +from .general import ( + DATASETS_DIR, + LOGGER, + NUM_THREADS, + TQDM_BAR_FORMAT, + check_dataset, + check_requirements, + check_yaml, + clean_str, + cv2, + is_colab, + is_kaggle, + segments2boxes, + unzip_file, + xyn2xy, + xywh2xyxy, + xywhn2xyxy, + xyxy2xywhn, +) from .torch_utils import torch_distributed_zero_first # Parameters @@ -117,7 +141,8 @@ def create_dataloader(path, quad=False, prefix='', shuffle=False, - point_sizes_by_id: dict[int, int] = dict()): + point_sizes_by_id: dict[int, float] | None = None, + flip_label_pairs: list[tuple[int, int]] | None = None): if rect and shuffle: LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False') shuffle = False @@ -135,7 +160,8 @@ def create_dataloader(path, pad=pad, image_weights=image_weights, prefix=prefix, - point_sizes_by_id=point_sizes_by_id) + point_sizes_by_id=point_sizes_by_id, + flip_label_pairs=flip_label_pairs) batch_size = min(batch_size, len(dataset)) nd = torch.cuda.device_count() # number of CUDA devices @@ -463,7 +489,8 @@ def __init__(self, pad=0.0, min_items=0, prefix='', - point_sizes_by_id: dict[int, int] = dict()): + point_sizes_by_id: dict[int, float] | None = None, + flip_label_pairs: list[tuple[int, int]] | None = None): self.img_size = img_size self.augment = augment @@ -475,7 +502,8 @@ def __init__(self, self.stride = stride self.path = path self.albumentations = Albumentations(size=img_size) if augment else None - self.point_sizes_by_id = point_sizes_by_id + self.point_sizes_by_id = point_sizes_by_id or {} + self.flip_label_pairs = flip_label_pairs or [] self.prefix = prefix try: @@ -729,6 +757,15 @@ def __getitem__(self, index): img = np.fliplr(img) if nl: labels[:, 1] = 1 - labels[:, 1] + if len(self.flip_label_pairs) > 0: + for label in labels: + for pair in self.flip_label_pairs: + if label[0] == pair[0]: + label[0] = pair[1] + break + if label[0] == pair[1]: + label[0] = pair[0] + break # Cutouts # labels = cutout(img, labels, p=0.5) @@ -753,10 +790,9 @@ def __getitem__(self, index): if reset_points and len(self.point_sizes_by_id) > 0: for label in labels: if label[0] in self.point_sizes_by_id: - target_point_size_px = self.point_sizes_by_id[label[0]] - # target_point_size_px = 10 - label[3] = target_point_size_px / self.img_size - label[4] = target_point_size_px / self.img_size + target_point_size_fraction = self.point_sizes_by_id[label[0]] + label[3] = target_point_size_fraction + label[4] = target_point_size_fraction labels_out = torch.zeros((nl, 6)) if nl: diff --git a/trainer/app_code/yolov5_format.py b/trainer/app_code/yolov5_format.py index 42ca07c..f33b969 100644 --- a/trainer/app_code/yolov5_format.py +++ b/trainer/app_code/yolov5_format.py @@ -5,7 +5,6 @@ from typing import Any from learning_loop_node.data_classes import Training -from learning_loop_node.enums import CategoryType from learning_loop_node.trainer.exceptions import CriticalError from ruamel.yaml import YAML from ruamel.yaml.scalarbool import ScalarBoolean @@ -16,17 +15,6 @@ yaml = YAML() -def get_ids_and_sizes_of_point_classes(training: Training) -> tuple[list[str], list[str]]: - """Returns a list of trainingids and sizes (in px) of point classes in the training data.""" - assert training is not None, 'Training should have data' - point_ids, point_sizes = [], [] - for i, category in enumerate(training.categories): - if category.type == CategoryType.Point: - point_ids.append(str(i)) - point_sizes.append(str(category.point_size or 20)) - return point_ids, point_sizes - - def category_lookup_from_training(training: Training) -> dict[str, str]: return {c.name: c.id for c in training.categories} @@ -162,6 +150,8 @@ def create_file_structure(training: Training) -> None: def set_hyperparameters_in_file(yaml_path: str, hyperparameter: dict[str, Any]) -> None: + """Override the hyperparameters in the yaml file used by the yolov5 trainer with the ones from the hyperparameter dict (coming from the loop configuration). + The yaml file is modified in place.""" with open(yaml_path, 'r') as f: content = yaml.load(f) diff --git a/trainer/app_code/yolov5_trainer.py b/trainer/app_code/yolov5_trainer.py index ea2d024..0fe32af 100644 --- a/trainer/app_code/yolov5_trainer.py +++ b/trainer/app_code/yolov5_trainer.py @@ -44,6 +44,8 @@ def __init__(self) -> None: self.epochs = 0 self.detect_nms_conf_thres = 0.2 self.detect_nms_iou_thres = 0.45 + self.point_sizes_by_uuid: dict[str, float] = {} + self.flip_label_uuid_pairs: list[tuple[str, str]] = [] self.additional_hyperparameters_parsed = False @@ -155,7 +157,7 @@ async def _get_latest_model_files(self) -> dict[str, list[str]]: async def _detect(self, model_information: ModelInformation, images: list[str], model_folder: str) -> list[Detections]: - self._set_params_from_hyperparameters() + self._save_additional_hyperparameters() images_folder = '/tmp/imagelinks_for_detecting' shutil.rmtree(images_folder, ignore_errors=True) @@ -228,7 +230,7 @@ async def _start_training_from_model(self, model: str) -> None: async def _start(self, model: str, additional_parameters: str = ''): resolution = self.training.hyperparameters['resolution'] - self._set_params_from_hyperparameters() + self._save_additional_hyperparameters() if self.is_cla: cmd = f'python /app/train_cla.py --exist-ok --img {resolution} \ @@ -236,8 +238,6 @@ async def _start(self, model: str, additional_parameters: str = ''): --project {self.training.training_folder} --name result \ --hyp {self.hyperparameter_path} --optimizer SGD {additional_parameters}' else: - p_ids, p_sizes = yolov5_format.get_ids_and_sizes_of_point_classes(self.training) - # self._try_replace_optimized_hyperparameter() try: batch_size = await batch_size_calculation.calc(self.training.training_folder, model, self.hyperparameter_path, f'{self.training.training_folder}/dataset.yaml', resolution) @@ -245,16 +245,39 @@ async def _start(self, model: str, additional_parameters: str = ''): logging.exception('Error during batch size calculation:') raise NodeNeedsRestartError() from e + p_sizes_by_id = "" + for i, category in enumerate(self.training.categories): + if category.type == CategoryType.Point: + size = self.point_sizes_by_uuid.get(category.id, 0.03) + p_sizes_by_id += f"{i}:{size}," + + flip_label_pairs = "" + for uuid_i, uuid_j in self.flip_label_uuid_pairs: + id_i = None + id_j = None + for i, category in enumerate(self.training.categories): + if category.id == uuid_i: + id_i = i + if category.id == uuid_j: + id_j = i + if id_i is not None and id_j is not None: + flip_label_pairs += f"{id_i}:{id_j}," + cmd = f'python /app/train_det.py --exist-ok --patience {self.patience} \ --batch-size {batch_size} --img {resolution} --data dataset.yaml --weights {model} \ --project {self.training.training_folder} --name result --hyp {self.hyperparameter_path} \ --epochs {self.epochs} {additional_parameters}' - if p_ids: - cmd += f' --point_ids {",".join(p_ids)} --point_sizes {",".join(p_sizes)}' + if p_sizes_by_id: + cmd += f' --point_sizes_by_id {p_sizes_by_id[:-1]}' + if flip_label_pairs: + cmd += f' --flip_label_pairs {flip_label_pairs[:-1]}' await self.executor.start(cmd, env={'WANDB_MODE': 'disabled'}) - def _set_params_from_hyperparameters(self) -> None: + def _save_additional_hyperparameters(self) -> None: + """Save additional hyperparameters to attributes of self. + These parameters are not passed to the yolov5 trainer, but are used to modify the training (and inference) process. + """ if self.additional_hyperparameters_parsed: return self.additional_hyperparameters_parsed = True @@ -264,25 +287,27 @@ def _set_params_from_hyperparameters(self) -> None: raise CriticalError(f'No hyperparameter file found at {self.hyperparameter_path}') with open(self.hyperparameter_path, errors='ignore') as f: - hyp = yaml.safe_load(f) # load hyps dict - hyp = {k: float(v) for k, v in hyp.items()} - hyp_str = 'hyps: ' + ', '.join(f'{k}={v}' for k, v in hyp.items()) - logging.info('parsing hyperparameters: %s', hyp_str) + hyp = dict(yaml.safe_load(f)) # load hyps dict self.epochs = int(hyp.get('epochs', self.epochs)) - self.detect_nms_conf_thres = hyp.get('detect_nms_conf_thres', self.detect_nms_conf_thres) - self.detect_nms_iou_thres = hyp.get('detect_nms_iou_thres', self.detect_nms_iou_thres) - - logging.info('parsing done: epochs: %d, detect_nms_conf_thres: %f, detect_nms_iou_thres: %f', - self.epochs, self.detect_nms_conf_thres, self.detect_nms_iou_thres) - - # def _try_replace_optimized_hyperparameter(self): - # optimized_hyp = f'{self.training.project_folder}/yolov5s6_evolved_hyperparameter.yaml' - # if os.path.exists(optimized_hyp): - # logging.info('Found optimized hyperparameter') - # shutil.copy(optimized_hyp, self.hyperparameter_path) - # else: - # logging.warning('No optimized hyperparameter found (!)') + self.detect_nms_conf_thres = float(hyp.get('detect_nms_conf_thres', self.detect_nms_conf_thres)) + self.detect_nms_iou_thres = float(hyp.get('detect_nms_iou_thres', self.detect_nms_iou_thres)) + + if point_sizes_by_id_str := str(hyp.get('point_sizes_by_id', '')): + for item in point_sizes_by_id_str.split(','): + k, v = item.split(':') + self.point_sizes_by_uuid[str(k)] = float(v) + + if flip_label_pairs_str := str(hyp.get('flip_label_pairs', '')): + for item in flip_label_pairs_str.split(','): + k, v = item.split(':') + self.flip_label_uuid_pairs.append((str(k), str(v))) + + hyp_str = ', '.join(f'{k}={v}' for k, v in hyp.items()) + logging.info('parsed hyperparameters %s: epochs: %d, detect_nms_conf_thres: %f, detect_nms_iou_thres: %f', + hyp_str, self.epochs, self.detect_nms_conf_thres, self.detect_nms_iou_thres) + logging.info('point_sizes_by_id: %s', self.point_sizes_by_uuid) + logging.info('flip_label_pairs: %s', self.flip_label_uuid_pairs) def _parse(self, labels_path: str, images_folder: str, model_information: ModelInformation) -> list[Detections]: detections = [] diff --git a/trainer/hyp_det.yaml b/trainer/hyp_det.yaml index 14f58b7..99cb762 100644 --- a/trainer/hyp_det.yaml +++ b/trainer/hyp_det.yaml @@ -26,14 +26,18 @@ translate: 0.245 scale: 0.898 shear: 0.602 perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 -flipud: 0 # 0.00856 # image flip up-down (probability) -fliplr: 0 # 0.5 # image flip left-right (probability) +flipud: 0.0 # 0.00856 # image flip up-down (probability) +fliplr: 0.0 # 0.5 # image flip left-right (probability) mosaic: 1.0 # image mosaic (probability) mixup: 0.243 # image mixup (probability) copy_paste: 0.0 #own parameters epochs: 2000 -reset_points: false detect_nms_conf_thres: 0.2 detect_nms_iou_thres: 0.45 + +#extra point parameters (not directly used by yolov5 but by the trainer logic) +reset_points: false +point_sizes_by_id: "" # e.g "1111-2222-3333-4444:0.03,5555-6666-7777-8888:0.05" +flip_label_pairs: "" # e.g "1111-2222-3333-4444:5555-6666-7777-8888" diff --git a/trainer/train_det.py b/trainer/train_det.py index f1fc6ad..95b537c 100644 --- a/trainer/train_det.py +++ b/trainer/train_det.py @@ -32,10 +32,6 @@ import torch import torch.nn as nn import yaml -from PIL import Image, ImageDraw, ImageFont -from torch.optim import lr_scheduler -from tqdm import tqdm - from app_code.yolov5 import val as validate # for end-of-epoch mAP from app_code.yolov5.models.experimental import attempt_load from app_code.yolov5.models.yolo import Model @@ -43,25 +39,45 @@ from app_code.yolov5.utils.callbacks import Callbacks from app_code.yolov5.utils.dataloaders import create_dataloader from app_code.yolov5.utils.downloads import attempt_download, is_url -from app_code.yolov5.utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, - check_dataset, check_file, - check_img_size, check_suffix, - check_yaml, colorstr, - get_latest_run, increment_path, - init_seeds, intersect_dicts, - labels_to_class_weights, - labels_to_image_weights, methods, - one_cycle, print_args, - print_mutation, strip_optimizer, - yaml_save) +from app_code.yolov5.utils.general import ( + LOGGER, + TQDM_BAR_FORMAT, + check_amp, + check_dataset, + check_file, + check_img_size, + check_suffix, + check_yaml, + colorstr, + get_latest_run, + increment_path, + init_seeds, + intersect_dicts, + labels_to_class_weights, + labels_to_image_weights, + methods, + one_cycle, + print_args, + print_mutation, + strip_optimizer, + yaml_save, +) from app_code.yolov5.utils.loggers import Loggers from app_code.yolov5.utils.loggers.comet.comet_utils import check_comet_resume from app_code.yolov5.utils.loss import ComputeLoss from app_code.yolov5.utils.metrics import fitness from app_code.yolov5.utils.plots import plot_evolve -from app_code.yolov5.utils.torch_utils import (EarlyStopping, ModelEMA, - de_parallel, select_device, - smart_optimizer, smart_resume) +from app_code.yolov5.utils.torch_utils import ( + EarlyStopping, + ModelEMA, + de_parallel, + select_device, + smart_optimizer, + smart_resume, +) +from PIL import Image, ImageDraw, ImageFont +from torch.optim import lr_scheduler +from tqdm import tqdm def draw_example(img: torch.Tensor, @@ -165,10 +181,13 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio callbacks.run('on_pretrain_routine_start') # Modification by Zauberzeug - point_ids: List[int] = [int(x) for x in opt.point_ids.split(',')] if opt.point_ids else [] # type: ignore - point_sizes: List[int] = [int(x) for x in opt.point_sizes.split(',')] if opt.point_sizes else [] # type: ignore + # point_sizes_by_id: "1:0.2,3:0.3,4:0.02" (mapping from point id to fractional point size if point_sizes should be overridden during augmentation) + # flip_label_pairs: "1:2,3:4" (mapping from category id to category id which should be switched when performing h_flip augmentation) - point_sizes_by_id = dict(zip(point_ids, point_sizes)) + point_sizes_by_id = {int(k): float(v) for k, v in (item.split(':') + for item in str(opt.point_sizes_by_id).split(','))} if opt.point_sizes_by_id else {} + flip_label_pairs = [(int(k), int(v)) for k, v in (item.split(':') + for item in str(opt.flip_label_pairs).split(','))] if opt.flip_label_pairs else [] # Directories examples = save_dir / 'examples' # examples dir @@ -292,7 +311,8 @@ def lf(x): return (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear quad=opt.quad, prefix=colorstr('train: '), shuffle=True, - point_sizes_by_id=point_sizes_by_id) + point_sizes_by_id=point_sizes_by_id, + flip_label_pairs=flip_label_pairs) labels = np.concatenate(dataset.labels, 0) mlc = int(labels[:, 0].max()) # max label class @@ -561,10 +581,12 @@ def parse_opt(known=False): parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)') parser.add_argument('--seed', type=int, default=0, help='Global training seed') parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify') - parser.add_argument('--point_ids', type=str, default='', - help='Comma separated list of point training ids as string, e.g. --point_ids 1,3') - parser.add_argument('--point_sizes', type=str, default='', - help='Comma separated list of point sizes as string, e.g. --point_sizes 30, 50') + + # Modification by Zauberzeug + parser.add_argument('--point_sizes_by_id', type=str, default='', + help='Comma separated list of point sizes as string, e.g. --point_sizes_by_id "1:0.3,3:0.5"') + parser.add_argument('--flip_label_pairs', type=str, default='', + help='Comma separated list of flip label pairs as string, e.g. --flip_label_pairs "1:2,3:4"') # Logger arguments parser.add_argument('--entity', default=None, help='Entity')