@@ -95,6 +95,94 @@ def __init__(self,
95
95
# pylint: enable=too-few-public-methods
96
96
97
97
98
+ def _bucket_boundaries (max_length , min_length = 8 , length_bucket_step = 1.1 ):
99
+ """A default set of length-bucket boundaries."""
100
+ assert length_bucket_step > 1.0
101
+ x = min_length
102
+ boundaries = []
103
+ while x < max_length :
104
+ boundaries .append (x )
105
+ x = max (x + 1 , int (x * length_bucket_step ))
106
+ return boundaries
107
+
108
+
109
+ def get_batching_scheme (batch_size : int ,
110
+ max_length : int = None ,
111
+ min_length_bucket : int = 8 ,
112
+ length_bucket_step : float = 1.1 ,
113
+ drop_long_sequences : bool = False ,
114
+ shard_multiplier : int = 1 ,
115
+ length_multiplier : int = 1 ,
116
+ min_length : int = 0 ) -> BatchingScheme :
117
+ """A batching scheme based on model hyperparameters.
118
+ Every batch contains a number of sequences divisible by `shard_multiplier`.
119
+ Args:
120
+ batch_size: int, total number of tokens in a batch.
121
+ max_length: int, sequences longer than this will be skipped. Defaults to
122
+ batch_size.
123
+ min_length_bucket: int
124
+ length_bucket_step: float greater than 1.0
125
+ drop_long_sequences: bool, if True, then sequences longer than
126
+ `max_length` are dropped. This prevents generating batches with
127
+ more than the usual number of tokens, which can cause out-of-memory
128
+ errors.
129
+ shard_multiplier: an integer increasing the batch_size to suit splitting
130
+ across datashards.
131
+ length_multiplier: an integer multiplier that is used to increase the
132
+ batch sizes and sequence length tolerance.
133
+ min_length: int, sequences shorter than this will be skipped.
134
+ Returns:
135
+ A dictionary with parameters that can be passed to input_pipeline:
136
+ * boundaries: list of bucket boundaries
137
+ * batch_sizes: list of batch sizes for each length bucket
138
+ * max_length: int, maximum length of an example
139
+ Raises:
140
+ ValueError: If min_length > max_length
141
+ """
142
+ max_length = max_length or batch_size
143
+ if max_length < min_length :
144
+ raise ValueError ("max_length must be greater or equal to min_length" )
145
+
146
+ boundaries = _bucket_boundaries (max_length , min_length_bucket ,
147
+ length_bucket_step )
148
+ boundaries = [boundary * length_multiplier for boundary in boundaries ]
149
+ max_length *= length_multiplier
150
+
151
+ batch_sizes = [
152
+ max (1 , batch_size // length ) for length in boundaries + [max_length ]
153
+ ]
154
+ max_batch_size = max (batch_sizes )
155
+ # Since the Datasets API only allows a single constant for window_size,
156
+ # and it needs divide all bucket_batch_sizes, we pick a highly-composite
157
+ # window size and then round down all batch sizes to divisors of that window
158
+ # size, so that a window can always be divided evenly into batches.
159
+ # TODO(noam): remove this when Dataset API improves.
160
+ highly_composite_numbers = [
161
+ 1 , 2 , 4 , 6 , 12 , 24 , 36 , 48 , 60 , 120 , 180 , 240 , 360 , 720 , 840 , 1260 , 1680 ,
162
+ 2520 , 5040 , 7560 , 10080 , 15120 , 20160 , 25200 , 27720 , 45360 , 50400 , 55440 ,
163
+ 83160 , 110880 , 166320 , 221760 , 277200 , 332640 , 498960 , 554400 , 665280 ,
164
+ 720720 , 1081080 , 1441440 , 2162160 , 2882880 , 3603600 , 4324320 , 6486480 ,
165
+ 7207200 , 8648640 , 10810800 , 14414400 , 17297280 , 21621600 , 32432400 ,
166
+ 36756720 , 43243200 , 61261200 , 73513440 , 110270160
167
+ ]
168
+ window_size = max (
169
+ [i for i in highly_composite_numbers if i <= 3 * max_batch_size ])
170
+ divisors = [i for i in range (1 , window_size + 1 ) if window_size % i == 0 ]
171
+ batch_sizes = [max ([d for d in divisors if d <= bs ]) for bs in batch_sizes ]
172
+ window_size *= shard_multiplier
173
+ batch_sizes = [bs * shard_multiplier for bs in batch_sizes ]
174
+ # The Datasets API splits one window into multiple batches, which
175
+ # produces runs of many consecutive batches of the same size. This
176
+ # is bad for training. To solve this, we will shuffle the batches
177
+ # using a queue which must be several times as large as the maximum
178
+ # number of batches per window.
179
+ max_batches_per_window = window_size // min (batch_sizes )
180
+ shuffle_queue_size = max_batches_per_window * 3
181
+
182
+ ret = BatchingScheme (bucket_boundaries = boundaries ,
183
+ bucket_batch_sizes = batch_sizes )
184
+ return ret
185
+
98
186
# The protected functions below are designed to convert the ambiguous spec
99
187
# structures to a normalized form.
100
188
0 commit comments