@@ -515,7 +515,7 @@ def convert(self, state_dict):
515
515
class WanVideoVAE (PreTrainedModel ):
516
516
converter = WanVideoVAEStateDictConverter ()
517
517
518
- def __init__ (self , z_dim = 16 , parallelism : int = 1 , device : str = "cuda:0" , dtype : torch .dtype = torch .float32 ):
518
+ def __init__ (self , z_dim = 16 , device : str = "cuda:0" , dtype : torch .dtype = torch .float32 ):
519
519
super ().__init__ ()
520
520
521
521
mean = [
@@ -561,12 +561,11 @@ def __init__(self, z_dim=16, parallelism: int = 1, device: str = "cuda:0", dtype
561
561
# init model
562
562
self .model = VideoVAE (z_dim = z_dim ).eval ().requires_grad_ (False )
563
563
self .upsampling_factor = 8
564
- self .parallelism = parallelism
565
564
566
565
@classmethod
567
- def from_state_dict (cls , state_dict , parallelism = 1 , device = "cuda:0" , dtype = torch .float32 ) -> "WanVideoVAE" :
566
+ def from_state_dict (cls , state_dict , device = "cuda:0" , dtype = torch .float32 ) -> "WanVideoVAE" :
568
567
with no_init_weights ():
569
- model = torch .nn .utils .skip_init (cls , parallelism = parallelism , device = device , dtype = dtype )
568
+ model = torch .nn .utils .skip_init (cls , device = device , dtype = dtype )
570
569
model .load_state_dict (state_dict , assign = True )
571
570
model .to (device = device , dtype = dtype , non_blocking = True )
572
571
return model
@@ -607,7 +606,7 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
607
606
h_ , w_ = h + size_h , w + size_w
608
607
tasks .append ((h , h_ , w , w_ ))
609
608
610
- data_device = device if self . parallelism > 1 else "cpu"
609
+ data_device = device if dist . is_initialized () else "cpu"
611
610
computation_device = device
612
611
613
612
out_T = T * 4 - 3
@@ -622,9 +621,9 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
622
621
device = data_device ,
623
622
)
624
623
625
- hide_progress_bar = self . parallelism > 1 and dist .get_rank () != 0
626
- for i , (h , h_ , w , w_ ) in enumerate (tqdm (tasks , desc = "VAE DECODING" , disable = hide_progress_bar )):
627
- if self . parallelism > 1 and (i % dist .get_world_size () != dist .get_rank ()):
624
+ hide_progress = dist . is_initialized () and dist .get_rank () != 0
625
+ for i , (h , h_ , w , w_ ) in enumerate (tqdm (tasks , desc = "VAE DECODING" , disable = hide_progress )):
626
+ if dist . is_initialized () and (i % dist .get_world_size () != dist .get_rank ()):
628
627
continue
629
628
hidden_states_batch = hidden_states [:, :, :, h :h_ , w :w_ ].to (computation_device )
630
629
hidden_states_batch = self .model .decode (hidden_states_batch , self .scale ).to (data_device )
@@ -654,11 +653,11 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
654
653
target_h : target_h + hidden_states_batch .shape [3 ],
655
654
target_w : target_w + hidden_states_batch .shape [4 ],
656
655
] += mask
657
- if progress_callback is not None and not hide_progress_bar :
656
+ if progress_callback is not None and not hide_progress :
658
657
progress_callback (i + 1 , len (tasks ), "VAE DECODING" )
659
- if progress_callback is not None and not hide_progress_bar :
658
+ if progress_callback is not None and not hide_progress :
660
659
progress_callback (len (tasks ), len (tasks ), "VAE DECODING" )
661
- if self . parallelism > 1 :
660
+ if dist . is_initialized () :
662
661
dist .all_reduce (values )
663
662
dist .all_reduce (weight )
664
663
values = values / weight
@@ -681,7 +680,7 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
681
680
h_ , w_ = h + size_h , w + size_w
682
681
tasks .append ((h , h_ , w , w_ ))
683
682
684
- data_device = device if self . parallelism > 1 else "cpu"
683
+ data_device = device if dist . is_initialized () else "cpu"
685
684
computation_device = device
686
685
687
686
out_T = (T + 3 ) // 4
@@ -696,9 +695,9 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
696
695
device = data_device ,
697
696
)
698
697
699
- hide_progress_bar = self . parallelism > 1 and dist .get_rank () != 0
698
+ hide_progress_bar = dist . is_initialized () and dist .get_rank () != 0
700
699
for i , (h , h_ , w , w_ ) in enumerate (tqdm (tasks , desc = "VAE ENCODING" , disable = hide_progress_bar )):
701
- if self . parallelism > 1 and (i % dist .get_world_size () != dist .get_rank ()):
700
+ if dist . is_initialized () and (i % dist .get_world_size () != dist .get_rank ()):
702
701
continue
703
702
hidden_states_batch = video [:, :, :, h :h_ , w :w_ ].to (computation_device )
704
703
hidden_states_batch = self .model .encode (hidden_states_batch , self .scale ).to (data_device )
@@ -732,7 +731,7 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
732
731
progress_callback (i + 1 , len (tasks ), "VAE ENCODING" )
733
732
if progress_callback is not None and not hide_progress_bar :
734
733
progress_callback (len (tasks ), len (tasks ), "VAE ENCODING" )
735
- if self . parallelism > 1 :
734
+ if dist . is_initialized () :
736
735
dist .all_reduce (values )
737
736
dist .all_reduce (weight )
738
737
values = values / weight
0 commit comments