@@ -96,7 +96,7 @@ def __init__(self,
96
96
97
97
98
98
def _bucket_boundaries (max_length , min_length = 8 , length_bucket_step = 1.1 ):
99
- """A default set of length-bucket boundaries."""
99
+ """Create a default set of length-bucket boundaries."""
100
100
assert length_bucket_step > 1.0
101
101
x = min_length
102
102
boundaries = []
@@ -110,28 +110,25 @@ def get_batching_scheme(batch_size: int,
110
110
max_length : int = None ,
111
111
min_length_bucket : int = 8 ,
112
112
length_bucket_step : float = 1.1 ,
113
- drop_long_sequences : bool = False ,
114
113
shard_multiplier : int = 1 ,
115
114
length_multiplier : int = 1 ,
116
115
min_length : int = 0 ) -> BatchingScheme :
117
- """A batching scheme based on model hyperparameters.
116
+ """Create a batching scheme based on model hyperparameters.
117
+
118
118
Every batch contains a number of sequences divisible by `shard_multiplier`.
119
+
119
120
Args:
120
121
batch_size: int, total number of tokens in a batch.
121
- max_length: int, sequences longer than this will be skipped. Defaults to
122
- batch_size.
122
+ max_length: int, sequences longer than this will be skipped. Defaults
123
+ to batch_size.
123
124
min_length_bucket: int
124
125
length_bucket_step: float greater than 1.0
125
- drop_long_sequences: bool, if True, then sequences longer than
126
- `max_length` are dropped. This prevents generating batches with
127
- more than the usual number of tokens, which can cause out-of-memory
128
- errors.
129
- shard_multiplier: an integer increasing the batch_size to suit splitting
130
- across datashards.
126
+ shard_multiplier: an integer increasing the batch_size to suit
127
+ splitting across datashards.
131
128
length_multiplier: an integer multiplier that is used to increase the
132
129
batch sizes and sequence length tolerance.
133
130
min_length: int, sequences shorter than this will be skipped.
134
- Returns :
131
+ Return :
135
132
A dictionary with parameters that can be passed to input_pipeline:
136
133
* boundaries: list of bucket boundaries
137
134
* batch_sizes: list of batch sizes for each length bucket
@@ -149,40 +146,33 @@ def get_batching_scheme(batch_size: int,
149
146
max_length *= length_multiplier
150
147
151
148
batch_sizes = [
152
- max (1 , batch_size // length ) for length in boundaries + [max_length ]
149
+ max (1 , batch_size // length ) for length in boundaries + [max_length ]
153
150
]
154
151
max_batch_size = max (batch_sizes )
155
152
# Since the Datasets API only allows a single constant for window_size,
156
153
# and it needs divide all bucket_batch_sizes, we pick a highly-composite
157
- # window size and then round down all batch sizes to divisors of that window
158
- # size, so that a window can always be divided evenly into batches.
159
- # TODO(noam): remove this when Dataset API improves.
154
+ # window size and then round down all batch sizes to divisors of that
155
+ # window size, so that a window can always be divided evenly into batches.
160
156
highly_composite_numbers = [
161
- 1 , 2 , 4 , 6 , 12 , 24 , 36 , 48 , 60 , 120 , 180 , 240 , 360 , 720 , 840 , 1260 , 1680 ,
162
- 2520 , 5040 , 7560 , 10080 , 15120 , 20160 , 25200 , 27720 , 45360 , 50400 , 55440 ,
163
- 83160 , 110880 , 166320 , 221760 , 277200 , 332640 , 498960 , 554400 , 665280 ,
164
- 720720 , 1081080 , 1441440 , 2162160 , 2882880 , 3603600 , 4324320 , 6486480 ,
165
- 7207200 , 8648640 , 10810800 , 14414400 , 17297280 , 21621600 , 32432400 ,
166
- 36756720 , 43243200 , 61261200 , 73513440 , 110270160
157
+ 1 , 2 , 4 , 6 , 12 , 24 , 36 , 48 , 60 , 120 , 180 , 240 , 360 , 720 , 840 , 1260 ,
158
+ 1680 , 2520 , 5040 , 7560 , 10080 , 15120 , 20160 , 25200 , 27720 , 45360 ,
159
+ 50400 , 55440 , 83160 , 110880 , 166320 , 221760 , 277200 , 332640 , 498960 ,
160
+ 554400 , 665280 , 720720 , 1081080 , 1441440 , 2162160 , 2882880 , 3603600 ,
161
+ 4324320 , 6486480 , 7207200 , 8648640 , 10810800 , 14414400 , 17297280 ,
162
+ 21621600 , 32432400 , 36756720 , 43243200 , 61261200 , 73513440 , 110270160
167
163
]
168
164
window_size = max (
169
- [i for i in highly_composite_numbers if i <= 3 * max_batch_size ])
165
+ [i for i in highly_composite_numbers if i <= 3 * max_batch_size ])
170
166
divisors = [i for i in range (1 , window_size + 1 ) if window_size % i == 0 ]
171
167
batch_sizes = [max ([d for d in divisors if d <= bs ]) for bs in batch_sizes ]
172
168
window_size *= shard_multiplier
173
169
batch_sizes = [bs * shard_multiplier for bs in batch_sizes ]
174
- # The Datasets API splits one window into multiple batches, which
175
- # produces runs of many consecutive batches of the same size. This
176
- # is bad for training. To solve this, we will shuffle the batches
177
- # using a queue which must be several times as large as the maximum
178
- # number of batches per window.
179
- max_batches_per_window = window_size // min (batch_sizes )
180
- shuffle_queue_size = max_batches_per_window * 3
181
170
182
171
ret = BatchingScheme (bucket_boundaries = boundaries ,
183
172
bucket_batch_sizes = batch_sizes )
184
173
return ret
185
174
175
+
186
176
# The protected functions below are designed to convert the ambiguous spec
187
177
# structures to a normalized form.
188
178
0 commit comments