From f49277e1717ca72c996d494c802cd41b4f341c88 Mon Sep 17 00:00:00 2001 From: Karim Ali Date: Sat, 24 Jun 2023 15:34:39 +0200 Subject: [PATCH] Update models.py I added an if statement for MPS device because double is not supported for MPS device if mps: .float() else: .double() https://discuss.pytorch.org/t/typeerror-cannot-convert-a-mps-tensor-to-float64-dtype-as-the-mps-framework-doesnt-support-float64-please-use-float32-instead/180852 --- vdecoder/nsf_hifigan/models.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vdecoder/nsf_hifigan/models.py b/vdecoder/nsf_hifigan/models.py index c2c889ec..c8ff50b9 100644 --- a/vdecoder/nsf_hifigan/models.py +++ b/vdecoder/nsf_hifigan/models.py @@ -146,7 +146,10 @@ def forward(self, f0, upp): rand_ini[:, 0] = 0 rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini is_half = rad_values.dtype is not torch.float32 - tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 + if f0.device.type == 'mps': + tmp_over_one = torch.cumsum(rad_values, 1) + else: + tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 if is_half: tmp_over_one = tmp_over_one.half() else: @@ -161,8 +164,9 @@ def forward(self, f0, upp): tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 cumsum_shift = torch.zeros_like(rad_values) cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 - rad_values = rad_values.double() - cumsum_shift = cumsum_shift.double() + if f0.device.type != 'mps': + rad_values = rad_values.double() + cumsum_shift = cumsum_shift.double() sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) if is_half: sine_waves = sine_waves.half()