Skip to content

Commit 69b1664

Browse files
committed
setup amp correctly
1 parent 9704aed commit 69b1664

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -906,20 +906,18 @@ def __init__(
906906
num_samples = 25,
907907
results_folder = './results',
908908
amp = False,
909-
fp16 = False,
909+
mixed_precision_type = 'fp16',
910910
split_batches = True,
911911
convert_image_to = None
912912
):
913913
super().__init__()
914914

915915
self.accelerator = Accelerator(
916916
split_batches = split_batches,
917-
mixed_precision = 'fp16' if fp16 else 'no',
917+
mixed_precision = mixed_precision_type if amp else 'no',
918918
kwargs_handlers = [DistributedDataParallelKwargs(find_unused_parameters=True)]
919919
)
920920

921-
self.accelerator.native_amp = amp
922-
923921
self.model = diffusion_model
924922

925923
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.7.5',
6+
version = '0.7.6',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)