diff --git a/vdecoder/nsf_hifigan/models.py b/vdecoder/nsf_hifigan/models.py index c2c889ec..c8ff50b9 100644 --- a/vdecoder/nsf_hifigan/models.py +++ b/vdecoder/nsf_hifigan/models.py @@ -146,7 +146,10 @@ def forward(self, f0, upp): rand_ini[:, 0] = 0 rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini is_half = rad_values.dtype is not torch.float32 - tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 + if f0.device.type == 'mps': + tmp_over_one = torch.cumsum(rad_values, 1) + else: + tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 if is_half: tmp_over_one = tmp_over_one.half() else: @@ -161,8 +164,9 @@ def forward(self, f0, upp): tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 cumsum_shift = torch.zeros_like(rad_values) cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 - rad_values = rad_values.double() - cumsum_shift = cumsum_shift.double() + if f0.device.type != 'mps': + rad_values = rad_values.double() + cumsum_shift = cumsum_shift.double() sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) if is_half: sine_waves = sine_waves.half()