@@ -247,17 +247,19 @@ def _generate_report(self):
247
247
self .explain_model ()
248
248
249
249
global_explanation_section = None
250
- if self .spec .explanations_accuracy_mode != SpeedAccuracyMode .AUTOMLX :
251
- # Convert the global explanation data to a DataFrame
252
- global_explanation_df = pd .DataFrame (self .global_explanation )
253
250
254
- self .formatted_global_explanation = (
255
- global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
256
- )
257
- self .formatted_global_explanation = self .formatted_global_explanation .rename (
251
+ # Convert the global explanation data to a DataFrame
252
+ global_explanation_df = pd .DataFrame (self .global_explanation )
253
+
254
+ self .formatted_global_explanation = (
255
+ global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
256
+ )
257
+ self .formatted_global_explanation = (
258
+ self .formatted_global_explanation .rename (
258
259
{self .spec .datetime_column .name : ForecastOutputColumns .DATE },
259
260
axis = 1 ,
260
261
)
262
+ )
261
263
262
264
aggregate_local_explanations = pd .DataFrame ()
263
265
for s_id , local_ex_df in self .local_explanation .items ():
@@ -269,11 +271,15 @@ def _generate_report(self):
269
271
self .formatted_local_explanation = aggregate_local_explanations
270
272
271
273
if not self .target_cat_col :
272
- self .formatted_global_explanation = self .formatted_global_explanation .rename (
273
- {"Series 1" : self .original_target_column },
274
- axis = 1 ,
274
+ self .formatted_global_explanation = (
275
+ self .formatted_global_explanation .rename (
276
+ {"Series 1" : self .original_target_column },
277
+ axis = 1 ,
278
+ )
279
+ )
280
+ self .formatted_local_explanation .drop (
281
+ "Series" , axis = 1 , inplace = True
275
282
)
276
- self .formatted_local_explanation .drop ("Series" , axis = 1 , inplace = True )
277
283
278
284
# Create a markdown section for the global explainability
279
285
global_explanation_section = rc .Block (
@@ -422,7 +428,9 @@ def explain_model(self):
422
428
# Use the MLExplainer class from AutoMLx to generate explanations
423
429
explainer = automlx .MLExplainer (
424
430
self .models [s_id ],
425
- self .datasets .additional_data .get_data_for_series (series_id = s_id )
431
+ self .datasets .additional_data .get_data_for_series (
432
+ series_id = s_id
433
+ )
426
434
.drop (self .spec .datetime_column .name , axis = 1 )
427
435
.head (- self .spec .horizon )
428
436
if self .spec .additional_data
@@ -433,7 +441,9 @@ def explain_model(self):
433
441
434
442
# Generate explanations for the forecast
435
443
explanations = explainer .explain_prediction (
436
- X = self .datasets .additional_data .get_data_for_series (series_id = s_id )
444
+ X = self .datasets .additional_data .get_data_for_series (
445
+ series_id = s_id
446
+ )
437
447
.drop (self .spec .datetime_column .name , axis = 1 )
438
448
.tail (self .spec .horizon )
439
449
if self .spec .additional_data
@@ -445,7 +455,9 @@ def explain_model(self):
445
455
explanations_df = pd .concat (
446
456
[exp .to_dataframe () for exp in explanations ]
447
457
)
448
- explanations_df ["row" ] = explanations_df .groupby ("Feature" ).cumcount ()
458
+ explanations_df ["row" ] = explanations_df .groupby (
459
+ "Feature"
460
+ ).cumcount ()
449
461
explanations_df = explanations_df .pivot (
450
462
index = "row" , columns = "Feature" , values = "Attribution"
451
463
)
@@ -454,14 +466,17 @@ def explain_model(self):
454
466
# Store the explanations in the local_explanation dictionary
455
467
self .local_explanation [s_id ] = explanations_df
456
468
457
- self .global_explanation [s_id ] = dict (zip (
458
- data_i .columns [1 :],
459
- np .average (np .absolute (explanations_df [:, 1 :]), axis = 0 ),
469
+ self .global_explanation [s_id ] = dict (
470
+ zip (
471
+ self .local_explanation [s_id ].columns ,
472
+ np .nanmean ((self .local_explanation [s_id ]), axis = 0 ),
460
473
)
461
474
)
462
475
else :
463
476
# Fall back to the default explanation generation method
464
477
super ().explain_model ()
465
478
except Exception as e :
466
- logger .warning (f"Failed to generate explanations for series { s_id } with error: { e } ." )
479
+ logger .warning (
480
+ f"Failed to generate explanations for series { s_id } with error: { e } ."
481
+ )
467
482
logger .debug (f"Full Traceback: { traceback .format_exc ()} " )
0 commit comments