Skip to content

Commit 2adfeb4

Browse files
committed
Adding explicit channels and shape parameters to image and numpy readers
1 parent 8585567 commit 2adfeb4

File tree

5 files changed

+31
-11
lines changed

5 files changed

+31
-11
lines changed

neuralmonkey/readers/image_reader.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
from typing import Callable, Iterable, List
22
import os
3-
from typeguard import check_argument_types
3+
44
import numpy as np
5+
from typeguard import check_argument_types
56
from PIL import Image, ImageFile
7+
8+
from neuralmonkey.logging import warn
9+
10+
611
ImageFile.LOAD_TRUNCATED_IMAGES = True
712

813

914
def image_reader(pad_w: int,
1015
pad_h: int,
16+
channels: int = 3,
1117
prefix: str = "",
1218
rescale_w: bool = False,
1319
rescale_h: bool = False,
@@ -17,7 +23,8 @@ def image_reader(pad_w: int,
1723
1824
Args:
1925
pad_w: Width to which the images will be padded/cropped/resized.
20-
pad_h: Height to with the images will be padded/corpped/resized.
26+
pad_h: Height to with the images will be padded/cropped/resized.
27+
channels: Number of channels in each image (default 3 for RGB)
2128
prefix: Prefix of the paths that are listed in a image files.
2229
rescale_w: If true, image is rescaled to have given width. It is
2330
cropped/padded otherwise.
@@ -57,6 +64,8 @@ def load(list_files: List[str]) -> Iterable[np.ndarray]:
5764
try:
5865
image = Image.open(path).convert(mode)
5966
except IOError:
67+
warn("Skipping image from file '{}' no. '{}'.".format(
68+
path, i + 1))
6069
image = Image.new(mode, (pad_w, pad_h))
6170

6271
image = _rescale_or_crop(image, pad_w, pad_h,
@@ -65,16 +74,22 @@ def load(list_files: List[str]) -> Iterable[np.ndarray]:
6574
image_np = np.array(image)
6675

6776
if len(image_np.shape) == 2:
68-
channels = 1
77+
img_channels = 1
6978
image_np = np.expand_dims(image_np, 2)
7079
elif len(image_np.shape) == 3:
71-
channels = image_np.shape[2]
80+
img_channels = image_np.shape[2]
7281
else:
7382
raise ValueError(
7483
("Image should have either 2 (black and white) "
7584
"or three dimensions (color channels), has {} "
7685
"dimension.").format(len(image_np.shape)))
7786

87+
if channels != img_channels:
88+
raise ValueError(
89+
"Image does not have the pre-declared number of "
90+
"channels {}, but {}.".format(
91+
channels, img_channels))
92+
7893
yield _pad(image_np, pad_w, pad_h, channels)
7994

8095
return load

neuralmonkey/readers/numpy_reader.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ def single_tensor(files: List[str]) -> np.ndarray:
1515

1616

1717
def from_file_list(prefix: str,
18+
shape: List[int],
1819
suffix: str = "",
1920
default_tensor_name: str = "arr_0") -> Callable:
2021
"""Load a list of numpy arrays from a list of .npz numpy files.
2122
2223
Args:
2324
prefix: A common prefix for the files in the list.
25+
shape: The shape of the numpy arrays stored in the referenced files.
2426
suffix: An optional suffix that will be appended to each path
2527
default_tensor_name: Key of the tensors to load from the npz files.
2628
@@ -35,10 +37,11 @@ def load(files: List[str]) -> Iterable[np.ndarray]:
3537
for line in f_list:
3638
path = os.path.join(prefix, line.rstrip()) + suffix
3739
with np.load(path) as npz:
38-
yield npz[default_tensor_name]
39-
40+
arr = npz[default_tensor_name]
41+
arr_shape = list(arr.shape)
42+
if arr_shape != shape:
43+
raise ValueError(
44+
"Shapes do not match: expected {}, found {}"
45+
.format(shape, arr_shape))
46+
yield arr
4047
return load
41-
42-
43-
# pylint: disable=invalid-name
44-
numpy_file_list_reader = from_file_list(prefix="")

tests/flat-multiattention.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ num_sessions=1
2424
[numpy_reader]
2525
class=readers.numpy_reader.from_file_list
2626
prefix="tests/data/flickr30k"
27-
# shape=[8, 8, 2048]
27+
shape=[8, 8, 2048]
2828

2929
[train_data]
3030
class=dataset.load

tests/hier-multiattention.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ prefix="tests/data/flickr30k"
3232
pad_h=32
3333
pad_w=32
3434
mode="RGB"
35+
channels=3
3536

3637
[train_data]
3738
class=dataset.load

tests/str.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pad_w=256
2929
rescale_w=True
3030
rescale_h=True
3131
mode="F"
32+
channels=1
3233

3334
[train_data]
3435
class=dataset.load

0 commit comments

Comments
 (0)