Skip to content

Commit a97affc

Browse files
committed
fixing failed travis tests
1 parent 7a62312 commit a97affc

File tree

2 files changed

+21
-31
lines changed

2 files changed

+21
-31
lines changed

neuralmonkey/dataset.py

+20-30
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(self,
9696

9797

9898
def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1):
99-
"""A default set of length-bucket boundaries."""
99+
"""Create a default set of length-bucket boundaries."""
100100
assert length_bucket_step > 1.0
101101
x = min_length
102102
boundaries = []
@@ -110,28 +110,25 @@ def get_batching_scheme(batch_size: int,
110110
max_length: int = None,
111111
min_length_bucket: int = 8,
112112
length_bucket_step: float = 1.1,
113-
drop_long_sequences: bool = False,
114113
shard_multiplier: int = 1,
115114
length_multiplier: int = 1,
116115
min_length: int = 0) -> BatchingScheme:
117-
"""A batching scheme based on model hyperparameters.
116+
"""Create a batching scheme based on model hyperparameters.
117+
118118
Every batch contains a number of sequences divisible by `shard_multiplier`.
119+
119120
Args:
120121
batch_size: int, total number of tokens in a batch.
121-
max_length: int, sequences longer than this will be skipped. Defaults to
122-
batch_size.
122+
max_length: int, sequences longer than this will be skipped. Defaults
123+
to batch_size.
123124
min_length_bucket: int
124125
length_bucket_step: float greater than 1.0
125-
drop_long_sequences: bool, if True, then sequences longer than
126-
`max_length` are dropped. This prevents generating batches with
127-
more than the usual number of tokens, which can cause out-of-memory
128-
errors.
129-
shard_multiplier: an integer increasing the batch_size to suit splitting
130-
across datashards.
126+
shard_multiplier: an integer increasing the batch_size to suit
127+
splitting across datashards.
131128
length_multiplier: an integer multiplier that is used to increase the
132129
batch sizes and sequence length tolerance.
133130
min_length: int, sequences shorter than this will be skipped.
134-
Returns:
131+
Return:
135132
A dictionary with parameters that can be passed to input_pipeline:
136133
* boundaries: list of bucket boundaries
137134
* batch_sizes: list of batch sizes for each length bucket
@@ -149,40 +146,33 @@ def get_batching_scheme(batch_size: int,
149146
max_length *= length_multiplier
150147

151148
batch_sizes = [
152-
max(1, batch_size // length) for length in boundaries + [max_length]
149+
max(1, batch_size // length) for length in boundaries + [max_length]
153150
]
154151
max_batch_size = max(batch_sizes)
155152
# Since the Datasets API only allows a single constant for window_size,
156153
# and it needs divide all bucket_batch_sizes, we pick a highly-composite
157-
# window size and then round down all batch sizes to divisors of that window
158-
# size, so that a window can always be divided evenly into batches.
159-
# TODO(noam): remove this when Dataset API improves.
154+
# window size and then round down all batch sizes to divisors of that
155+
# window size, so that a window can always be divided evenly into batches.
160156
highly_composite_numbers = [
161-
1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680,
162-
2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360, 50400, 55440,
163-
83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280,
164-
720720, 1081080, 1441440, 2162160, 2882880, 3603600, 4324320, 6486480,
165-
7207200, 8648640, 10810800, 14414400, 17297280, 21621600, 32432400,
166-
36756720, 43243200, 61261200, 73513440, 110270160
157+
1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260,
158+
1680, 2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360,
159+
50400, 55440, 83160, 110880, 166320, 221760, 277200, 332640, 498960,
160+
554400, 665280, 720720, 1081080, 1441440, 2162160, 2882880, 3603600,
161+
4324320, 6486480, 7207200, 8648640, 10810800, 14414400, 17297280,
162+
21621600, 32432400, 36756720, 43243200, 61261200, 73513440, 110270160
167163
]
168164
window_size = max(
169-
[i for i in highly_composite_numbers if i <= 3 * max_batch_size])
165+
[i for i in highly_composite_numbers if i <= 3 * max_batch_size])
170166
divisors = [i for i in range(1, window_size + 1) if window_size % i == 0]
171167
batch_sizes = [max([d for d in divisors if d <= bs]) for bs in batch_sizes]
172168
window_size *= shard_multiplier
173169
batch_sizes = [bs * shard_multiplier for bs in batch_sizes]
174-
# The Datasets API splits one window into multiple batches, which
175-
# produces runs of many consecutive batches of the same size. This
176-
# is bad for training. To solve this, we will shuffle the batches
177-
# using a queue which must be several times as large as the maximum
178-
# number of batches per window.
179-
max_batches_per_window = window_size // min(batch_sizes)
180-
shuffle_queue_size = max_batches_per_window * 3
181170

182171
ret = BatchingScheme(bucket_boundaries=boundaries,
183172
bucket_batch_sizes=batch_sizes)
184173
return ret
185174

175+
186176
# The protected functions below are designed to convert the ambiguous spec
187177
# structures to a normalized form.
188178

neuralmonkey/readers/string_vector_reader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def process_line(line: str, lineno: int, path: str) -> np.ndarray:
1313

1414
return np.array(numbers, dtype=dtype)
1515

16-
def reader(files: List[str])-> Iterable[List[np.ndarray]]:
16+
def reader(files: List[str]) -> Iterable[List[np.ndarray]]:
1717
for path in files:
1818
current_line = 0
1919

0 commit comments

Comments
 (0)