Skip to content

T2T batching #786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added batching schemes from tensor2tensor
  • Loading branch information
varisd committed Jan 30, 2019
commit 7a623121889ceef2168f491bff18f896aef3f56f
88 changes: 88 additions & 0 deletions neuralmonkey/dataset.py
Original file line number Diff line number Diff line change
@@ -95,6 +95,94 @@ def __init__(self,
# pylint: enable=too-few-public-methods


def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chybí typový anotace

"""A default set of length-bucket boundaries."""
assert length_bucket_step > 1.0
x = min_length
boundaries = []
while x < max_length:
boundaries.append(x)
x = max(x + 1, int(x * length_bucket_step))
return boundaries


def get_batching_scheme(batch_size: int,
max_length: int = None,
min_length_bucket: int = 8,
length_bucket_step: float = 1.1,
drop_long_sequences: bool = False,
shard_multiplier: int = 1,
length_multiplier: int = 1,
min_length: int = 0) -> BatchingScheme:
"""A batching scheme based on model hyperparameters.
Every batch contains a number of sequences divisible by `shard_multiplier`.
Args:
batch_size: int, total number of tokens in a batch.
max_length: int, sequences longer than this will be skipped. Defaults to
batch_size.
min_length_bucket: int
length_bucket_step: float greater than 1.0
drop_long_sequences: bool, if True, then sequences longer than
`max_length` are dropped. This prevents generating batches with
more than the usual number of tokens, which can cause out-of-memory
errors.
shard_multiplier: an integer increasing the batch_size to suit splitting
across datashards.
length_multiplier: an integer multiplier that is used to increase the
batch sizes and sequence length tolerance.
min_length: int, sequences shorter than this will be skipped.
Returns:
A dictionary with parameters that can be passed to input_pipeline:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tohle neni pravda

* boundaries: list of bucket boundaries
* batch_sizes: list of batch sizes for each length bucket
* max_length: int, maximum length of an example
Raises:
ValueError: If min_length > max_length
"""
max_length = max_length or batch_size
if max_length < min_length:
raise ValueError("max_length must be greater or equal to min_length")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tady by se mělo kontrolovat že length_bucket_step je > 1.0 a hodit valueerror se zprávou a nenechávat to až na assert v pomocný funkci

boundaries = _bucket_boundaries(max_length, min_length_bucket,
length_bucket_step)
boundaries = [boundary * length_multiplier for boundary in boundaries]
max_length *= length_multiplier

batch_sizes = [
max(1, batch_size // length) for length in boundaries + [max_length]
]
max_batch_size = max(batch_sizes)
# Since the Datasets API only allows a single constant for window_size,
# and it needs divide all bucket_batch_sizes, we pick a highly-composite
# window size and then round down all batch sizes to divisors of that window
# size, so that a window can always be divided evenly into batches.
# TODO(noam): remove this when Dataset API improves.
highly_composite_numbers = [
1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680,
2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360, 50400, 55440,
83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280,
720720, 1081080, 1441440, 2162160, 2882880, 3603600, 4324320, 6486480,
7207200, 8648640, 10810800, 14414400, 17297280, 21621600, 32432400,
36756720, 43243200, 61261200, 73513440, 110270160
]
window_size = max(
[i for i in highly_composite_numbers if i <= 3 * max_batch_size])
divisors = [i for i in range(1, window_size + 1) if window_size % i == 0]
batch_sizes = [max([d for d in divisors if d <= bs]) for bs in batch_sizes]
window_size *= shard_multiplier
batch_sizes = [bs * shard_multiplier for bs in batch_sizes]
# The Datasets API splits one window into multiple batches, which
# produces runs of many consecutive batches of the same size. This
# is bad for training. To solve this, we will shuffle the batches
# using a queue which must be several times as large as the maximum
# number of batches per window.
max_batches_per_window = window_size // min(batch_sizes)
shuffle_queue_size = max_batches_per_window * 3

ret = BatchingScheme(bucket_boundaries=boundaries,
bucket_batch_sizes=batch_sizes)
return ret

# The protected functions below are designed to convert the ambiguous spec
# structures to a normalized form.