Skip to content

Commit b47c313

Browse files
Merge pull request #2139 from VWS-Python/torchmetrics
Introduce torchmetrics in place of piq in one place
2 parents 6bdc0a4 + 5b553b5 commit b47c313

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dependencies = [
4242
"pydantic-settings",
4343
"requests",
4444
"requests-mock",
45+
"torchmetrics",
4546
"torchvision",
4647
"tzdata; sys_platform=='win32'",
4748
"vws-auth-tools",
@@ -84,7 +85,6 @@ optional-dependencies.dev = [
8485
"sphinxcontrib-spelling==8",
8586
"sybil==6.1.1",
8687
"tenacity==8.4.2",
87-
"torch==2.3.1",
8888
"types-docker==7.1.0.20240626",
8989
"types-pillow==10.2.0.20240520",
9090
"types-pyyaml==6.0.12.20240311",

src/mock_vws/image_matchers.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
"""Matchers for query and duplicate requests."""
22

33
import io
4-
from typing import TYPE_CHECKING, Protocol, runtime_checkable
4+
from typing import Protocol, runtime_checkable
55

6-
import piq # type: ignore[import-untyped]
76
from PIL import Image
7+
from torchmetrics.image import (
8+
StructuralSimilarityIndexMeasure,
9+
)
810
from torchvision.transforms import functional # type: ignore[import-untyped]
911

10-
if TYPE_CHECKING:
11-
import torch
12-
1312

1413
@runtime_checkable
1514
class ImageMatcher(Protocol):
@@ -81,12 +80,10 @@ def __call__(
8180
first_image_tensor_batch_dimension = first_image_tensor.unsqueeze(0)
8281
second_image_tensor_batch_dimension = second_image_tensor.unsqueeze(0)
8382

84-
# See https://github.com/photosynthesis-team/piq/pull/377
85-
# for fixing the type hint in ``piq``.
86-
ssim_value: torch.Tensor = piq.ssim( # pyright: ignore[reportAssignmentType]
87-
x=first_image_tensor_batch_dimension,
88-
y=second_image_tensor_batch_dimension,
89-
data_range=1.0,
83+
ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
84+
ssim_value = ssim(
85+
first_image_tensor_batch_dimension,
86+
second_image_tensor_batch_dimension,
9087
)
9188
ssim_score = ssim_value.item()
9289

0 commit comments

Comments
 (0)