@@ -594,112 +594,118 @@ def forecast(
594594 Any
595595 The forecasted state.
596596 """
597- self .model .eval ()
597+ # NOTE we are not using decorator of the top level function as we anticipate lazy torch load
598+ with torch .inference_mode ():
599+ self .model .eval ()
598600
599- torch .set_grad_enabled (False )
601+ # Create pytorch input tensor
602+ input_tensor_torch = torch .from_numpy (np .swapaxes (input_tensor_numpy , - 2 , - 1 )[np .newaxis , ...]).to (
603+ self .device
604+ )
600605
601- # Create pytorch input tensor
602- input_tensor_torch = torch .from_numpy (np .swapaxes (input_tensor_numpy , - 2 , - 1 )[np .newaxis , ...]).to (self .device )
606+ lead_time = to_timedelta (lead_time )
603607
604- lead_time = to_timedelta (lead_time )
605-
606- new_state = input_state .copy () # We should not modify the input state
607- new_state ["fields" ] = dict ()
608- new_state ["step" ] = to_timedelta (0 )
609-
610- start = input_state ["date" ]
611-
612- # The variable `check` is used to keep track of which variables have been updated
613- # In the input tensor. `reset` is used to reset `check` to False except
614- # when the values are of the constant in time variables
615-
616- reset = np .full ((input_tensor_torch .shape [- 1 ],), False )
617- variable_to_input_tensor_index = self .checkpoint .variable_to_input_tensor_index
618- typed_variables = self .checkpoint .typed_variables
619- for variable , i in variable_to_input_tensor_index .items ():
620- if typed_variables [variable ].is_constant_in_time :
621- reset [i ] = True
622-
623- check = reset .copy ()
624-
625- if self .verbosity > 0 :
626- self ._print_input_tensor ("First input tensor" , input_tensor_torch )
627-
628- for s , (step , date , next_date , is_last_step ) in enumerate (self .forecast_stepper (start , lead_time )):
629- title = f"Forecasting step { step } ({ date } )"
630-
631- new_state ["date" ] = date
632- new_state ["previous_step" ] = new_state .get ("step" )
633- new_state ["step" ] = step
634-
635- if self .trace :
636- self .trace .write_input_tensor (
637- date , s , input_tensor_torch .cpu ().numpy (), variable_to_input_tensor_index , self .checkpoint .timestep
638- )
608+ new_state = input_state .copy () # We should not modify the input state
609+ new_state ["fields" ] = dict ()
610+ new_state ["step" ] = to_timedelta (0 )
639611
640- # Predict next state of atmosphere
641- with (
642- torch .autocast (device_type = self .device .type , dtype = self .autocast ),
643- ProfilingLabel ("Predict step" , self .use_profiler ),
644- Timer (title ),
645- ):
646- y_pred = self .predict_step (self .model , input_tensor_torch , fcstep = s , step = step , date = date )
612+ start = input_state ["date" ]
647613
648- output = torch .squeeze (y_pred , dim = (0 , 1 )) # shape: (values, variables)
614+ # The variable `check` is used to keep track of which variables have been updated
615+ # In the input tensor. `reset` is used to reset `check` to False except
616+ # when the values are of the constant in time variables
649617
650- # Update state
651- with ProfilingLabel ("Updating state (CPU)" , self .use_profiler ):
652- for i in range (output .shape [1 ]):
653- new_state ["fields" ][self .checkpoint .output_tensor_index_to_variable [i ]] = output [:, i ]
618+ reset = np .full ((input_tensor_torch .shape [- 1 ],), False )
619+ variable_to_input_tensor_index = self .checkpoint .variable_to_input_tensor_index
620+ typed_variables = self .checkpoint .typed_variables
621+ for variable , i in variable_to_input_tensor_index .items ():
622+ if typed_variables [variable ].is_constant_in_time :
623+ reset [i ] = True
654624
655- if (s == 0 and self .verbosity > 0 ) or self .verbosity > 1 :
656- self ._print_output_tensor ("Output tensor" , output .cpu ().numpy ())
625+ check = reset .copy ()
657626
658- if self .trace :
659- self .trace .write_output_tensor (
660- date ,
661- s ,
662- output .cpu ().numpy (),
663- self .checkpoint .output_tensor_index_to_variable ,
664- self .checkpoint .timestep ,
665- )
627+ if self .verbosity > 0 :
628+ self ._print_input_tensor ("First input tensor" , input_tensor_torch )
666629
667- yield new_state
630+ for s , (step , date , next_date , is_last_step ) in enumerate (self .forecast_stepper (start , lead_time )):
631+ title = f"Forecasting step { step } ({ date } )"
668632
669- # No need to prepare next input tensor if we are at the last step
670- if is_last_step :
671- break
633+ new_state [ "date" ] = date
634+ new_state [ "previous_step" ] = new_state . get ( "step" )
635+ new_state [ "step" ] = step
672636
673- # Update tensor for next iteration
674- with ProfilingLabel ("Update tensor for next step" , self .use_profiler ):
675- check [:] = reset
676637 if self .trace :
677- self .trace .reset_sources (reset , self .checkpoint .variable_to_input_tensor_index )
678-
679- input_tensor_torch = self .copy_prognostic_fields_to_input_tensor (input_tensor_torch , y_pred , check )
680-
681- del y_pred # Recover memory
638+ self .trace .write_input_tensor (
639+ date ,
640+ s ,
641+ input_tensor_torch .cpu ().numpy (),
642+ variable_to_input_tensor_index ,
643+ self .checkpoint .timestep ,
644+ )
645+
646+ # Predict next state of atmosphere
647+ with (
648+ torch .autocast (device_type = self .device .type , dtype = self .autocast ),
649+ ProfilingLabel ("Predict step" , self .use_profiler ),
650+ Timer (title ),
651+ ):
652+ y_pred = self .predict_step (self .model , input_tensor_torch , fcstep = s , step = step , date = date )
653+
654+ output = torch .squeeze (y_pred , dim = (0 , 1 )) # shape: (values, variables)
655+
656+ # Update state
657+ with ProfilingLabel ("Updating state (CPU)" , self .use_profiler ):
658+ for i in range (output .shape [1 ]):
659+ new_state ["fields" ][self .checkpoint .output_tensor_index_to_variable [i ]] = output [:, i ]
660+
661+ if (s == 0 and self .verbosity > 0 ) or self .verbosity > 1 :
662+ self ._print_output_tensor ("Output tensor" , output .cpu ().numpy ())
682663
683- input_tensor_torch = self .add_dynamic_forcings_to_input_tensor (
684- input_tensor_torch , new_state , next_date , check
685- )
686- input_tensor_torch = self .add_boundary_forcings_to_input_tensor (
687- input_tensor_torch , new_state , next_date , check
688- )
689-
690- if not check .all ():
691- # Not all variables have been updated
692- missing = []
693- variable_to_input_tensor_index = self .checkpoint .variable_to_input_tensor_index
694- mapping = {v : k for k , v in variable_to_input_tensor_index .items ()}
695- for i in range (check .shape [- 1 ]):
696- if not check [i ]:
697- missing .append (mapping [i ])
698-
699- raise ValueError (f"Missing variables in input tensor: { sorted (missing )} " )
700-
701- if (s == 0 and self .verbosity > 0 ) or self .verbosity > 1 :
702- self ._print_input_tensor ("Next input tensor" , input_tensor_torch )
664+ if self .trace :
665+ self .trace .write_output_tensor (
666+ date ,
667+ s ,
668+ output .cpu ().numpy (),
669+ self .checkpoint .output_tensor_index_to_variable ,
670+ self .checkpoint .timestep ,
671+ )
672+
673+ yield new_state
674+
675+ # No need to prepare next input tensor if we are at the last step
676+ if is_last_step :
677+ break
678+
679+ # Update tensor for next iteration
680+ with ProfilingLabel ("Update tensor for next step" , self .use_profiler ):
681+ check [:] = reset
682+ if self .trace :
683+ self .trace .reset_sources (reset , self .checkpoint .variable_to_input_tensor_index )
684+
685+ input_tensor_torch = self .copy_prognostic_fields_to_input_tensor (input_tensor_torch , y_pred , check )
686+
687+ del y_pred # Recover memory
688+
689+ input_tensor_torch = self .add_dynamic_forcings_to_input_tensor (
690+ input_tensor_torch , new_state , next_date , check
691+ )
692+ input_tensor_torch = self .add_boundary_forcings_to_input_tensor (
693+ input_tensor_torch , new_state , next_date , check
694+ )
695+
696+ if not check .all ():
697+ # Not all variables have been updated
698+ missing = []
699+ variable_to_input_tensor_index = self .checkpoint .variable_to_input_tensor_index
700+ mapping = {v : k for k , v in variable_to_input_tensor_index .items ()}
701+ for i in range (check .shape [- 1 ]):
702+ if not check [i ]:
703+ missing .append (mapping [i ])
704+
705+ raise ValueError (f"Missing variables in input tensor: { sorted (missing )} " )
706+
707+ if (s == 0 and self .verbosity > 0 ) or self .verbosity > 1 :
708+ self ._print_input_tensor ("Next input tensor" , input_tensor_torch )
703709
704710 def copy_prognostic_fields_to_input_tensor (
705711 self , input_tensor_torch : torch .Tensor , y_pred : torch .Tensor , check : BoolArray
0 commit comments