Skip to content

Commit 44df33a

Browse files
Refactor beeswarm (#254)
* Update doc string * vectorize function * vectorize function * update signatures * change order of keywords according to use * Clean-up internals * adjust test * mv logic away from plotting * change keyword in tests * add small space between the layers * do the space in fractions * simplify test --------- Co-authored-by: Matthew R. Becker <beckermr@users.noreply.github.com>
1 parent b68dd16 commit 44df33a

File tree

2 files changed

+168
-205
lines changed

2 files changed

+168
-205
lines changed

ultraplot/axes/plot.py

Lines changed: 145 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -736,16 +736,69 @@
736736
_beeswarm_docstring = """
737737
Beeswarm plot with `SHAP-style <https://shap.readthedocs.io/en/latest/generated/shap.plots.beeswarm.html#shap.plots.beeswarm>`_ feature value coloring.
738738
739-
%(plot.scatter)s
740-
color_values : array-like, optional
741-
Values to use for coloring points. Should match the shape of the data.
742-
Enables SHAP-style feature value coloring.
743-
color_by_feature : array-like, optional
744-
Alias for color_values. Values to color points by (e.g., feature values).
745-
n_iter: int, default: 50
746-
Number of iterations for the beeswarm algorithm. More iterations can lead long time to plot but better point separation.
739+
Parameters
740+
----------
741+
data: array-like
742+
The data to be plotted. It is assumed the shape of `data` is (N, M) where N is the number of points and M is the number of features.
743+
levels: array-like, optional
744+
The levels to use for the beeswarm plot. If not provided, the levels are automatically determined based on the data.
745+
n_bins: int or array-like, default: 50
746+
Number of bins to use to reduce the overlap between points.
747+
Bins are used to determine how crowded the points are for each level of the `y` coordinate.
748+
s, size, ms, markersize : float or array-like or unit-spec, optional
749+
The marker size area(s). If this is an array matching the shape of `x` and `y`,
750+
the units are scaled by `smin` and `smax`. If this contains unit string(s), it
751+
is processed by `~ultraplot.utils.units` and represents the width rather than area.
752+
c, color, colors, mc, markercolor, markercolors, fc, facecolor, facecolors \
753+
: array-like or color-spec, optional
754+
The marker color(s). If this is an array matching the shape of `x` and `y`,
755+
the colors are generated using `cmap`, `norm`, `vmin`, and `vmax`. Otherwise,
756+
this should be a valid matplotlib color.
757+
smin, smax : float, optional
758+
The minimum and maximum marker size area in units ``points ** 2``. Ignored
759+
if `absolute_size` is ``True``. Default value for `smin` is ``1`` and for
760+
`smax` is the square of :rc:`lines.markersize`.
761+
area_size : bool, default: True
762+
Whether the marker sizes `s` are scaled by area or by radius. The default
763+
``True`` is consistent with matplotlib. When `absolute_size` is ``True``,
764+
the `s` units are ``points ** 2`` if `area_size` is ``True`` and ``points``
765+
if `area_size` is ``False``.
766+
absolute_size : bool, default: True or False
767+
Whether `s` should be taken to represent "absolute" marker sizes in units
768+
``points`` or ``points ** 2`` or "relative" marker sizes scaled by `smin`
769+
and `smax`. Default is ``True`` if `s` is scalar and ``False`` if `s` is
770+
array-like or `smin` or `smax` were passed.
771+
%(plot.vmin_vmax)s
772+
%(plot.args_1d_shared)s
773+
774+
Other parameters
775+
----------------
776+
%(plot.cmap_norm)s
777+
%(plot.levels_manual)s
778+
%(plot.levels_auto)s
779+
%(plot.cycle)s
780+
lw, linewidth, linewidths, mew, markeredgewidth, markeredgewidths \
781+
: float or sequence, optional
782+
The marker edge width(s).
783+
edgecolors, markeredgecolor, markeredgecolors \
784+
: color-spec or sequence, optional
785+
The marker edge color(s).
786+
%(plot.error_means_{y})s
787+
%(plot.error_bars)s
788+
%(plot.error_shading)s
789+
%(plot.inbounds)s
790+
%(plot.labels_1d)s
791+
%(plot.guide)s
792+
**kwargs
793+
Passed to `~matplotlib.axes.Axes.scatter`.
794+
795+
See also
796+
--------
797+
PlotAxes.scatter
798+
PlotAxes.scatterx
799+
matplotlib.axes.Axes.scatter
747800
"""
748-
docstring._snippet_manager["plot.beeswarm"] = _scatter_docstring.format(y="y")
801+
docstring._snippet_manager["plot.beeswarm"] = _beeswarm_docstring.format(y="y")
749802

750803
# Bar function docstring
751804
_bar_docstring = """
@@ -3440,185 +3493,108 @@ def _apply_lollipop(
34403493
return patch_collection, line_collection
34413494

34423495
@docstring._snippet_manager
3443-
def beeswarm(self, *args, color_values=None, color_by_feature=None, **kwargs):
3496+
def beeswarm(self, *args, **kwargs):
34443497
"""
34453498
%(plot.beeswarm)s
3446-
3447-
Parameters
3448-
----------
3449-
color_values : array-like, optional
3450-
Values to use for coloring points. Should match the shape of the data.
3451-
Enables SHAP-style feature value coloring.
3452-
color_by_feature : array-like, optional
3453-
Alias for color_values. Values to color points by (e.g., feature values).
3454-
"""
3455-
# Allow orientation to be overridden in kwargs
3456-
orientation = kwargs.pop("orientation", "horizontal")
3499+
"""
34573500
return self._apply_beeswarm(
34583501
*args,
3459-
orientation=orientation,
3460-
color_values=color_values,
3461-
color_by_feature=color_by_feature,
34623502
**kwargs,
34633503
)
34643504

3465-
@inputs._preprocess_or_redirect("x", "y", allow_extra=True)
34663505
def _apply_beeswarm(
34673506
self,
3468-
x: np.ndarray,
3469-
y: np.ndarray,
3507+
data: np.ndarray,
3508+
levels: np.ndarray = None,
3509+
feature_values: np.ndarray = None,
3510+
ss: float | np.ndarray = None,
34703511
orientation: str = "horizontal",
3471-
n_iter: int = 50,
3472-
*args,
3512+
n_bins: int = 50,
34733513
**kwargs,
34743514
) -> "Collection":
34753515

3476-
cmap = kwargs.pop("cmap", rc["cmap.diverging"])
3477-
size = kwargs.pop("s", kwargs.pop("size", 20)) # Default marker size
3516+
# Parse input parameters
3517+
ss, _ = self._parse_markersize(ss, **kwargs)
34783518
colorbar = kwargs.pop("colorbar", False)
34793519
colorbar_kw = kwargs.pop("colorbar_kw", {})
34803520

3481-
# Feature value coloring support (SHAP-style)
3482-
color_values = kwargs.pop("color_values", None)
3483-
color_by_feature = kwargs.pop("color_by_feature", None)
3484-
if color_by_feature is not None:
3485-
color_values = color_by_feature # Alias for SHAP-style naming
3486-
3521+
flatten = False
3522+
if data.ndim == 1:
3523+
flatten = True
3524+
data = np.atleast_2d(data)
3525+
n_points, n_features = data.shape[:2]
34873526
# Convert to numpy arrays
3488-
x = np.asarray(x)
3489-
y = np.asarray(y)
3490-
3491-
# Handle 2D y array (multiple series)
3492-
if y.ndim == 2:
3493-
# x should be 1D with length matching y.shape[1]
3494-
if x.shape[-1] != y.shape[1]:
3495-
raise ValueError(
3496-
"For 2D y array, x must be 1D with length matching y.shape[1]"
3497-
)
3498-
3499-
# Flatten y and repeat x for each series
3500-
n_series, n_points = y.shape
3501-
x_flat = x.flatten()
3502-
if x.ndim == 1:
3503-
x_flat = np.tile(x, n_series)
3504-
3505-
y_flat = y.flatten()
3506-
3507-
# Handle color values for 2D case
3508-
if color_values is not None:
3509-
color_values = np.asarray(color_values)
3510-
if color_values.shape == y.shape:
3511-
# Color values match y shape - flatten them
3512-
color_flat = color_values.flatten()
3513-
elif len(color_values) == len(y_flat):
3514-
# Already flattened
3515-
color_flat = color_values
3516-
else:
3517-
raise ValueError(
3518-
"color_values must match the shape of y or be flattened to match flattened y"
3519-
)
3520-
else:
3521-
# Create series labels for coloring when no color_values provided
3522-
color_flat = np.repeat(np.arange(n_series), n_points)
3523-
else:
3524-
# Standard 1D case
3525-
x_flat = x.flatten()
3526-
y_flat = y.flatten()
3527-
3528-
# Handle color values for 1D case
3529-
if color_values is not None:
3530-
color_values = np.asarray(color_values).flatten()
3531-
if len(color_values) != len(y_flat):
3532-
raise ValueError("color_values must have the same length as y")
3533-
color_flat = color_values
3534-
else:
3535-
color_flat = np.zeros(len(x_flat))
3536-
3537-
if len(x_flat) != len(y_flat):
3538-
raise ValueError("x and y must have compatible dimensions")
3539-
3540-
# Group data by unique x values (categories)
3541-
unique_x = np.unique(x_flat)
3542-
swarm_x = np.zeros_like(x_flat, dtype=float)
3543-
swarm_y = np.zeros_like(y_flat, dtype=float)
3544-
3545-
# Calculate point radius from marker size (approximate)
3546-
# Marker size is in points^2, so radius is sqrt(size/pi)
3547-
point_radius = np.sqrt(size / np.pi) * 0.01 # Scale factor for data units
3548-
3549-
for i, cat_x in enumerate(unique_x):
3550-
# Get indices for this category
3551-
mask = x_flat == cat_x
3552-
cat_y = y_flat[mask]
3553-
n_points = len(cat_y)
3554-
3555-
if n_points == 0:
3556-
continue
3557-
3558-
# Sort by y-value to process from bottom to top
3559-
sorted_indices = np.argsort(cat_y)
3560-
sorted_y = cat_y[sorted_indices]
3561-
3562-
# Initialize positions
3563-
positions = []
3564-
3565-
for j, y_val in enumerate(sorted_y):
3566-
# Try to place point at category center first
3567-
best_x = cat_x
3568-
best_collision = True
3569-
3570-
# Check for collisions with existing points
3571-
for attempt in range(
3572-
n_iter
3573-
): # Max attempts to find non-overlapping position
3574-
collision = False
3575-
3576-
for pos_x, pos_y in positions:
3577-
# Calculate distance between points
3578-
dx = best_x - pos_x
3579-
dy = y_val - pos_y
3580-
distance = np.sqrt(dx**2 + dy**2)
3581-
3582-
# Check if points would overlap
3583-
if distance < 2 * point_radius:
3584-
collision = True
3585-
break
3586-
3587-
if not collision:
3588-
best_collision = False
3589-
break
3590-
3591-
# If collision, try moving horizontally
3592-
# Alternate between left and right, increasing distance
3593-
side = 1 if attempt % 2 == 0 else -1
3594-
offset = (attempt // 2 + 1) * point_radius * 0.5
3595-
best_x = cat_x + side * offset
3596-
3597-
# Store the final position
3598-
positions.append((best_x, y_val))
3599-
3600-
# Map back to original indices
3601-
original_idx = np.where(mask)[0][sorted_indices[j]]
3602-
swarm_x[original_idx] = best_x
3603-
swarm_y[original_idx] = y_val
3604-
3605-
# Handle orientation
3606-
if orientation == "horizontal":
3607-
plot_x, plot_y = swarm_y, swarm_x
3608-
else: # vertical
3609-
plot_x, plot_y = swarm_x, swarm_y
3610-
3611-
# Create the scatter plot with appropriate coloring
3612-
guide_kw = _pop_params(kwargs, self._update_guide)
3613-
if "c" not in kwargs and "color" not in kwargs:
3614-
# Use our computed color values (either feature values or series labels)
3615-
objs = self.scatter(
3616-
plot_x, plot_y, s=size, c=color_flat, cmap=cmap, **kwargs
3527+
if levels is None:
3528+
levels = np.arange(n_features)
3529+
3530+
if data.ndim > 1 and levels.ndim == 1:
3531+
levels = np.ones(data.shape) * levels[None]
3532+
3533+
# Bin data to distribute the beeswarm
3534+
extend_range = max(levels[:, -1]) + max(abs(levels[:, -1] - levels[:, -2]))
3535+
level_widths = abs(np.diff(levels, axis=1, append=extend_range))
3536+
3537+
for level, d in enumerate(data.T):
3538+
# Construct a histogram to estimate
3539+
# the number of points present at a similar
3540+
# x (for horizontal beeswarm) or y value (for
3541+
# vertical beeswarm)
3542+
counts, edges = np.histogram(d, bins=n_bins)
3543+
upper_limit = levels[:, level] + level_widths[:, level]
3544+
lower_limit = levels[:, level] - level_widths[:, level]
3545+
3546+
# Adjust the values for each bin
3547+
binned = np.clip(
3548+
np.digitize(d, edges) - 1,
3549+
0,
3550+
len(counts) - 1,
36173551
)
3618-
else:
3619-
# User provided explicit colors
3620-
objs = self.scatter(plot_x, plot_y, s=size, **kwargs)
36213552

3553+
z = counts.sum()
3554+
for bin, count in enumerate(counts):
3555+
# Skip bins without multiple points
3556+
if count == 0:
3557+
continue
3558+
# Collect the group data and extract the
3559+
# lower and upper bounds
3560+
idx = np.where(binned == bin)[0].astype(int)
3561+
lower = min(lower_limit[idx])
3562+
upper = max(upper_limit[idx])
3563+
# Distribute the points evenly but reduce
3564+
# the range based on the number of points
3565+
# in this bin compared to the total number of
3566+
# points
3567+
limit = (
3568+
(count / z) * (upper - lower) * 0.5 * 0.9
3569+
) # give a slight space between the layers
3570+
offset = np.linspace(-limit, limit, num=count, endpoint=True)
3571+
levels[idx, level] += offset
3572+
3573+
# Pop before plotting to avoid issues with guide_kw
3574+
guide_kw = _pop_params(kwargs, self._update_guide)
3575+
if feature_values is not None:
3576+
kwargs = self._parse_cmap(feature_values, **kwargs)
3577+
kwargs["c"] = feature_values.flat
3578+
# Use flat to get around the issue of generating
3579+
# multiple colorbars when feature_values are used
3580+
flatten = True
3581+
3582+
# Swap the data if we are in vert mode
3583+
if orientation == "vertical":
3584+
data, levels = levels, data
3585+
3586+
# Put size back in kwargs
3587+
if ss is not None:
3588+
kwargs["s"] = ss
3589+
3590+
if flatten:
3591+
data, levels = data.flatten(), levels.flatten()
3592+
3593+
objs = self.scatter(
3594+
data,
3595+
levels,
3596+
**kwargs,
3597+
)
36223598
self._update_guide(objs, queue_colorbar=False, **guide_kw)
36233599
if colorbar:
36243600
self.colorbar(objs, loc=colorbar, **colorbar_kw)

0 commit comments

Comments
 (0)