diff --git a/pyproject.toml b/pyproject.toml index 0840bbb3c..34c5d35b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "pydantic-settings", "requests", "requests-mock", + "torchmetrics", "torchvision", "tzdata; sys_platform=='win32'", "vws-auth-tools", @@ -84,7 +85,6 @@ optional-dependencies.dev = [ "sphinxcontrib-spelling==8", "sybil==6.1.1", "tenacity==8.4.2", - "torch==2.3.1", "types-docker==7.1.0.20240626", "types-pillow==10.2.0.20240520", "types-pyyaml==6.0.12.20240311", diff --git a/src/mock_vws/image_matchers.py b/src/mock_vws/image_matchers.py index 93fd9b1fa..ac5fb1452 100644 --- a/src/mock_vws/image_matchers.py +++ b/src/mock_vws/image_matchers.py @@ -1,15 +1,14 @@ """Matchers for query and duplicate requests.""" import io -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable -import piq # type: ignore[import-untyped] from PIL import Image +from torchmetrics.image import ( + StructuralSimilarityIndexMeasure, +) from torchvision.transforms import functional # type: ignore[import-untyped] -if TYPE_CHECKING: - import torch - @runtime_checkable class ImageMatcher(Protocol): @@ -81,12 +80,10 @@ def __call__( first_image_tensor_batch_dimension = first_image_tensor.unsqueeze(0) second_image_tensor_batch_dimension = second_image_tensor.unsqueeze(0) - # See https://github.com/photosynthesis-team/piq/pull/377 - # for fixing the type hint in ``piq``. - ssim_value: torch.Tensor = piq.ssim( # pyright: ignore[reportAssignmentType] - x=first_image_tensor_batch_dimension, - y=second_image_tensor_batch_dimension, - data_range=1.0, + ssim = StructuralSimilarityIndexMeasure(data_range=1.0) + ssim_value = ssim( + first_image_tensor_batch_dimension, + second_image_tensor_batch_dimension, ) ssim_score = ssim_value.item()