@@ -90,9 +90,9 @@ def _pad(seq, max_len, constant_values=0):
90
90
mode = 'constant' , constant_values = constant_values )
91
91
92
92
93
- def _pad_2d (x , max_len , b_pad = 0 ):
93
+ def _pad_2d (x , max_len , b_pad = 0 , constant_values = 0 ):
94
94
x = np .pad (x , [(b_pad , max_len - len (x ) - b_pad ), (0 , 0 )],
95
- mode = "constant" , constant_values = 0 )
95
+ mode = "constant" , constant_values = constant_values )
96
96
return x
97
97
98
98
@@ -417,17 +417,19 @@ def collate_fn(batch):
417
417
# (B, T, C)
418
418
# pad for time-axis
419
419
if is_mulaw_quantize (hparams .input_type ):
420
+ padding_value = P .mulaw_quantize (0 , mu = hparams .quantize_channels )
420
421
x_batch = np .array ([_pad_2d (np_utils .to_categorical (
421
422
x [0 ], num_classes = hparams .quantize_channels ),
422
- max_input_len ) for x in batch ], dtype = np .float32 )
423
+ max_input_len , padding_value ) for x in batch ], dtype = np .float32 )
423
424
else :
424
425
x_batch = np .array ([_pad_2d (x [0 ].reshape (- 1 , 1 ), max_input_len )
425
426
for x in batch ], dtype = np .float32 )
426
427
assert len (x_batch .shape ) == 3
427
428
428
429
# (B, T)
429
430
if is_mulaw_quantize (hparams .input_type ):
430
- y_batch = np .array ([_pad (x [0 ], max_input_len ) for x in batch ], dtype = np .int )
431
+ padding_value = P .mulaw_quantize (0 , mu = hparams .quantize_channels )
432
+ y_batch = np .array ([_pad (x [0 ], max_input_len , constant_values = padding_value ) for x in batch ], dtype = np .int )
431
433
else :
432
434
y_batch = np .array ([_pad (x [0 ], max_input_len ) for x in batch ], dtype = np .float32 )
433
435
assert len (y_batch .shape ) == 2
0 commit comments