diff --git a/src/mock_vws/image_matchers.py b/src/mock_vws/image_matchers.py index ac5fb1452..cfd16bc05 100644 --- a/src/mock_vws/image_matchers.py +++ b/src/mock_vws/image_matchers.py @@ -7,7 +7,7 @@ from torchmetrics.image import ( StructuralSimilarityIndexMeasure, ) -from torchvision.transforms import functional # type: ignore[import-untyped] +from torchvision import transforms # type: ignore[import-untyped] @runtime_checkable @@ -74,8 +74,10 @@ def __call__( first_image = first_image.resize(size=target_size) second_image = second_image.resize(size=target_size) - first_image_tensor = functional.to_tensor(pic=first_image) # pyright: ignore[reportUnknownMemberType] - second_image_tensor = functional.to_tensor(pic=second_image) # pyright: ignore[reportUnknownMemberType] + transform = transforms.ToTensor() + + first_image_tensor = transform(pic=first_image) + second_image_tensor = transform(pic=second_image) first_image_tensor_batch_dimension = first_image_tensor.unsqueeze(0) second_image_tensor_batch_dimension = second_image_tensor.unsqueeze(0) diff --git a/src/mock_vws/target_raters.py b/src/mock_vws/target_raters.py index a5d6f6aa6..7ce255caf 100644 --- a/src/mock_vws/target_raters.py +++ b/src/mock_vws/target_raters.py @@ -8,7 +8,7 @@ import piq # type: ignore[import-untyped] from PIL import Image -from torchvision.transforms import functional # type: ignore[import-untyped] +from torchvision import transforms # type: ignore[import-untyped] @functools.cache @@ -25,7 +25,8 @@ def _get_brisque_target_tracking_rating(image_content: bytes) -> int: """ image_file = io.BytesIO(initial_bytes=image_content) image = Image.open(fp=image_file) - image_tensor = functional.to_tensor(pic=image) * 255 # pyright: ignore[reportUnknownMemberType] + transform = transforms.ToTensor() + image_tensor = transform(pic=image) image_tensor = image_tensor.unsqueeze(0) try: brisque_score = piq.brisque(x=image_tensor, data_range=255)