@@ -96,7 +96,7 @@ def __init__(self,
9696
9797
9898def _bucket_boundaries (max_length , min_length = 8 , length_bucket_step = 1.1 ):
99- """A default set of length-bucket boundaries."""
99+ """Create a default set of length-bucket boundaries."""
100100 assert length_bucket_step > 1.0
101101 x = min_length
102102 boundaries = []
@@ -110,28 +110,25 @@ def get_batching_scheme(batch_size: int,
110110 max_length : int = None ,
111111 min_length_bucket : int = 8 ,
112112 length_bucket_step : float = 1.1 ,
113- drop_long_sequences : bool = False ,
114113 shard_multiplier : int = 1 ,
115114 length_multiplier : int = 1 ,
116115 min_length : int = 0 ) -> BatchingScheme :
117- """A batching scheme based on model hyperparameters.
116+ """Create a batching scheme based on model hyperparameters.
117+
118118 Every batch contains a number of sequences divisible by `shard_multiplier`.
119+
119120 Args:
120121 batch_size: int, total number of tokens in a batch.
121- max_length: int, sequences longer than this will be skipped. Defaults to
122- batch_size.
122+ max_length: int, sequences longer than this will be skipped. Defaults
123+ to batch_size.
123124 min_length_bucket: int
124125 length_bucket_step: float greater than 1.0
125- drop_long_sequences: bool, if True, then sequences longer than
126- `max_length` are dropped. This prevents generating batches with
127- more than the usual number of tokens, which can cause out-of-memory
128- errors.
129- shard_multiplier: an integer increasing the batch_size to suit splitting
130- across datashards.
126+ shard_multiplier: an integer increasing the batch_size to suit
127+ splitting across datashards.
131128 length_multiplier: an integer multiplier that is used to increase the
132129 batch sizes and sequence length tolerance.
133130 min_length: int, sequences shorter than this will be skipped.
134- Returns :
131+ Return :
135132 A dictionary with parameters that can be passed to input_pipeline:
136133 * boundaries: list of bucket boundaries
137134 * batch_sizes: list of batch sizes for each length bucket
@@ -149,40 +146,33 @@ def get_batching_scheme(batch_size: int,
149146 max_length *= length_multiplier
150147
151148 batch_sizes = [
152- max (1 , batch_size // length ) for length in boundaries + [max_length ]
149+ max (1 , batch_size // length ) for length in boundaries + [max_length ]
153150 ]
154151 max_batch_size = max (batch_sizes )
155152 # Since the Datasets API only allows a single constant for window_size,
156153 # and it needs divide all bucket_batch_sizes, we pick a highly-composite
157- # window size and then round down all batch sizes to divisors of that window
158- # size, so that a window can always be divided evenly into batches.
159- # TODO(noam): remove this when Dataset API improves.
154+ # window size and then round down all batch sizes to divisors of that
155+ # window size, so that a window can always be divided evenly into batches.
160156 highly_composite_numbers = [
161- 1 , 2 , 4 , 6 , 12 , 24 , 36 , 48 , 60 , 120 , 180 , 240 , 360 , 720 , 840 , 1260 , 1680 ,
162- 2520 , 5040 , 7560 , 10080 , 15120 , 20160 , 25200 , 27720 , 45360 , 50400 , 55440 ,
163- 83160 , 110880 , 166320 , 221760 , 277200 , 332640 , 498960 , 554400 , 665280 ,
164- 720720 , 1081080 , 1441440 , 2162160 , 2882880 , 3603600 , 4324320 , 6486480 ,
165- 7207200 , 8648640 , 10810800 , 14414400 , 17297280 , 21621600 , 32432400 ,
166- 36756720 , 43243200 , 61261200 , 73513440 , 110270160
157+ 1 , 2 , 4 , 6 , 12 , 24 , 36 , 48 , 60 , 120 , 180 , 240 , 360 , 720 , 840 , 1260 ,
158+ 1680 , 2520 , 5040 , 7560 , 10080 , 15120 , 20160 , 25200 , 27720 , 45360 ,
159+ 50400 , 55440 , 83160 , 110880 , 166320 , 221760 , 277200 , 332640 , 498960 ,
160+ 554400 , 665280 , 720720 , 1081080 , 1441440 , 2162160 , 2882880 , 3603600 ,
161+ 4324320 , 6486480 , 7207200 , 8648640 , 10810800 , 14414400 , 17297280 ,
162+ 21621600 , 32432400 , 36756720 , 43243200 , 61261200 , 73513440 , 110270160
167163 ]
168164 window_size = max (
169- [i for i in highly_composite_numbers if i <= 3 * max_batch_size ])
165+ [i for i in highly_composite_numbers if i <= 3 * max_batch_size ])
170166 divisors = [i for i in range (1 , window_size + 1 ) if window_size % i == 0 ]
171167 batch_sizes = [max ([d for d in divisors if d <= bs ]) for bs in batch_sizes ]
172168 window_size *= shard_multiplier
173169 batch_sizes = [bs * shard_multiplier for bs in batch_sizes ]
174- # The Datasets API splits one window into multiple batches, which
175- # produces runs of many consecutive batches of the same size. This
176- # is bad for training. To solve this, we will shuffle the batches
177- # using a queue which must be several times as large as the maximum
178- # number of batches per window.
179- max_batches_per_window = window_size // min (batch_sizes )
180- shuffle_queue_size = max_batches_per_window * 3
181170
182171 ret = BatchingScheme (bucket_boundaries = boundaries ,
183172 bucket_batch_sizes = batch_sizes )
184173 return ret
185174
175+
186176# The protected functions below are designed to convert the ambiguous spec
187177# structures to a normalized form.
188178
0 commit comments