Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions napari_clusters_plotter/_Qt_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from matplotlib.path import Path
from matplotlib.widgets import LassoSelector, RectangleSelector, SpanSelector
from napari.layers import Image, Layer
from qtpy.QtCore import QRect
from qtpy.QtGui import QIcon
from qtpy.QtCore import QRect, Qt
from qtpy.QtGui import QGuiApplication, QIcon
from qtpy.QtWidgets import (
QAbstractItemView,
QHBoxLayout,
Expand Down Expand Up @@ -390,24 +390,54 @@ def create_options_dropdown(name: str, value, options: dict, label: str):


class SelectFrom2DHistogram:
def __init__(self, parent, ax, full_data):
def __init__(self, parent, ax, full_data, histogram):
self.parent = parent
self.ax = ax
self.canvas = ax.figure.canvas
self.xys = full_data
self.cluster_id_histo_overlay = None
self.histogram = histogram

self.lasso = LassoSelector(ax, onselect=self.onselect)
self.ind = []
self.ind_mask = []

def vert_to_coord(self, vert):
"""
Converts verticis to histogram coordinates in pixels
"""

# I tried to solve it with self.ax.transData.transform... but it did not work...
xrange = self.histogram[1][-1] - self.histogram[1][0]
yrange = self.histogram[2][-1] - self.histogram[2][0]
v = (
(vert[0] - self.histogram[1][0]) / (xrange) * self.histogram[0].shape[0],
(vert[1] - self.histogram[2][0]) / (yrange) * self.histogram[0].shape[1],
)

coord = tuple([int(c) for c in v])
return coord

def onselect(self, verts):
path = Path(verts)
if self.parent.manual_clustering_method is None:
return

self.ind_mask = path.contains_points(self.xys)
self.ind = np.nonzero(self.ind_mask)[0]
modifiers = QGuiApplication.keyboardModifiers()

if self.parent.manual_clustering_method is not None:
self.parent.manual_clustering_method(self.ind_mask)
if (
modifiers == Qt.ControlModifier and len(verts) == 2
): # has len of 2 when single click was done
coord_click = self.vert_to_coord(verts[0])
cluster_id_to_delete = self.cluster_id_histo_overlay[coord_click[::-1]][0]
if cluster_id_to_delete > 0:
self.parent.manual_clustering_method(
np.zeros(shape=self.xys.shape), delete_cluster=cluster_id_to_delete
)
return

path = Path(verts)
self.ind_mask = path.contains_points(self.xys)
self.parent.manual_clustering_method(self.ind_mask)

def disconnect(self):
self.lasso.disconnect_events()
Expand Down Expand Up @@ -521,6 +551,7 @@ class MplCanvas(FigureCanvas):
def __init__(self, parent=None, width=7, height=4, manual_clustering_method=None):
self.fig = Figure(figsize=(width, height), constrained_layout=True)
self.manual_clustering_method = manual_clustering_method
self.parent = parent

self.axes = self.fig.add_subplot(111)
self.histogram = None
Expand Down Expand Up @@ -551,6 +582,9 @@ def __init__(self, parent=None, width=7, height=4, manual_clustering_method=None

self.reset()

def set_selector_cluster_id_overlay(self, overlay: np.array):
self.selector.cluster_id_histo_overlay = overlay

def reset_zoom(self):
if self.xylim:
self.axes.set_xlim(self.xylim[0])
Expand Down Expand Up @@ -617,7 +651,9 @@ def make_2d_histogram(

full_data = pd.concat([pd.DataFrame(data_x), pd.DataFrame(data_y)], axis=1)
self.selector.disconnect()
self.selector = SelectFrom2DHistogram(self, self.axes, full_data)
self.selector = SelectFrom2DHistogram(
self, self.axes, full_data, self.histogram
)
self.axes.figure.canvas.draw_idle()

def make_1d_histogram(
Expand Down
17 changes: 14 additions & 3 deletions napari_clusters_plotter/_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def __init__(self, napari_viewer):

self.analysed_layer = None
self.visualized_layer = None
self.cluster_id_histo_overlay = None

def manual_clustering_method(inside):
def manual_clustering_method(inside, **kwargs):
inside = np.array(inside) # leads to errors sometimes otherwise

if self.analysed_layer is None or len(inside) == 0:
Expand All @@ -101,7 +102,14 @@ def manual_clustering_method(inside):
features = get_layer_tabular_data(self.analysed_layer)

modifiers = QGuiApplication.keyboardModifiers()
if modifiers == Qt.ShiftModifier and clustering_ID in features.keys():
if "delete_cluster" in kwargs:
features[clustering_ID].mask(
features[clustering_ID] == kwargs["delete_cluster"],
other=-1,
inplace=True,
)

elif modifiers == Qt.ShiftModifier and clustering_ID in features.keys():
features[clustering_ID].mask(
inside, other=features[clustering_ID].max() + 1, inplace=True
)
Expand Down Expand Up @@ -677,7 +685,7 @@ def run(
log_scale=self.log_scale.isChecked(),
)

rgb_img = make_cluster_overlay_img(
rgb_img, self.cluster_id_histo_overlay = make_cluster_overlay_img(
cluster_id=plot_cluster_name,
features=features,
feature_x=self.plot_x_axis_name,
Expand All @@ -686,6 +694,9 @@ def run(
histogram_data=self.graphics_widget.histogram,
hide_first_cluster=self.plot_hide_non_selected.isChecked(),
)
self.graphics_widget.set_selector_cluster_id_overlay(
self.cluster_id_histo_overlay
)
xedges = self.graphics_widget.histogram[1]
yedges = self.graphics_widget.histogram[2]

Expand Down
6 changes: 5 additions & 1 deletion napari_clusters_plotter/_plotter_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def make_cluster_overlay_img(
]

cluster_overlay_rgba = np.zeros((*h.shape, 4), dtype=float)
cluster_overlay_cluster_id = np.zeros((*h.shape, 1), dtype=int)
output_max = np.zeros(h.shape, dtype=float)

for cluster, entries in relevant_entries.groupby(cluster_id):
Expand All @@ -565,5 +566,8 @@ def make_cluster_overlay_img(
]
rgba.append(0.9)
cluster_overlay_rgba[mask] = rgba
cluster_overlay_cluster_id[mask] = cluster

return cluster_overlay_rgba.swapaxes(0, 1)
return cluster_overlay_rgba.swapaxes(0, 1), cluster_overlay_cluster_id.swapaxes(
0, 1
)