From 22b1f35dc67f8450a4a8688b95edf57863a960aa Mon Sep 17 00:00:00 2001 From: Josua Rieder Date: Mon, 4 Nov 2024 01:35:36 +0100 Subject: [PATCH] Use args.class_map for labels in inference.py --- inference.py | 19 ++++++++++++++----- timm/data/dataset_info.py | 5 +++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/inference.py b/inference.py index 60581978b7..368639e30d 100755 --- a/inference.py +++ b/inference.py @@ -18,6 +18,8 @@ import torch from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset +from timm.data.readers.class_map import load_class_map +from timm.data.dataset_info import get_dataset_info_from_class_map from timm.layers import apply_test_time_pool from timm.models import create_model from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs @@ -241,9 +243,18 @@ def main(): to_label = None if args.label_type in ('name', 'description', 'detail'): - imagenet_subset = infer_imagenet_subset(model) - if imagenet_subset is not None: - dataset_info = ImageNetInfo(imagenet_subset) + dataset_info = None + if args.class_map: + class_map = load_class_map(args.class_map) + dataset_info = get_dataset_info_from_class_map(class_map) + else: + imagenet_subset = infer_imagenet_subset(model) + if imagenet_subset is not None: + dataset_info = ImageNetInfo(imagenet_subset) + else: + _logger.error("Cannot deduce ImageNet subset from model, no labelling will be performed.") + + if dataset_info: if args.label_type == 'name': to_label = lambda x: dataset_info.index_to_label_name(x) elif args.label_type == 'detail': @@ -251,8 +262,6 @@ def main(): else: to_label = lambda x: dataset_info.index_to_description(x) to_label = np.vectorize(to_label) - else: - _logger.error("Cannot deduce ImageNet subset from model, no labelling will be performed.") top_k = min(args.topk, args.num_classes) batch_time = AverageMeter() diff --git a/timm/data/dataset_info.py b/timm/data/dataset_info.py index 58e4619618..f8c6897e48 100644 --- a/timm/data/dataset_info.py +++ b/timm/data/dataset_info.py @@ -71,3 +71,8 @@ def index_to_label_name(self, index) -> str: def index_to_description(self, index: int, detailed: bool = False) -> str: label = self.index_to_label_name(index) return self.label_name_to_description(label, detailed=detailed) + + +def get_dataset_info_from_class_map(class_map): + label_names = {v: k for k, v in class_map.items()} + return CustomDatasetInfo(label_names)