Skip to content

Commit d7ed4c6

Browse files
Merge pull request #2140 from VWS-Python/rm-pyright-ignore-torchvision
Use torchvision in a different way to avoid type errors
2 parents b47c313 + cdeb0b3 commit d7ed4c6

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/mock_vws/image_matchers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchmetrics.image import (
88
StructuralSimilarityIndexMeasure,
99
)
10-
from torchvision.transforms import functional # type: ignore[import-untyped]
10+
from torchvision import transforms # type: ignore[import-untyped]
1111

1212

1313
@runtime_checkable
@@ -74,8 +74,10 @@ def __call__(
7474
first_image = first_image.resize(size=target_size)
7575
second_image = second_image.resize(size=target_size)
7676

77-
first_image_tensor = functional.to_tensor(pic=first_image) # pyright: ignore[reportUnknownMemberType]
78-
second_image_tensor = functional.to_tensor(pic=second_image) # pyright: ignore[reportUnknownMemberType]
77+
transform = transforms.ToTensor()
78+
79+
first_image_tensor = transform(pic=first_image)
80+
second_image_tensor = transform(pic=second_image)
7981

8082
first_image_tensor_batch_dimension = first_image_tensor.unsqueeze(0)
8183
second_image_tensor_batch_dimension = second_image_tensor.unsqueeze(0)

src/mock_vws/target_raters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import piq # type: ignore[import-untyped]
1010
from PIL import Image
11-
from torchvision.transforms import functional # type: ignore[import-untyped]
11+
from torchvision import transforms # type: ignore[import-untyped]
1212

1313

1414
@functools.cache
@@ -25,7 +25,8 @@ def _get_brisque_target_tracking_rating(image_content: bytes) -> int:
2525
"""
2626
image_file = io.BytesIO(initial_bytes=image_content)
2727
image = Image.open(fp=image_file)
28-
image_tensor = functional.to_tensor(pic=image) * 255 # pyright: ignore[reportUnknownMemberType]
28+
transform = transforms.ToTensor()
29+
image_tensor = transform(pic=image)
2930
image_tensor = image_tensor.unsqueeze(0)
3031
try:
3132
brisque_score = piq.brisque(x=image_tensor, data_range=255)

0 commit comments

Comments
 (0)