|
28 | 28 | OneInputUnionType,
|
29 | 29 | OneOutputUnionType,
|
30 | 30 | ManyInputsUnionType,
|
| 31 | + ManyOutputsUnionType, |
31 | 32 | many_inputs_convert,
|
32 | 33 | data_conversion_docstring,
|
| 34 | + many_outputs_convert, |
33 | 35 | )
|
34 | 36 |
|
35 | 37 | from org.kie.trustyai.explainability.local.shap import (
|
36 | 38 | ShapConfig as _ShapConfig,
|
37 | 39 | ShapKernelExplainer as _ShapKernelExplainer,
|
38 | 40 | )
|
| 41 | + |
| 42 | +from org.kie.trustyai.explainability.local.shap.background import ( |
| 43 | + RandomGenerator, |
| 44 | + KMeansGenerator, |
| 45 | + CounterfactualGenerator, |
| 46 | +) |
| 47 | + |
39 | 48 | from org.kie.trustyai.explainability.model import (
|
40 | 49 | PredictionProvider,
|
41 | 50 | Saliency,
|
@@ -410,6 +419,151 @@ def _get_bokeh_plot_dict(self):
|
410 | 419 | }
|
411 | 420 |
|
412 | 421 |
|
| 422 | +class BackgroundGenerator: |
| 423 | + r"""Generate a background for the SHAP explainer via one of three algorithms: |
| 424 | +
|
| 425 | + * `sample`: Randomly sample a set of provided points |
| 426 | + * `kmeans`: Summarize a set of provided points into k centroids |
| 427 | + * `counterfactual`: Generate a set of background points that meet certain criteria |
| 428 | +
|
| 429 | + """ |
| 430 | + |
| 431 | + @data_conversion_docstring("many_inputs") |
| 432 | + def __init__(self, datapoints: ManyInputsUnionType, feature_domains=None, seed=0): |
| 433 | + r"""Initialize the :class:`BackgroundGenerator`. |
| 434 | +
|
| 435 | + Parameters |
| 436 | + ---------- |
| 437 | + datapoints : {} |
| 438 | + The set of datapoints to be used to sample/generate the background, as a: {} |
| 439 | + seed : int |
| 440 | + The random seed to use in the sampling/generation method |
| 441 | + """ |
| 442 | + self.datapoints = many_inputs_convert(datapoints, feature_domains) |
| 443 | + self.feature_domains = feature_domains |
| 444 | + self.seed = 0 |
| 445 | + self._jrandom = Random() |
| 446 | + self._jrandom.setSeed(self.seed) |
| 447 | + |
| 448 | + def sample(self, k=100): |
| 449 | + r"""Randomly sample datapoints. |
| 450 | +
|
| 451 | + Parameters |
| 452 | + ---------- |
| 453 | + k : int |
| 454 | + The number of datapoints to select |
| 455 | +
|
| 456 | + Returns |
| 457 | + ------- |
| 458 | + :list:`PredictionInput` |
| 459 | + The background dataset to pass to the :class:`~SHAPExplainer` |
| 460 | + """ |
| 461 | + perturbation_context = PerturbationContext(self._jrandom, 0) |
| 462 | + return RandomGenerator(self.datapoints, perturbation_context).generate(k) |
| 463 | + |
| 464 | + def kmeans(self, k=100): |
| 465 | + r"""Use k-means clustering over `datapoints` and return k centroids as the background data |
| 466 | + set. |
| 467 | +
|
| 468 | + Parameters |
| 469 | + ---------- |
| 470 | + k : int |
| 471 | + The number of centroids to find |
| 472 | +
|
| 473 | + Returns |
| 474 | + ------- |
| 475 | + :list:`PredictionInput` |
| 476 | + The background dataset to pass to the :class:`~SHAPExplainer` |
| 477 | + """ |
| 478 | + return KMeansGenerator(self.datapoints, self.seed).generate(k) |
| 479 | + |
| 480 | + @data_conversion_docstring("many_outputs") |
| 481 | + def counterfactual( |
| 482 | + self, |
| 483 | + goals: ManyOutputsUnionType, |
| 484 | + model: PredictionProvider, |
| 485 | + k_per_goal=100, |
| 486 | + **kwargs, |
| 487 | + ): |
| 488 | + r"""Generate a background via the CounterfactualExplainer. This lets you specify |
| 489 | + exact output values that the background dataset conforms to, and thus set the reference |
| 490 | + point by which all SHAP values compare. For example, if your model is a regression |
| 491 | + model, choosing a counterfactual goal of 0 will create a background dataset where |
| 492 | + :math:'f(x) \approx 0 \forall x \in \text{{background}}`, and as such the SHAP values |
| 493 | + will compare against zero, which is a useful baseline for regression. |
| 494 | +
|
| 495 | + Parameters |
| 496 | + ---------- |
| 497 | + goals : {} |
| 498 | + The set of background datapoints as a: {} |
| 499 | + model : :obj:`~trustyai.model.PredictionProvider` |
| 500 | + The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model` |
| 501 | + k_per_goal : int |
| 502 | + The number of background datapoints to generate per goal. |
| 503 | + Keyword Arguments: |
| 504 | + * k_seeds: int |
| 505 | + (default=5) For each goal, a number of starting seeds from `datapoints` are used |
| 506 | + to start the search from. These are the `k_seeds` points within `datapoint` |
| 507 | + whose corresponding outputs are closet to the goal output. Choose a larger |
| 508 | + number to get a more diverse background dataset, but the search might require |
| 509 | + larger `max_attempt_count`, `step_count`, and `timeout_seconds` to get good results. |
| 510 | + * goal_threshold: float |
| 511 | + (default=.01) The distance (percentage) threshold defining whether |
| 512 | + a particular output satisfies the goal. Set to 0 to require an exact match, but |
| 513 | + this will likey require larger `max_attempt_count`, `step_count`, |
| 514 | + and `timeout_seconds` to get good results. |
| 515 | + * chain: boolean |
| 516 | + (default=False) If chaining is set to `true`, found counterfactual datapoints |
| 517 | + will be added to the search seeds for subsequent searches. This is useful when a |
| 518 | + range of counterfactual outputs is desired; for example, if the desired goals are |
| 519 | + [0, 1, 2, 3], whichever goal is closest to the closest point within `datapoints` will |
| 520 | + be searched for first. The found counterfactuals from that search are then included |
| 521 | + in the search for the second-closest goal, and so on. This is especially helpful |
| 522 | + if the extremes of the goal range are far outside the range produced by the |
| 523 | + `datapoints`. If only |
| 524 | + * max_attempt_count: int |
| 525 | + If no valid counterfactual can be found for a starting seed in the search, the point |
| 526 | + is slightly perturbed and search is retried. This parameter sets the maximum |
| 527 | + number of perturbation-retry cycles are allowed during generation. |
| 528 | + * step_count: int |
| 529 | + (default=10,000) The number of datapoints to evaluate during the search |
| 530 | + * timeout_seconds: int |
| 531 | + (default=30) The maximum number of seconds allowed for each counterfactual search |
| 532 | +
|
| 533 | + Returns |
| 534 | + ------- |
| 535 | + :list:`PredictionInput` |
| 536 | + The background dataset to pass to the :class:`~SHAPExplainer` |
| 537 | + """ |
| 538 | + if self.feature_domains is None: |
| 539 | + raise AttributeError( |
| 540 | + "Feature domains must be passed to perform" |
| 541 | + " meaningful counterfactual search" |
| 542 | + ) |
| 543 | + goals_converted = many_outputs_convert(goals) |
| 544 | + generator = ( |
| 545 | + CounterfactualGenerator.builder() |
| 546 | + .withModel(model) |
| 547 | + .withKSeeds(kwargs.get("k_seeds", 5)) |
| 548 | + .withRandom(self._jrandom) |
| 549 | + .withTimeoutSeconds(kwargs.get("timeout_seconds", 3)) |
| 550 | + .withStepCount(kwargs.get("step_count", 5_000)) |
| 551 | + .withGoalThreshold(kwargs.get("goal_threshold", 0.01)) |
| 552 | + .withMaxAttemptCount(kwargs.get("max_attempt_count", 5)) |
| 553 | + .build() |
| 554 | + ) |
| 555 | + |
| 556 | + if len(goals) == 1: |
| 557 | + background = generator.generate( |
| 558 | + self.datapoints, goals_converted[0], k_per_goal |
| 559 | + ) |
| 560 | + else: |
| 561 | + background = generator.generateRange( |
| 562 | + self.datapoints, goals_converted, k_per_goal, kwargs.get("chain", False) |
| 563 | + ) |
| 564 | + return background |
| 565 | + |
| 566 | + |
413 | 567 | class SHAPExplainer:
|
414 | 568 | r"""*"By how much did each feature contribute to the outputs?"*
|
415 | 569 |
|
@@ -511,6 +665,8 @@ def explain(
|
511 | 665 | outputs : {}
|
512 | 666 | The corresponding model outputs for the provided features, that is,
|
513 | 667 | ``outputs = model(input_features)``. These can take the form of a: {}
|
| 668 | + model : :obj:`~trustyai.model.PredictionProvider` |
| 669 | + The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model` |
514 | 670 |
|
515 | 671 | Returns
|
516 | 672 | -------
|
|
0 commit comments