Skip to content

Commit 7a62312

Browse files
committed
added batching schemes from tensor2tensor
1 parent 299c1bc commit 7a62312

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

neuralmonkey/dataset.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,94 @@ def __init__(self,
9595
# pylint: enable=too-few-public-methods
9696

9797

98+
def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1):
99+
"""A default set of length-bucket boundaries."""
100+
assert length_bucket_step > 1.0
101+
x = min_length
102+
boundaries = []
103+
while x < max_length:
104+
boundaries.append(x)
105+
x = max(x + 1, int(x * length_bucket_step))
106+
return boundaries
107+
108+
109+
def get_batching_scheme(batch_size: int,
110+
max_length: int = None,
111+
min_length_bucket: int = 8,
112+
length_bucket_step: float = 1.1,
113+
drop_long_sequences: bool = False,
114+
shard_multiplier: int = 1,
115+
length_multiplier: int = 1,
116+
min_length: int = 0) -> BatchingScheme:
117+
"""A batching scheme based on model hyperparameters.
118+
Every batch contains a number of sequences divisible by `shard_multiplier`.
119+
Args:
120+
batch_size: int, total number of tokens in a batch.
121+
max_length: int, sequences longer than this will be skipped. Defaults to
122+
batch_size.
123+
min_length_bucket: int
124+
length_bucket_step: float greater than 1.0
125+
drop_long_sequences: bool, if True, then sequences longer than
126+
`max_length` are dropped. This prevents generating batches with
127+
more than the usual number of tokens, which can cause out-of-memory
128+
errors.
129+
shard_multiplier: an integer increasing the batch_size to suit splitting
130+
across datashards.
131+
length_multiplier: an integer multiplier that is used to increase the
132+
batch sizes and sequence length tolerance.
133+
min_length: int, sequences shorter than this will be skipped.
134+
Returns:
135+
A dictionary with parameters that can be passed to input_pipeline:
136+
* boundaries: list of bucket boundaries
137+
* batch_sizes: list of batch sizes for each length bucket
138+
* max_length: int, maximum length of an example
139+
Raises:
140+
ValueError: If min_length > max_length
141+
"""
142+
max_length = max_length or batch_size
143+
if max_length < min_length:
144+
raise ValueError("max_length must be greater or equal to min_length")
145+
146+
boundaries = _bucket_boundaries(max_length, min_length_bucket,
147+
length_bucket_step)
148+
boundaries = [boundary * length_multiplier for boundary in boundaries]
149+
max_length *= length_multiplier
150+
151+
batch_sizes = [
152+
max(1, batch_size // length) for length in boundaries + [max_length]
153+
]
154+
max_batch_size = max(batch_sizes)
155+
# Since the Datasets API only allows a single constant for window_size,
156+
# and it needs divide all bucket_batch_sizes, we pick a highly-composite
157+
# window size and then round down all batch sizes to divisors of that window
158+
# size, so that a window can always be divided evenly into batches.
159+
# TODO(noam): remove this when Dataset API improves.
160+
highly_composite_numbers = [
161+
1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680,
162+
2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360, 50400, 55440,
163+
83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280,
164+
720720, 1081080, 1441440, 2162160, 2882880, 3603600, 4324320, 6486480,
165+
7207200, 8648640, 10810800, 14414400, 17297280, 21621600, 32432400,
166+
36756720, 43243200, 61261200, 73513440, 110270160
167+
]
168+
window_size = max(
169+
[i for i in highly_composite_numbers if i <= 3 * max_batch_size])
170+
divisors = [i for i in range(1, window_size + 1) if window_size % i == 0]
171+
batch_sizes = [max([d for d in divisors if d <= bs]) for bs in batch_sizes]
172+
window_size *= shard_multiplier
173+
batch_sizes = [bs * shard_multiplier for bs in batch_sizes]
174+
# The Datasets API splits one window into multiple batches, which
175+
# produces runs of many consecutive batches of the same size. This
176+
# is bad for training. To solve this, we will shuffle the batches
177+
# using a queue which must be several times as large as the maximum
178+
# number of batches per window.
179+
max_batches_per_window = window_size // min(batch_sizes)
180+
shuffle_queue_size = max_batches_per_window * 3
181+
182+
ret = BatchingScheme(bucket_boundaries=boundaries,
183+
bucket_batch_sizes=batch_sizes)
184+
return ret
185+
98186
# The protected functions below are designed to convert the ambiguous spec
99187
# structures to a normalized form.
100188

0 commit comments

Comments
 (0)