-
Notifications
You must be signed in to change notification settings - Fork 166
Description
Motivation: describe the problem to be solved
Median absolute error and max absolute errors are useful metric indicating different changes for the monitoring. Additionally, median is, as known, a more robust metric than average and would thus be good to implement.
Describe the solution you'd like
New regression metrics, called MedianAbsolutError ("medae") and MaxAbsoluteError ("maxae") that is usable similarly to the other metrics.
Describe alternatives you've considered
I tried implementing a custom metric myself, but it's tricky or not very clean code. This is probably also something other people could make use of.
Additional context
I asked copilot to create the metric (which I intended to use but clearly it was not possible) so if you want here is what it suggested. It could also be hallucinating, so if you prefer not to view or use that, it's understandable.
AI suggestion
@MetricFactory.register("medae", ProblemType.REGRESSION)
class MedianError(Metric):
"""Estimate regression performance using Median Absolute Error metric.
This metric calculates the median of the absolute differences between true and predicted values.
It is more robust to outliers compared to Mean Absolute Error (MAE).
"""
def __init__(
self,
feature_column_names: list[str],
y_true: str,
y_pred: str,
chunker: Chunker,
threshold: Threshold,
tune_hyperparameters: bool,
hyperparameter_tuning_config: dict[str, Any],
hyperparameters: dict[str, Any],
):
"""Creates a new Median Absolute Error (MedAE) Metric instance.
Parameters
----------
feature_column_names: list[str]
A list of column names indicating which columns contain feature values.
y_true: str
The name of the column containing target values (that are provided in reference data during fitting).
y_pred: str
The name of the column containing your model predictions.
chunker: Chunker
The `Chunker` used to split the data sets into a lists of chunks.
tune_hyperparameters: bool
A boolean controlling whether hypertuning should be performed on the internal regressor models
whilst fitting on reference data.
Tuning hyperparameters takes some time and does not guarantee better results, hence it defaults to `False`.
hyperparameter_tuning_config: dict[str, Any]
A dictionary that allows you to provide a custom hyperparameter tuning configuration when
`tune_hyperparameters` has been set to `True`.
hyperparameters: dict[str, Any]
A dictionary used to provide your own custom hyperparameters when `tune_hyperparameters` has
been set to `True`.
threshold: Threshold
The Threshold instance that determines how the lower and upper threshold values will be calculated.
"""
super().__init__(
display_name="MedAE",
column_name="medae",
feature_column_names=feature_column_names,
y_true=y_true,
y_pred=y_pred,
chunker=chunker,
threshold=threshold,
tune_hyperparameters=tune_hyperparameters,
hyperparameter_tuning_config=hyperparameter_tuning_config,
hyperparameters=hyperparameters,
)
# Store reference data statistics for sampling error calculation
self._reference_abs_errors = None
def _fit(self, reference_data: pd.DataFrame):
# Filter out NaN values
reference_data, empty = common_nan_removal(
reference_data, [self.y_true, self.y_pred]
)
if empty:
raise InvalidReferenceDataException(
f"Cannot fit DLE for {self.display_name}, too many missing values for predictions and targets."
)
y_true = reference_data[self.y_true]
y_pred = reference_data[self.y_pred]
# Calculate absolute errors for observation level metrics
abs_errors = abs(y_true - y_pred)
self._reference_abs_errors = abs_errors
# Store basic statistics of reference data for sampling error calculation
self._reference_median = np.median(abs_errors)
self._reference_std = np.std(abs_errors)
self._reference_n = len(abs_errors)
# Train the model to predict absolute errors
self._dee_model = self._train_direct_error_estimation_model(
X_train=reference_data[self.feature_column_names + [self.y_pred]],
y_train=abs_errors,
tune_hyperparameters=self.tune_hyperparameters,
hyperparameter_tuning_config=self.hyperparameter_tuning_config,
hyperparameters=self.hyperparameters,
categorical_column_names=self.categorical_column_names,
)
def _estimate(self, data: pd.DataFrame):
# Predict absolute errors at the observation level
observation_level_estimates = self._dee_model.predict(
X=data[self.feature_column_names + [self.y_pred]]
)
# Clip negative predictions to 0
observation_level_estimates = np.maximum(0, observation_level_estimates)
# Calculate the median of predicted absolute errors
chunk_level_estimate = np.median(observation_level_estimates)
return chunk_level_estimate
def _sampling_error(self, data: pd.DataFrame) -> float:
# We only expect predictions to be present and estimate sampling error based on them
data, empty = common_nan_removal(data[[self.y_pred]], [self.y_pred])
if empty:
return np.nan
# Simple sampling error estimation
# For median, we use a simplified approach based on the standard error of the median
# which is approximately 1.253 * std / sqrt(n)
n = len(data)
# Approximation of the standard error of the median
sampling_error = (
1.253 * self._reference_std / np.sqrt(min(n, self._reference_n))
)
return sampling_error
def realized_performance(self, data: pd.DataFrame) -> float:
"""Calculates the realized median absolute error of a model for a given chunk of data.
The data needs to have both prediction and real targets.
Parameters
----------
data: pd.DataFrame
The data to calculate the realized performance on.
Returns
-------
medae: float
Median Absolute Error
"""
if self.y_true not in data.columns:
return np.nan
data, empty = common_nan_removal(
data[[self.y_true, self.y_pred]], [self.y_true, self.y_pred]
)
if empty:
return np.nan
y_true = data[self.y_true]
y_pred = data[self.y_pred]
# Calculate the median of absolute errors
return np.median(np.abs(y_true - y_pred))
Additionally I forked the repo and implemented both metrics (using Copilot as well), which seems to work fine. But as I am very unfamiliar with the repo and the ideas behind some of the functions I cannot verify that they work as thought.