From a1b7406e52724066ba51c298f961264167590256 Mon Sep 17 00:00:00 2001 From: Adam Dangoor Date: Thu, 27 Jun 2024 15:51:12 +0100 Subject: [PATCH 1/3] Use torchvision in a different way to avoid type errors --- src/mock_vws/image_matchers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mock_vws/image_matchers.py b/src/mock_vws/image_matchers.py index ac5fb1452..cfd16bc05 100644 --- a/src/mock_vws/image_matchers.py +++ b/src/mock_vws/image_matchers.py @@ -7,7 +7,7 @@ from torchmetrics.image import ( StructuralSimilarityIndexMeasure, ) -from torchvision.transforms import functional # type: ignore[import-untyped] +from torchvision import transforms # type: ignore[import-untyped] @runtime_checkable @@ -74,8 +74,10 @@ def __call__( first_image = first_image.resize(size=target_size) second_image = second_image.resize(size=target_size) - first_image_tensor = functional.to_tensor(pic=first_image) # pyright: ignore[reportUnknownMemberType] - second_image_tensor = functional.to_tensor(pic=second_image) # pyright: ignore[reportUnknownMemberType] + transform = transforms.ToTensor() + + first_image_tensor = transform(pic=first_image) + second_image_tensor = transform(pic=second_image) first_image_tensor_batch_dimension = first_image_tensor.unsqueeze(0) second_image_tensor_batch_dimension = second_image_tensor.unsqueeze(0) From 1fa22b79277c8c33fbb541ec86c81d4b996f0eb6 Mon Sep 17 00:00:00 2001 From: Adam Dangoor Date: Thu, 27 Jun 2024 15:54:25 +0100 Subject: [PATCH 2/3] Use new style in a new place --- src/mock_vws/target_raters.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mock_vws/target_raters.py b/src/mock_vws/target_raters.py index a5d6f6aa6..719f2bd6b 100644 --- a/src/mock_vws/target_raters.py +++ b/src/mock_vws/target_raters.py @@ -8,7 +8,7 @@ import piq # type: ignore[import-untyped] from PIL import Image -from torchvision.transforms import functional # type: ignore[import-untyped] +from torchvision import transforms @functools.cache @@ -25,7 +25,8 @@ def _get_brisque_target_tracking_rating(image_content: bytes) -> int: """ image_file = io.BytesIO(initial_bytes=image_content) image = Image.open(fp=image_file) - image_tensor = functional.to_tensor(pic=image) * 255 # pyright: ignore[reportUnknownMemberType] + transform = transforms.ToTensor() + image_tensor = transform(pic=image) image_tensor = image_tensor.unsqueeze(0) try: brisque_score = piq.brisque(x=image_tensor, data_range=255) From cdeb0b38cb3abcd8a3110569bc281a5bc4a5440e Mon Sep 17 00:00:00 2001 From: Adam Dangoor Date: Thu, 27 Jun 2024 15:55:32 +0100 Subject: [PATCH 3/3] Add back a type: ignore --- src/mock_vws/target_raters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mock_vws/target_raters.py b/src/mock_vws/target_raters.py index 719f2bd6b..7ce255caf 100644 --- a/src/mock_vws/target_raters.py +++ b/src/mock_vws/target_raters.py @@ -8,7 +8,7 @@ import piq # type: ignore[import-untyped] from PIL import Image -from torchvision import transforms +from torchvision import transforms # type: ignore[import-untyped] @functools.cache