Skip to content

Commit bac049a

Browse files
committed
fix amp
1 parent deb97db commit bac049a

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

naturalspeech2_pytorch/naturalspeech2_pytorch.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,7 +1649,7 @@ def __init__(
16491649
num_samples = 1,
16501650
results_folder = './results',
16511651
amp = False,
1652-
fp16 = False,
1652+
mixed_precision_type = 'fp16',
16531653
use_ema = True,
16541654
split_batches = True,
16551655
dataloader = None,
@@ -1663,11 +1663,9 @@ def __init__(
16631663

16641664
self.accelerator = Accelerator(
16651665
split_batches = split_batches,
1666-
mixed_precision = 'fp16' if fp16 else 'no'
1666+
mixed_precision = mixed_precision_type if amp else 'no'
16671667
)
16681668

1669-
self.accelerator.native_amp = amp
1670-
16711669
# model
16721670

16731671
self.model = diffusion_model

naturalspeech2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.4'
1+
__version__ = '0.1.5'

0 commit comments

Comments
 (0)