|
17 | 17 | from ads.opctl.operator.lowcode.forecast.const import (
|
18 | 18 | AUTOMLX_METRIC_MAP,
|
19 | 19 | ForecastOutputColumns,
|
| 20 | + SpeedAccuracyMode, |
20 | 21 | SupportedModels,
|
21 | 22 | )
|
22 | 23 | from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
|
@@ -245,18 +246,18 @@ def _generate_report(self):
|
245 | 246 | # If the key is present, call the "explain_model" method
|
246 | 247 | self.explain_model()
|
247 | 248 |
|
248 |
| - # Convert the global explanation data to a DataFrame |
249 |
| - global_explanation_df = pd.DataFrame(self.global_explanation) |
| 249 | + global_explanation_section = None |
| 250 | + if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX: |
| 251 | + # Convert the global explanation data to a DataFrame |
| 252 | + global_explanation_df = pd.DataFrame(self.global_explanation) |
250 | 253 |
|
251 |
| - self.formatted_global_explanation = ( |
252 |
| - global_explanation_df / global_explanation_df.sum(axis=0) * 100 |
253 |
| - ) |
254 |
| - self.formatted_global_explanation = ( |
255 |
| - self.formatted_global_explanation.rename( |
| 254 | + self.formatted_global_explanation = ( |
| 255 | + global_explanation_df / global_explanation_df.sum(axis=0) * 100 |
| 256 | + ) |
| 257 | + self.formatted_global_explanation = self.formatted_global_explanation.rename( |
256 | 258 | {self.spec.datetime_column.name: ForecastOutputColumns.DATE},
|
257 | 259 | axis=1,
|
258 | 260 | )
|
259 |
| - ) |
260 | 261 |
|
261 | 262 | aggregate_local_explanations = pd.DataFrame()
|
262 | 263 | for s_id, local_ex_df in self.local_explanation.items():
|
@@ -297,8 +298,11 @@ def _generate_report(self):
|
297 | 298 | )
|
298 | 299 |
|
299 | 300 | # Append the global explanation text and section to the "other_sections" list
|
| 301 | + if global_explanation_section: |
| 302 | + other_sections.append(global_explanation_section) |
| 303 | + |
| 304 | + # Append the local explanation text and section to the "other_sections" list |
300 | 305 | other_sections = other_sections + [
|
301 |
| - global_explanation_section, |
302 | 306 | local_explanation_section,
|
303 | 307 | ]
|
304 | 308 | except Exception as e:
|
@@ -379,3 +383,79 @@ def _custom_predict_automlx(self, data):
|
379 | 383 | return self.models.get(self.series_id).forecast(
|
380 | 384 | X=data_temp, periods=data_temp.shape[0]
|
381 | 385 | )[self.series_id]
|
| 386 | + |
| 387 | + @runtime_dependency( |
| 388 | + module="automlx", |
| 389 | + err_msg=( |
| 390 | + "Please run `python3 -m pip install automlx` to install the required dependencies for model explanation." |
| 391 | + ), |
| 392 | + ) |
| 393 | + def explain_model(self): |
| 394 | + """ |
| 395 | + Generates explanations for the model using the AutoMLx library. |
| 396 | +
|
| 397 | + Parameters |
| 398 | + ---------- |
| 399 | + None |
| 400 | +
|
| 401 | + Returns |
| 402 | + ------- |
| 403 | + None |
| 404 | +
|
| 405 | + Notes |
| 406 | + ----- |
| 407 | + This function works by generating local explanations for each series in the dataset. |
| 408 | + It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions |
| 409 | + for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary. |
| 410 | +
|
| 411 | + If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations. |
| 412 | + Otherwise, it falls back to the default explanation generation method. |
| 413 | + """ |
| 414 | + import automlx |
| 415 | + |
| 416 | + # Loop through each series in the dataset |
| 417 | + for s_id, data_i in self.datasets.get_data_by_series( |
| 418 | + include_horizon=False |
| 419 | + ).items(): |
| 420 | + try: |
| 421 | + if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX: |
| 422 | + # Use the MLExplainer class from AutoMLx to generate explanations |
| 423 | + explainer = automlx.MLExplainer( |
| 424 | + self.models[s_id], |
| 425 | + self.datasets.additional_data.get_data_for_series(series_id=s_id) |
| 426 | + .drop(self.spec.datetime_column.name, axis=1) |
| 427 | + .head(-self.spec.horizon) |
| 428 | + if self.spec.additional_data |
| 429 | + else None, |
| 430 | + pd.DataFrame(data_i[self.spec.target_column]), |
| 431 | + task="forecasting", |
| 432 | + ) |
| 433 | + |
| 434 | + # Generate explanations for the forecast |
| 435 | + explanations = explainer.explain_prediction( |
| 436 | + X=self.datasets.additional_data.get_data_for_series(series_id=s_id) |
| 437 | + .drop(self.spec.datetime_column.name, axis=1) |
| 438 | + .tail(self.spec.horizon) |
| 439 | + if self.spec.additional_data |
| 440 | + else None, |
| 441 | + forecast_timepoints=list(range(self.spec.horizon + 1)), |
| 442 | + ) |
| 443 | + |
| 444 | + # Convert the explanations to a DataFrame |
| 445 | + explanations_df = pd.concat( |
| 446 | + [exp.to_dataframe() for exp in explanations] |
| 447 | + ) |
| 448 | + explanations_df["row"] = explanations_df.groupby("Feature").cumcount() |
| 449 | + explanations_df = explanations_df.pivot( |
| 450 | + index="row", columns="Feature", values="Attribution" |
| 451 | + ) |
| 452 | + explanations_df = explanations_df.reset_index(drop=True) |
| 453 | + |
| 454 | + # Store the explanations in the local_explanation dictionary |
| 455 | + self.local_explanation[s_id] = explanations_df |
| 456 | + else: |
| 457 | + # Fall back to the default explanation generation method |
| 458 | + super().explain_model() |
| 459 | + except Exception as e: |
| 460 | + logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.") |
| 461 | + logger.debug(f"Full Traceback: {traceback.format_exc()}") |
0 commit comments