@@ -44,7 +44,7 @@ def __init__(
4444 super ().__init__ (config , data_indices , statistics )
4545
4646 self .nan_locations = None
47- # weight imputed values wiht zero in loss calculation
47+ # weight imputed values with zero in loss calculation
4848 self .loss_mask_training = None
4949
5050 def _validate_indices (self ):
@@ -113,6 +113,12 @@ def get_nans(self, x: torch.Tensor) -> torch.Tensor:
113113 idx = [slice (0 , 1 )] * (x .ndim - 2 ) + [slice (None ), slice (None )]
114114 return torch .isnan (x [idx ].squeeze ())
115115
116+ def fill_with_value (self , x , index ):
117+ for idx_src , (idx_dst , value ) in zip (self .index_training_input , zip (index , self .replacement )):
118+ if idx_dst is not None :
119+ x [..., idx_dst ][self ._expand_subset_mask (x , idx_src )] = value
120+ return x
121+
116122 def transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
117123 """Impute missing values in the input tensor."""
118124 if not in_place :
@@ -145,10 +151,7 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
145151 )
146152
147153 # Replace values
148- for idx_src , (idx_dst , value ) in zip (self .index_training_input , zip (index , self .replacement )):
149- if idx_dst is not None :
150- x [..., idx_dst ][self ._expand_subset_mask (x , idx_src )] = value
151- return x
154+ return self .fill_with_value (x , index )
152155
153156 def inverse_transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
154157 """Impute missing values in the input tensor."""
@@ -231,13 +234,130 @@ def __init__(
231234 self ._validate_indices ()
232235
233236
237+ class CopyImputer (BaseImputer ):
238+ """Imputes missing values copying them from another variable.
239+ ```
240+ default: "none"
241+ variable_to_copy:
242+ - variable_missing_1
243+ - variable_missing_2
244+ ```
245+ """
246+
247+ def __init__ (
248+ self ,
249+ config = None ,
250+ data_indices : Optional [IndexCollection ] = None ,
251+ statistics : Optional [dict ] = None ,
252+ ) -> None :
253+ super ().__init__ (config , data_indices , statistics )
254+
255+ self ._create_imputation_indices ()
256+
257+ self ._validate_indices ()
258+
259+ def _create_imputation_indices (
260+ self ,
261+ ):
262+ """Create the indices for imputation."""
263+ name_to_index_training_input = self .data_indices .data .input .name_to_index
264+ name_to_index_inference_input = self .data_indices .model .input .name_to_index
265+ name_to_index_training_output = self .data_indices .data .output .name_to_index
266+ name_to_index_inference_output = self .data_indices .model .output .name_to_index
267+
268+ self .num_training_input_vars = len (name_to_index_training_input )
269+ self .num_inference_input_vars = len (name_to_index_inference_input )
270+ self .num_training_output_vars = len (name_to_index_training_output )
271+ self .num_inference_output_vars = len (name_to_index_inference_output )
272+
273+ (
274+ self .index_training_input ,
275+ self .index_inference_input ,
276+ self .index_training_output ,
277+ self .index_inference_output ,
278+ self .replacement ,
279+ ) = ([], [], [], [], [])
280+
281+ # Create indices for imputation
282+ for name in name_to_index_training_input :
283+ key_to_copy = self .methods .get (name , self .default )
284+
285+ if key_to_copy == "none" :
286+ LOGGER .debug (f"Imputer: skipping { name } as no imputation method is specified" )
287+ continue
288+
289+ self .index_training_input .append (name_to_index_training_input [name ])
290+ self .index_training_output .append (name_to_index_training_output .get (name , None ))
291+ self .index_inference_input .append (name_to_index_inference_input .get (name , None ))
292+ self .index_inference_output .append (name_to_index_inference_output .get (name , None ))
293+
294+ self .replacement .append (key_to_copy )
295+
296+ LOGGER .debug (f"Imputer: replacing NaNs in { name } with value coming from variable :{ self .replacement [- 1 ]} " )
297+
298+ def fill_with_value (self , x , index ):
299+ # Replace values
300+ for idx_src , (idx_dst , value ) in zip (self .index_training_input , zip (index , self .replacement )):
301+ if idx_dst is not None :
302+ assert not torch .isnan (
303+ x [..., self .data_indices .data .input .name_to_index [value ]][self ._expand_subset_mask (x , idx_src )]
304+ ).any (), f"NaNs found in { value } ."
305+ x [..., idx_dst ][self ._expand_subset_mask (x , idx_src )] = x [
306+ ..., self .data_indices .data .input .name_to_index [value ]
307+ ][self ._expand_subset_mask (x , idx_src )]
308+ return x
309+
310+ def transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
311+ """Impute missing values in the input tensor."""
312+ if not in_place :
313+ x = x .clone ()
314+
315+ # Initialize nan mask once
316+ if self .nan_locations is None :
317+
318+ # Get NaN locations
319+ self .nan_locations = self .get_nans (x )
320+
321+ # Initialize training loss mask to weigh imputed values with zeroes once
322+ self .loss_mask_training = torch .ones (
323+ (x .shape [- 2 ], len (self .data_indices .model .output .name_to_index )), device = x .device
324+ ) # shape (grid, n_outputs)
325+ # for all variables that are imputed and part of the model output, set the loss weight to zero
326+ for idx_src , idx_dst in zip (self .index_training_input , self .index_inference_output ):
327+ if idx_dst is not None :
328+ self .loss_mask_training [:, idx_dst ] = (~ self .nan_locations [:, idx_src ]).int ()
329+
330+ # Choose correct index based on number of variables
331+ if x .shape [- 1 ] == self .num_training_input_vars :
332+ index = self .index_training_input
333+ elif x .shape [- 1 ] == self .num_inference_input_vars :
334+ index = self .index_inference_input
335+ else :
336+ raise ValueError (
337+ f"Input tensor ({ x .shape [- 1 ]} ) does not match the training "
338+ f"({ self .num_training_input_vars } ) or inference shape ({ self .num_inference_input_vars } )" ,
339+ )
340+
341+ return self .fill_with_value (x , index )
342+
343+
234344class DynamicMixin :
235- """Mixin to add dynamic imputation behavior."""
345+ """
346+ Mixin to add dynamic imputation behavior.
347+ To be used when NaN maps change at different timesteps.
348+ """
236349
237350 def get_nans (self , x : torch .Tensor ) -> torch .Tensor :
238351 """Override to calculate NaN locations dynamically."""
239352 return torch .isnan (x )
240353
354+ def fill_with_value (self , x , index , nan_locations ):
355+ # Replace values
356+ for idx , value in zip (index , self .replacement ):
357+ if idx is not None :
358+ x [..., idx ][nan_locations [..., idx ]] = value
359+ return x
360+
241361 def transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
242362 """Impute missing values in the input tensor."""
243363 if not in_place :
@@ -261,12 +381,7 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
261381 f"({ self .num_training_input_vars } ) or inference shape ({ self .num_inference_input_vars } )" ,
262382 )
263383
264- # Replace values
265- for idx_src , (idx_dst , value ) in zip (self .index_training_input , zip (index , self .replacement )):
266- if idx_dst is not None :
267- x [..., idx_dst ][nan_locations [..., idx_src ]] = value
268-
269- return x
384+ return self .fill_with_value (x , index , nan_locations )
270385
271386 def inverse_transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
272387 """Impute missing values in the input tensor."""
@@ -282,7 +397,7 @@ def __init__(
282397 data_indices : Optional [IndexCollection ] = None ,
283398 statistics : Optional [dict ] = None ,
284399 ) -> None :
285- super () .__init__ (config , data_indices , statistics )
400+ InputImputer .__init__ (self , config , data_indices , statistics )
286401 warnings .warn (
287402 "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
288403 The model will be trained to predict imputed values. This might deteriorate performances."
@@ -298,8 +413,51 @@ def __init__(
298413 data_indices : Optional [IndexCollection ] = None ,
299414 statistics : Optional [dict ] = None ,
300415 ) -> None :
301- super ().__init__ (config , data_indices , statistics )
416+ ConstantImputer .__init__ (self , config , data_indices , statistics )
417+ warnings .warn (
418+ "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
419+ The model will be trained to predict imputed values. This might deteriorate performances."
420+ )
421+
422+
423+ class DynamicCopyImputer (DynamicMixin , CopyImputer ):
424+ """Dynamic Copy imputation behavior."""
425+
426+ def __init__ (
427+ self ,
428+ config = None ,
429+ data_indices : Optional [IndexCollection ] = None ,
430+ statistics : Optional [dict ] = None ,
431+ ) -> None :
432+ CopyImputer .__init__ (self , config , data_indices , statistics )
302433 warnings .warn (
303434 "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
304435 The model will be trained to predict imputed values. This might deteriorate performances."
305436 )
437+
438+ def fill_with_value (self , x , index , nan_locations ):
439+
440+ if x .shape [- 1 ] == self .num_training_input_vars :
441+ indices = self .data_indices .data .input .name_to_index
442+ elif x .shape [- 1 ] == self .num_inference_input_vars :
443+ indices = self .data_indices .model .input .name_to_index
444+ else :
445+ raise ValueError (
446+ f"Input tensor ({ x .shape [- 1 ]} ) does not match the training "
447+ f"({ self .num_training_input_vars } ) or inference shape ({ self .num_inference_input_vars } )" ,
448+ )
449+
450+ # Replace values
451+ for idx , value in zip (index , self .replacement ):
452+ if idx is not None :
453+ assert not torch .isnan (x [..., indices [value ]][nan_locations [..., idx ]]).any (), f"NaNs found in { value } ."
454+ x [..., idx ][nan_locations [..., idx ]] = x [..., indices [value ]][nan_locations [..., idx ]]
455+ return x
456+
457+ def transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
458+ """Impute missing values in the input tensor."""
459+ return DynamicMixin .transform (self , x , in_place )
460+
461+ def inverse_transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
462+ """Impute missing values in the input tensor."""
463+ return DynamicMixin .inverse_transform (self , x , in_place )
0 commit comments