|
736 | 736 | _beeswarm_docstring = """
|
737 | 737 | Beeswarm plot with `SHAP-style <https://shap.readthedocs.io/en/latest/generated/shap.plots.beeswarm.html#shap.plots.beeswarm>`_ feature value coloring.
|
738 | 738 |
|
739 |
| -%(plot.scatter)s |
740 |
| -color_values : array-like, optional |
741 |
| - Values to use for coloring points. Should match the shape of the data. |
742 |
| - Enables SHAP-style feature value coloring. |
743 |
| -color_by_feature : array-like, optional |
744 |
| - Alias for color_values. Values to color points by (e.g., feature values). |
745 |
| -n_iter: int, default: 50 |
746 |
| - Number of iterations for the beeswarm algorithm. More iterations can lead long time to plot but better point separation. |
| 739 | +Parameters |
| 740 | +---------- |
| 741 | +data: array-like |
| 742 | + The data to be plotted. It is assumed the shape of `data` is (N, M) where N is the number of points and M is the number of features. |
| 743 | +levels: array-like, optional |
| 744 | + The levels to use for the beeswarm plot. If not provided, the levels are automatically determined based on the data. |
| 745 | +n_bins: int or array-like, default: 50 |
| 746 | + Number of bins to use to reduce the overlap between points. |
| 747 | + Bins are used to determine how crowded the points are for each level of the `y` coordinate. |
| 748 | + s, size, ms, markersize : float or array-like or unit-spec, optional |
| 749 | + The marker size area(s). If this is an array matching the shape of `x` and `y`, |
| 750 | + the units are scaled by `smin` and `smax`. If this contains unit string(s), it |
| 751 | + is processed by `~ultraplot.utils.units` and represents the width rather than area. |
| 752 | + c, color, colors, mc, markercolor, markercolors, fc, facecolor, facecolors \ |
| 753 | + : array-like or color-spec, optional |
| 754 | + The marker color(s). If this is an array matching the shape of `x` and `y`, |
| 755 | + the colors are generated using `cmap`, `norm`, `vmin`, and `vmax`. Otherwise, |
| 756 | + this should be a valid matplotlib color. |
| 757 | + smin, smax : float, optional |
| 758 | + The minimum and maximum marker size area in units ``points ** 2``. Ignored |
| 759 | + if `absolute_size` is ``True``. Default value for `smin` is ``1`` and for |
| 760 | + `smax` is the square of :rc:`lines.markersize`. |
| 761 | + area_size : bool, default: True |
| 762 | + Whether the marker sizes `s` are scaled by area or by radius. The default |
| 763 | + ``True`` is consistent with matplotlib. When `absolute_size` is ``True``, |
| 764 | + the `s` units are ``points ** 2`` if `area_size` is ``True`` and ``points`` |
| 765 | + if `area_size` is ``False``. |
| 766 | + absolute_size : bool, default: True or False |
| 767 | + Whether `s` should be taken to represent "absolute" marker sizes in units |
| 768 | + ``points`` or ``points ** 2`` or "relative" marker sizes scaled by `smin` |
| 769 | + and `smax`. Default is ``True`` if `s` is scalar and ``False`` if `s` is |
| 770 | + array-like or `smin` or `smax` were passed. |
| 771 | + %(plot.vmin_vmax)s |
| 772 | + %(plot.args_1d_shared)s |
| 773 | +
|
| 774 | + Other parameters |
| 775 | + ---------------- |
| 776 | + %(plot.cmap_norm)s |
| 777 | + %(plot.levels_manual)s |
| 778 | + %(plot.levels_auto)s |
| 779 | + %(plot.cycle)s |
| 780 | + lw, linewidth, linewidths, mew, markeredgewidth, markeredgewidths \ |
| 781 | + : float or sequence, optional |
| 782 | + The marker edge width(s). |
| 783 | + edgecolors, markeredgecolor, markeredgecolors \ |
| 784 | + : color-spec or sequence, optional |
| 785 | + The marker edge color(s). |
| 786 | + %(plot.error_means_{y})s |
| 787 | + %(plot.error_bars)s |
| 788 | + %(plot.error_shading)s |
| 789 | + %(plot.inbounds)s |
| 790 | + %(plot.labels_1d)s |
| 791 | + %(plot.guide)s |
| 792 | + **kwargs |
| 793 | + Passed to `~matplotlib.axes.Axes.scatter`. |
| 794 | +
|
| 795 | + See also |
| 796 | + -------- |
| 797 | + PlotAxes.scatter |
| 798 | + PlotAxes.scatterx |
| 799 | + matplotlib.axes.Axes.scatter |
747 | 800 | """
|
748 |
| -docstring._snippet_manager["plot.beeswarm"] = _scatter_docstring.format(y="y") |
| 801 | +docstring._snippet_manager["plot.beeswarm"] = _beeswarm_docstring.format(y="y") |
749 | 802 |
|
750 | 803 | # Bar function docstring
|
751 | 804 | _bar_docstring = """
|
@@ -3440,185 +3493,108 @@ def _apply_lollipop(
|
3440 | 3493 | return patch_collection, line_collection
|
3441 | 3494 |
|
3442 | 3495 | @docstring._snippet_manager
|
3443 |
| - def beeswarm(self, *args, color_values=None, color_by_feature=None, **kwargs): |
| 3496 | + def beeswarm(self, *args, **kwargs): |
3444 | 3497 | """
|
3445 | 3498 | %(plot.beeswarm)s
|
3446 |
| -
|
3447 |
| - Parameters |
3448 |
| - ---------- |
3449 |
| - color_values : array-like, optional |
3450 |
| - Values to use for coloring points. Should match the shape of the data. |
3451 |
| - Enables SHAP-style feature value coloring. |
3452 |
| - color_by_feature : array-like, optional |
3453 |
| - Alias for color_values. Values to color points by (e.g., feature values). |
3454 |
| - """ |
3455 |
| - # Allow orientation to be overridden in kwargs |
3456 |
| - orientation = kwargs.pop("orientation", "horizontal") |
| 3499 | + """ |
3457 | 3500 | return self._apply_beeswarm(
|
3458 | 3501 | *args,
|
3459 |
| - orientation=orientation, |
3460 |
| - color_values=color_values, |
3461 |
| - color_by_feature=color_by_feature, |
3462 | 3502 | **kwargs,
|
3463 | 3503 | )
|
3464 | 3504 |
|
3465 |
| - @inputs._preprocess_or_redirect("x", "y", allow_extra=True) |
3466 | 3505 | def _apply_beeswarm(
|
3467 | 3506 | self,
|
3468 |
| - x: np.ndarray, |
3469 |
| - y: np.ndarray, |
| 3507 | + data: np.ndarray, |
| 3508 | + levels: np.ndarray = None, |
| 3509 | + feature_values: np.ndarray = None, |
| 3510 | + ss: float | np.ndarray = None, |
3470 | 3511 | orientation: str = "horizontal",
|
3471 |
| - n_iter: int = 50, |
3472 |
| - *args, |
| 3512 | + n_bins: int = 50, |
3473 | 3513 | **kwargs,
|
3474 | 3514 | ) -> "Collection":
|
3475 | 3515 |
|
3476 |
| - cmap = kwargs.pop("cmap", rc["cmap.diverging"]) |
3477 |
| - size = kwargs.pop("s", kwargs.pop("size", 20)) # Default marker size |
| 3516 | + # Parse input parameters |
| 3517 | + ss, _ = self._parse_markersize(ss, **kwargs) |
3478 | 3518 | colorbar = kwargs.pop("colorbar", False)
|
3479 | 3519 | colorbar_kw = kwargs.pop("colorbar_kw", {})
|
3480 | 3520 |
|
3481 |
| - # Feature value coloring support (SHAP-style) |
3482 |
| - color_values = kwargs.pop("color_values", None) |
3483 |
| - color_by_feature = kwargs.pop("color_by_feature", None) |
3484 |
| - if color_by_feature is not None: |
3485 |
| - color_values = color_by_feature # Alias for SHAP-style naming |
3486 |
| - |
| 3521 | + flatten = False |
| 3522 | + if data.ndim == 1: |
| 3523 | + flatten = True |
| 3524 | + data = np.atleast_2d(data) |
| 3525 | + n_points, n_features = data.shape[:2] |
3487 | 3526 | # Convert to numpy arrays
|
3488 |
| - x = np.asarray(x) |
3489 |
| - y = np.asarray(y) |
3490 |
| - |
3491 |
| - # Handle 2D y array (multiple series) |
3492 |
| - if y.ndim == 2: |
3493 |
| - # x should be 1D with length matching y.shape[1] |
3494 |
| - if x.shape[-1] != y.shape[1]: |
3495 |
| - raise ValueError( |
3496 |
| - "For 2D y array, x must be 1D with length matching y.shape[1]" |
3497 |
| - ) |
3498 |
| - |
3499 |
| - # Flatten y and repeat x for each series |
3500 |
| - n_series, n_points = y.shape |
3501 |
| - x_flat = x.flatten() |
3502 |
| - if x.ndim == 1: |
3503 |
| - x_flat = np.tile(x, n_series) |
3504 |
| - |
3505 |
| - y_flat = y.flatten() |
3506 |
| - |
3507 |
| - # Handle color values for 2D case |
3508 |
| - if color_values is not None: |
3509 |
| - color_values = np.asarray(color_values) |
3510 |
| - if color_values.shape == y.shape: |
3511 |
| - # Color values match y shape - flatten them |
3512 |
| - color_flat = color_values.flatten() |
3513 |
| - elif len(color_values) == len(y_flat): |
3514 |
| - # Already flattened |
3515 |
| - color_flat = color_values |
3516 |
| - else: |
3517 |
| - raise ValueError( |
3518 |
| - "color_values must match the shape of y or be flattened to match flattened y" |
3519 |
| - ) |
3520 |
| - else: |
3521 |
| - # Create series labels for coloring when no color_values provided |
3522 |
| - color_flat = np.repeat(np.arange(n_series), n_points) |
3523 |
| - else: |
3524 |
| - # Standard 1D case |
3525 |
| - x_flat = x.flatten() |
3526 |
| - y_flat = y.flatten() |
3527 |
| - |
3528 |
| - # Handle color values for 1D case |
3529 |
| - if color_values is not None: |
3530 |
| - color_values = np.asarray(color_values).flatten() |
3531 |
| - if len(color_values) != len(y_flat): |
3532 |
| - raise ValueError("color_values must have the same length as y") |
3533 |
| - color_flat = color_values |
3534 |
| - else: |
3535 |
| - color_flat = np.zeros(len(x_flat)) |
3536 |
| - |
3537 |
| - if len(x_flat) != len(y_flat): |
3538 |
| - raise ValueError("x and y must have compatible dimensions") |
3539 |
| - |
3540 |
| - # Group data by unique x values (categories) |
3541 |
| - unique_x = np.unique(x_flat) |
3542 |
| - swarm_x = np.zeros_like(x_flat, dtype=float) |
3543 |
| - swarm_y = np.zeros_like(y_flat, dtype=float) |
3544 |
| - |
3545 |
| - # Calculate point radius from marker size (approximate) |
3546 |
| - # Marker size is in points^2, so radius is sqrt(size/pi) |
3547 |
| - point_radius = np.sqrt(size / np.pi) * 0.01 # Scale factor for data units |
3548 |
| - |
3549 |
| - for i, cat_x in enumerate(unique_x): |
3550 |
| - # Get indices for this category |
3551 |
| - mask = x_flat == cat_x |
3552 |
| - cat_y = y_flat[mask] |
3553 |
| - n_points = len(cat_y) |
3554 |
| - |
3555 |
| - if n_points == 0: |
3556 |
| - continue |
3557 |
| - |
3558 |
| - # Sort by y-value to process from bottom to top |
3559 |
| - sorted_indices = np.argsort(cat_y) |
3560 |
| - sorted_y = cat_y[sorted_indices] |
3561 |
| - |
3562 |
| - # Initialize positions |
3563 |
| - positions = [] |
3564 |
| - |
3565 |
| - for j, y_val in enumerate(sorted_y): |
3566 |
| - # Try to place point at category center first |
3567 |
| - best_x = cat_x |
3568 |
| - best_collision = True |
3569 |
| - |
3570 |
| - # Check for collisions with existing points |
3571 |
| - for attempt in range( |
3572 |
| - n_iter |
3573 |
| - ): # Max attempts to find non-overlapping position |
3574 |
| - collision = False |
3575 |
| - |
3576 |
| - for pos_x, pos_y in positions: |
3577 |
| - # Calculate distance between points |
3578 |
| - dx = best_x - pos_x |
3579 |
| - dy = y_val - pos_y |
3580 |
| - distance = np.sqrt(dx**2 + dy**2) |
3581 |
| - |
3582 |
| - # Check if points would overlap |
3583 |
| - if distance < 2 * point_radius: |
3584 |
| - collision = True |
3585 |
| - break |
3586 |
| - |
3587 |
| - if not collision: |
3588 |
| - best_collision = False |
3589 |
| - break |
3590 |
| - |
3591 |
| - # If collision, try moving horizontally |
3592 |
| - # Alternate between left and right, increasing distance |
3593 |
| - side = 1 if attempt % 2 == 0 else -1 |
3594 |
| - offset = (attempt // 2 + 1) * point_radius * 0.5 |
3595 |
| - best_x = cat_x + side * offset |
3596 |
| - |
3597 |
| - # Store the final position |
3598 |
| - positions.append((best_x, y_val)) |
3599 |
| - |
3600 |
| - # Map back to original indices |
3601 |
| - original_idx = np.where(mask)[0][sorted_indices[j]] |
3602 |
| - swarm_x[original_idx] = best_x |
3603 |
| - swarm_y[original_idx] = y_val |
3604 |
| - |
3605 |
| - # Handle orientation |
3606 |
| - if orientation == "horizontal": |
3607 |
| - plot_x, plot_y = swarm_y, swarm_x |
3608 |
| - else: # vertical |
3609 |
| - plot_x, plot_y = swarm_x, swarm_y |
3610 |
| - |
3611 |
| - # Create the scatter plot with appropriate coloring |
3612 |
| - guide_kw = _pop_params(kwargs, self._update_guide) |
3613 |
| - if "c" not in kwargs and "color" not in kwargs: |
3614 |
| - # Use our computed color values (either feature values or series labels) |
3615 |
| - objs = self.scatter( |
3616 |
| - plot_x, plot_y, s=size, c=color_flat, cmap=cmap, **kwargs |
| 3527 | + if levels is None: |
| 3528 | + levels = np.arange(n_features) |
| 3529 | + |
| 3530 | + if data.ndim > 1 and levels.ndim == 1: |
| 3531 | + levels = np.ones(data.shape) * levels[None] |
| 3532 | + |
| 3533 | + # Bin data to distribute the beeswarm |
| 3534 | + extend_range = max(levels[:, -1]) + max(abs(levels[:, -1] - levels[:, -2])) |
| 3535 | + level_widths = abs(np.diff(levels, axis=1, append=extend_range)) |
| 3536 | + |
| 3537 | + for level, d in enumerate(data.T): |
| 3538 | + # Construct a histogram to estimate |
| 3539 | + # the number of points present at a similar |
| 3540 | + # x (for horizontal beeswarm) or y value (for |
| 3541 | + # vertical beeswarm) |
| 3542 | + counts, edges = np.histogram(d, bins=n_bins) |
| 3543 | + upper_limit = levels[:, level] + level_widths[:, level] |
| 3544 | + lower_limit = levels[:, level] - level_widths[:, level] |
| 3545 | + |
| 3546 | + # Adjust the values for each bin |
| 3547 | + binned = np.clip( |
| 3548 | + np.digitize(d, edges) - 1, |
| 3549 | + 0, |
| 3550 | + len(counts) - 1, |
3617 | 3551 | )
|
3618 |
| - else: |
3619 |
| - # User provided explicit colors |
3620 |
| - objs = self.scatter(plot_x, plot_y, s=size, **kwargs) |
3621 | 3552 |
|
| 3553 | + z = counts.sum() |
| 3554 | + for bin, count in enumerate(counts): |
| 3555 | + # Skip bins without multiple points |
| 3556 | + if count == 0: |
| 3557 | + continue |
| 3558 | + # Collect the group data and extract the |
| 3559 | + # lower and upper bounds |
| 3560 | + idx = np.where(binned == bin)[0].astype(int) |
| 3561 | + lower = min(lower_limit[idx]) |
| 3562 | + upper = max(upper_limit[idx]) |
| 3563 | + # Distribute the points evenly but reduce |
| 3564 | + # the range based on the number of points |
| 3565 | + # in this bin compared to the total number of |
| 3566 | + # points |
| 3567 | + limit = ( |
| 3568 | + (count / z) * (upper - lower) * 0.5 * 0.9 |
| 3569 | + ) # give a slight space between the layers |
| 3570 | + offset = np.linspace(-limit, limit, num=count, endpoint=True) |
| 3571 | + levels[idx, level] += offset |
| 3572 | + |
| 3573 | + # Pop before plotting to avoid issues with guide_kw |
| 3574 | + guide_kw = _pop_params(kwargs, self._update_guide) |
| 3575 | + if feature_values is not None: |
| 3576 | + kwargs = self._parse_cmap(feature_values, **kwargs) |
| 3577 | + kwargs["c"] = feature_values.flat |
| 3578 | + # Use flat to get around the issue of generating |
| 3579 | + # multiple colorbars when feature_values are used |
| 3580 | + flatten = True |
| 3581 | + |
| 3582 | + # Swap the data if we are in vert mode |
| 3583 | + if orientation == "vertical": |
| 3584 | + data, levels = levels, data |
| 3585 | + |
| 3586 | + # Put size back in kwargs |
| 3587 | + if ss is not None: |
| 3588 | + kwargs["s"] = ss |
| 3589 | + |
| 3590 | + if flatten: |
| 3591 | + data, levels = data.flatten(), levels.flatten() |
| 3592 | + |
| 3593 | + objs = self.scatter( |
| 3594 | + data, |
| 3595 | + levels, |
| 3596 | + **kwargs, |
| 3597 | + ) |
3622 | 3598 | self._update_guide(objs, queue_colorbar=False, **guide_kw)
|
3623 | 3599 | if colorbar:
|
3624 | 3600 | self.colorbar(objs, loc=colorbar, **colorbar_kw)
|
|
0 commit comments