diff --git a/napari_clusters_plotter/_Qt_code.py b/napari_clusters_plotter/_Qt_code.py index 585dfd10..a54cc210 100644 --- a/napari_clusters_plotter/_Qt_code.py +++ b/napari_clusters_plotter/_Qt_code.py @@ -12,8 +12,8 @@ from matplotlib.path import Path from matplotlib.widgets import LassoSelector, RectangleSelector, SpanSelector from napari.layers import Image, Layer -from qtpy.QtCore import QRect -from qtpy.QtGui import QIcon +from qtpy.QtCore import QRect, Qt +from qtpy.QtGui import QGuiApplication, QIcon from qtpy.QtWidgets import ( QAbstractItemView, QHBoxLayout, @@ -390,24 +390,54 @@ def create_options_dropdown(name: str, value, options: dict, label: str): class SelectFrom2DHistogram: - def __init__(self, parent, ax, full_data): + def __init__(self, parent, ax, full_data, histogram): self.parent = parent self.ax = ax self.canvas = ax.figure.canvas self.xys = full_data + self.cluster_id_histo_overlay = None + self.histogram = histogram self.lasso = LassoSelector(ax, onselect=self.onselect) self.ind = [] self.ind_mask = [] + def vert_to_coord(self, vert): + """ + Converts verticis to histogram coordinates in pixels + """ + + # I tried to solve it with self.ax.transData.transform... but it did not work... + xrange = self.histogram[1][-1] - self.histogram[1][0] + yrange = self.histogram[2][-1] - self.histogram[2][0] + v = ( + (vert[0] - self.histogram[1][0]) / (xrange) * self.histogram[0].shape[0], + (vert[1] - self.histogram[2][0]) / (yrange) * self.histogram[0].shape[1], + ) + + coord = tuple([int(c) for c in v]) + return coord + def onselect(self, verts): - path = Path(verts) + if self.parent.manual_clustering_method is None: + return - self.ind_mask = path.contains_points(self.xys) - self.ind = np.nonzero(self.ind_mask)[0] + modifiers = QGuiApplication.keyboardModifiers() - if self.parent.manual_clustering_method is not None: - self.parent.manual_clustering_method(self.ind_mask) + if ( + modifiers == Qt.ControlModifier and len(verts) == 2 + ): # has len of 2 when single click was done + coord_click = self.vert_to_coord(verts[0]) + cluster_id_to_delete = self.cluster_id_histo_overlay[coord_click[::-1]][0] + if cluster_id_to_delete > 0: + self.parent.manual_clustering_method( + np.zeros(shape=self.xys.shape), delete_cluster=cluster_id_to_delete + ) + return + + path = Path(verts) + self.ind_mask = path.contains_points(self.xys) + self.parent.manual_clustering_method(self.ind_mask) def disconnect(self): self.lasso.disconnect_events() @@ -521,6 +551,7 @@ class MplCanvas(FigureCanvas): def __init__(self, parent=None, width=7, height=4, manual_clustering_method=None): self.fig = Figure(figsize=(width, height), constrained_layout=True) self.manual_clustering_method = manual_clustering_method + self.parent = parent self.axes = self.fig.add_subplot(111) self.histogram = None @@ -552,6 +583,9 @@ def __init__(self, parent=None, width=7, height=4, manual_clustering_method=None self.reset() + def set_selector_cluster_id_overlay(self, overlay: np.array): + self.selector.cluster_id_histo_overlay = overlay + def reset_zoom(self): if self.xylim: self.axes.set_xlim(self.xylim[0]) @@ -619,7 +653,9 @@ def make_2d_histogram( self.axes.set_ylim(yedges[0], yedges[-1]) self.histogram = (h, xedges, yedges) self.selector.disconnect() - self.selector = SelectFrom2DHistogram(self, self.axes, self.full_data) + self.selector = SelectFrom2DHistogram( + self, self.axes, self.full_data, self.histogram + ) self.axes.figure.canvas.draw_idle() def make_1d_histogram( diff --git a/napari_clusters_plotter/_plotter.py b/napari_clusters_plotter/_plotter.py index 386407ee..78592b96 100644 --- a/napari_clusters_plotter/_plotter.py +++ b/napari_clusters_plotter/_plotter.py @@ -90,8 +90,9 @@ def __init__(self, napari_viewer): self.analysed_layer = None self.visualized_layer = None + self.cluster_id_histo_overlay = None - def manual_clustering_method(inside): + def manual_clustering_method(inside, **kwargs): inside = np.array(inside) # leads to errors sometimes otherwise if self.analysed_layer is None or len(inside) == 0: @@ -101,7 +102,14 @@ def manual_clustering_method(inside): features = get_layer_tabular_data(self.analysed_layer) modifiers = QGuiApplication.keyboardModifiers() - if modifiers == Qt.ShiftModifier and clustering_ID in features.keys(): + if "delete_cluster" in kwargs: + features[clustering_ID].mask( + features[clustering_ID] == kwargs["delete_cluster"], + other=-1, + inplace=True, + ) + + elif modifiers == Qt.ShiftModifier and clustering_ID in features.keys(): features[clustering_ID].mask( inside, other=features[clustering_ID].max() + 1, inplace=True ) @@ -677,7 +685,7 @@ def run( log_scale=self.log_scale.isChecked(), ) - rgb_img = make_cluster_overlay_img( + rgb_img, self.cluster_id_histo_overlay = make_cluster_overlay_img( cluster_id=plot_cluster_name, features=features, feature_x=self.plot_x_axis_name, @@ -686,6 +694,9 @@ def run( histogram_data=self.graphics_widget.histogram, hide_first_cluster=self.plot_hide_non_selected.isChecked(), ) + self.graphics_widget.set_selector_cluster_id_overlay( + self.cluster_id_histo_overlay + ) xedges = self.graphics_widget.histogram[1] yedges = self.graphics_widget.histogram[2] diff --git a/napari_clusters_plotter/_plotter_utilities.py b/napari_clusters_plotter/_plotter_utilities.py index b77ee5fb..b99bde64 100644 --- a/napari_clusters_plotter/_plotter_utilities.py +++ b/napari_clusters_plotter/_plotter_utilities.py @@ -549,6 +549,7 @@ def make_cluster_overlay_img( ] cluster_overlay_rgba = np.zeros((*h.shape, 4), dtype=float) + cluster_overlay_cluster_id = np.zeros((*h.shape, 1), dtype=int) output_max = np.zeros(h.shape, dtype=float) for cluster, entries in relevant_entries.groupby(cluster_id): @@ -565,5 +566,8 @@ def make_cluster_overlay_img( ] rgba.append(0.9) cluster_overlay_rgba[mask] = rgba + cluster_overlay_cluster_id[mask] = cluster - return cluster_overlay_rgba.swapaxes(0, 1) + return cluster_overlay_rgba.swapaxes(0, 1), cluster_overlay_cluster_id.swapaxes( + 0, 1 + )