11# Adapted from 
22# https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/opensora/models/causalvideovae/model/modeling_videobase.py 
33
4+ import  copy 
5+ import  logging 
6+ import  os 
7+ from  typing  import  Dict , Optional , Union 
8+ 
9+ from  huggingface_hub  import  DDUFEntry 
10+ from  huggingface_hub .utils  import  validate_hf_hub_args 
11+ 
412import  mindspore  as  ms 
13+ from  mindspore .nn .utils  import  no_init_parameters 
514
6- from  mindone .diffusers  import  ModelMixin 
15+ from  mindone .diffusers  import  ModelMixin ,  __version__ 
716from  mindone .diffusers .configuration_utils  import  ConfigMixin 
17+ from  mindone .diffusers .models .model_loading_utils  import  _fetch_index_file , _fetch_index_file_legacy , load_state_dict 
18+ from  mindone .diffusers .models .modeling_utils  import  _convert_state_dict 
19+ from  mindone .diffusers .utils  import  (
20+     SAFETENSORS_WEIGHTS_NAME ,
21+     WEIGHTS_NAME ,
22+     _add_variant ,
23+     _get_checkpoint_shard_files ,
24+     _get_model_file ,
25+ )
26+ 
27+ logger  =  logging .getLogger (__name__ )
828
929
1030class  VideoBaseAE (ModelMixin , ConfigMixin ):
@@ -23,3 +43,237 @@ def encode(self, x: ms.Tensor, *args, **kwargs):
2343
2444    def  decode (self , encoding : ms .Tensor , * args , ** kwargs ):
2545        pass 
46+ 
47+     @classmethod  
48+     @validate_hf_hub_args  
49+     def  from_pretrained (cls , pretrained_model_name_or_path : Optional [Union [str , os .PathLike ]], ** kwargs ):
50+         # adapted from mindone.diffusers.models.modeling_utils.from_pretrained 
51+         state_dict  =  kwargs .pop ("state_dict" , None )  # additional key argument 
52+         cache_dir  =  kwargs .pop ("cache_dir" , None )
53+         ignore_mismatched_sizes  =  kwargs .pop ("ignore_mismatched_sizes" , False )
54+         force_download  =  kwargs .pop ("force_download" , False )
55+         from_flax  =  kwargs .pop ("from_flax" , False )
56+         proxies  =  kwargs .pop ("proxies" , None )
57+         output_loading_info  =  kwargs .pop ("output_loading_info" , False )
58+         local_files_only  =  kwargs .pop ("local_files_only" , None )
59+         token  =  kwargs .pop ("token" , None )
60+         revision  =  kwargs .pop ("revision" , None )
61+         mindspore_dtype  =  kwargs .pop ("mindspore_dtype" , None )
62+         subfolder  =  kwargs .pop ("subfolder" , None )
63+         variant  =  kwargs .pop ("variant" , None )
64+         use_safetensors  =  kwargs .pop ("use_safetensors" , None )
65+         dduf_entries : Optional [Dict [str , DDUFEntry ]] =  kwargs .pop ("dduf_entries" , None )
66+         disable_mmap  =  kwargs .pop ("disable_mmap" , False )
67+ 
68+         if  mindspore_dtype  is  not   None  and  not  isinstance (mindspore_dtype , ms .Type ):
69+             mindspore_dtype  =  ms .float32 
70+             logger .warning (
71+                 f"Passed `mindspore_dtype` { mindspore_dtype }   is not a `ms.Type`. Defaulting to `ms.float32`." 
72+             )
73+ 
74+         allow_pickle  =  False 
75+         if  use_safetensors  is  None :
76+             use_safetensors  =  True 
77+             allow_pickle  =  True 
78+ 
79+         user_agent  =  {
80+             "diffusers" : __version__ ,
81+             "file_type" : "model" ,
82+             "framework" : "pytorch" ,
83+         }
84+         unused_kwargs  =  {}
85+ 
86+         # Load config if we don't provide a configuration 
87+         config_path  =  pretrained_model_name_or_path 
88+ 
89+         # load config 
90+         config , unused_kwargs , commit_hash  =  cls .load_config (
91+             config_path ,
92+             cache_dir = cache_dir ,
93+             return_unused_kwargs = True ,
94+             return_commit_hash = True ,
95+             force_download = force_download ,
96+             proxies = proxies ,
97+             local_files_only = local_files_only ,
98+             token = token ,
99+             revision = revision ,
100+             subfolder = subfolder ,
101+             user_agent = user_agent ,
102+             dduf_entries = dduf_entries ,
103+             ** kwargs ,
104+         )
105+         # no in-place modification of the original config. 
106+         config  =  copy .deepcopy (config )
107+ 
108+         # Check if `_keep_in_fp32_modules` is not None 
109+         # use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and ( 
110+         #     hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False) 
111+         # ) 
112+         use_keep_in_fp32_modules  =  (cls ._keep_in_fp32_modules  is  not   None ) and  (mindspore_dtype  ==  ms .float16 )
113+ 
114+         if  use_keep_in_fp32_modules :
115+             keep_in_fp32_modules  =  cls ._keep_in_fp32_modules 
116+             if  not  isinstance (keep_in_fp32_modules , list ):
117+                 keep_in_fp32_modules  =  [keep_in_fp32_modules ]
118+         else :
119+             keep_in_fp32_modules  =  []
120+ 
121+         is_sharded  =  False 
122+         resolved_model_file  =  None 
123+ 
124+         # Determine if we're loading from a directory of sharded checkpoints. 
125+         sharded_metadata  =  None 
126+         index_file  =  None 
127+         is_local  =  os .path .isdir (pretrained_model_name_or_path )
128+         index_file_kwargs  =  {
129+             "is_local" : is_local ,
130+             "pretrained_model_name_or_path" : pretrained_model_name_or_path ,
131+             "subfolder" : subfolder  or  "" ,
132+             "use_safetensors" : use_safetensors ,
133+             "cache_dir" : cache_dir ,
134+             "variant" : variant ,
135+             "force_download" : force_download ,
136+             "proxies" : proxies ,
137+             "local_files_only" : local_files_only ,
138+             "token" : token ,
139+             "revision" : revision ,
140+             "user_agent" : user_agent ,
141+             "commit_hash" : commit_hash ,
142+             "dduf_entries" : dduf_entries ,
143+         }
144+         index_file  =  _fetch_index_file (** index_file_kwargs )
145+         # In case the index file was not found we still have to consider the legacy format. 
146+         # this becomes applicable when the variant is not None. 
147+         if  variant  is  not   None  and  (index_file  is  None  or  not  os .path .exists (index_file )):
148+             index_file  =  _fetch_index_file_legacy (** index_file_kwargs )
149+         if  index_file  is  not   None  and  (dduf_entries  or  index_file .is_file ()):
150+             is_sharded  =  True 
151+ 
152+         # load model 
153+         if  from_flax :
154+             raise  NotImplementedError ("loading flax checkpoint in mindspore model is not yet supported." )
155+         else :
156+             # in the case it is sharded, we have already the index 
157+             if  is_sharded :
158+                 resolved_model_file , sharded_metadata  =  _get_checkpoint_shard_files (
159+                     pretrained_model_name_or_path ,
160+                     index_file ,
161+                     cache_dir = cache_dir ,
162+                     proxies = proxies ,
163+                     local_files_only = local_files_only ,
164+                     token = token ,
165+                     user_agent = user_agent ,
166+                     revision = revision ,
167+                     subfolder = subfolder  or  "" ,
168+                     dduf_entries = dduf_entries ,
169+                 )
170+             elif  use_safetensors :
171+                 try :
172+                     resolved_model_file  =  _get_model_file (
173+                         pretrained_model_name_or_path ,
174+                         weights_name = _add_variant (SAFETENSORS_WEIGHTS_NAME , variant ),
175+                         cache_dir = cache_dir ,
176+                         force_download = force_download ,
177+                         proxies = proxies ,
178+                         local_files_only = local_files_only ,
179+                         token = token ,
180+                         revision = revision ,
181+                         subfolder = subfolder ,
182+                         user_agent = user_agent ,
183+                         commit_hash = commit_hash ,
184+                         dduf_entries = dduf_entries ,
185+                     )
186+ 
187+                 except  IOError  as  e :
188+                     logger .error (f"An error occurred while trying to fetch { pretrained_model_name_or_path }  : { e }  " )
189+                     if  not  allow_pickle :
190+                         raise 
191+                     logger .warning (
192+                         "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." 
193+                     )
194+ 
195+             if  resolved_model_file  is  None  and  not  is_sharded :
196+                 resolved_model_file  =  _get_model_file (
197+                     pretrained_model_name_or_path ,
198+                     weights_name = _add_variant (WEIGHTS_NAME , variant ),
199+                     cache_dir = cache_dir ,
200+                     force_download = force_download ,
201+                     proxies = proxies ,
202+                     local_files_only = local_files_only ,
203+                     token = token ,
204+                     revision = revision ,
205+                     subfolder = subfolder ,
206+                     user_agent = user_agent ,
207+                     commit_hash = commit_hash ,
208+                     dduf_entries = dduf_entries ,
209+                 )
210+ 
211+         if  not  isinstance (resolved_model_file , list ):
212+             resolved_model_file  =  [resolved_model_file ]
213+ 
214+         # set dtype to instantiate the model under: 
215+         # 1. If mindspore_dtype is not None, we use that dtype 
216+         # 2. If mindspore_dtype is float8, we don't use _set_default_mindspore_dtype and we downcast after loading the model 
217+         dtype_orig  =  None   # noqa 
218+         if  mindspore_dtype  is  not   None :
219+             if  not  isinstance (mindspore_dtype , ms .Type ):
220+                 raise  ValueError (
221+                     f"{ mindspore_dtype }   needs to be of type `mindspore.Type`, e.g. `mindspore.float16`, but is { type (mindspore_dtype )}  ." 
222+                 )
223+ 
224+         with  no_init_parameters ():
225+             model  =  cls .from_config (config , ** unused_kwargs )
226+ 
227+         # state_dict = None # state_dict may be passed as an additional key argument 
228+         if  state_dict  is  None :  # edits: only load model_file if state_dict is None 
229+             if  not  is_sharded :
230+                 # Time to load the checkpoint 
231+                 state_dict  =  load_state_dict (
232+                     resolved_model_file [0 ], disable_mmap = disable_mmap , dduf_entries = dduf_entries 
233+                 )
234+                 # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. 
235+                 model ._fix_state_dict_keys_on_load (state_dict )
236+ 
237+             if  is_sharded :
238+                 loaded_keys  =  sharded_metadata ["all_checkpoint_keys" ]
239+             else :
240+                 state_dict  =  _convert_state_dict (model , state_dict )
241+                 loaded_keys  =  list (state_dict .keys ())
242+ 
243+         (
244+             model ,
245+             missing_keys ,
246+             unexpected_keys ,
247+             mismatched_keys ,
248+             offload_index ,
249+             error_msgs ,
250+         ) =  cls ._load_pretrained_model (
251+             model ,
252+             state_dict ,
253+             resolved_model_file ,
254+             pretrained_model_name_or_path ,
255+             loaded_keys ,
256+             ignore_mismatched_sizes = ignore_mismatched_sizes ,
257+             dtype = mindspore_dtype ,
258+             keep_in_fp32_modules = keep_in_fp32_modules ,
259+             dduf_entries = dduf_entries ,
260+         )
261+         loading_info  =  {
262+             "missing_keys" : missing_keys ,
263+             "unexpected_keys" : unexpected_keys ,
264+             "mismatched_keys" : mismatched_keys ,
265+             "error_msgs" : error_msgs ,
266+         }
267+ 
268+         if  mindspore_dtype  is  not   None  and  not  use_keep_in_fp32_modules :
269+             model  =  model .to (mindspore_dtype )
270+ 
271+         model .register_to_config (_name_or_path = pretrained_model_name_or_path )
272+ 
273+         # Set model in evaluation mode to deactivate DropOut modules by default 
274+         model .set_train (False )
275+ 
276+         if  output_loading_info :
277+             return  model , loading_info 
278+ 
279+         return  model 
0 commit comments