diff --git a/pyneuroml/plot/PlotMorphologyVispy.py b/pyneuroml/plot/PlotMorphologyVispy.py index d026dc78..5f6fc70f 100644 --- a/pyneuroml/plot/PlotMorphologyVispy.py +++ b/pyneuroml/plot/PlotMorphologyVispy.py @@ -727,7 +727,7 @@ def plot_interactive_3D( except KeyError: pass - plot_3D_cell_morphology( + meshdata, segment_info = plot_3D_cell_morphology( offset=pos, cell=cell, color=color, @@ -753,7 +753,7 @@ def plot_interactive_3D( logger.info( f"More meshes than threshold ({len(meshdata.keys())}/{precision[1]}), reducing precision to {precision[0]} and re-calculating." ) - plot_interactive_3D( + meshdata, segment_info = plot_interactive_3D( nml_file=plottable_nml_model, min_width=min_width, verbose=verbose, @@ -775,7 +775,9 @@ def plot_interactive_3D( if not nogui: if pbar is not None: pbar.finish() - create_instanced_meshes(meshdata, plot_type, current_view, min_width) + create_instanced_meshes( + meshdata, plot_type, current_view, min_width, segment_info, current_canvas + ) if pynml_in_jupyter: display(current_canvas) else: @@ -964,10 +966,14 @@ def plot_3D_cell_morphology( if meshdata is None: meshdata = {} + segment_info = {} for seg in cell.morphology.segments: p = cell.get_actual_proximal(seg.id) d = seg.distal length = cell.get_segment_length(seg.id) + position = (p.x + offset[0], p.y + offset[1], p.z + offset[2]) + position = tuple(numpy.float32(x) for x in position) + segment_info[position] = [seg.id, cell] # round up to precision r1 = round(p.diameter / 2, mesh_precision) @@ -1037,32 +1043,55 @@ def plot_3D_cell_morphology( logger.debug(f"meshdata added: {key}: {(p, d, seg_color, offset)}") if not nogui: - create_instanced_meshes(meshdata, plot_type, current_view, min_width) + create_instanced_meshes( + meshdata, plot_type, current_view, min_width, segment_info, current_canvas + ) if pynml_in_jupyter: display(current_canvas) else: current_canvas.show() app.run() - return meshdata + return meshdata, segment_info -def create_instanced_meshes(meshdata, plot_type, current_view, min_width): - """Internal function to plot instanced meshes from mesh data. +def clicked_on_seg(position, segment_info): + """ + Associates instance_position to segment's proximal position and returns the segment's id and other information - It is more efficient to collect all the segments that require the same - cylindrical mesh and to create instanced meshes for them. + :param position: coordinates + :type position: tuple(numpy.float32, numpy.float32, numpy.float32) + :segment_info: dictionary with positions as keys and segment ids and cell objects and values + :type: {position: [seg_id, neuroml.Cell]} + """ + seg_id = segment_info[position][0] + cell = segment_info[position][1] + print(f"the segment id is {seg_id}") + print(cell.get_segment_location_info(seg_id)) - See: https://vispy.org/api/vispy.scene.visuals.html#vispy.scene.visuals.InstancedMesh - :param meshdata: meshdata to plot: dictionary with: - key: (r1, r2, length) - value: [(prox, dist, color, offset)] - :param plot_type: type of plot - :type plot_type: str - :param current_view: vispy viewbox to use - :type current_view: ViewBox - :param min_width: minimum width of tubes - :type min_width: float +def create_instanced_meshes( + meshdata, plot_type, current_view, min_width, segment_info=None, current_canvas=None +): + """Internal function to plot instanced meshes from mesh data. + create_insta + It is more efficient to collect all the segments that require the same + cylindrical mesh and to create instanced meshes for them. + + See: https://vispy.org/api/vispy.scene.visuals.html#vispy.scene.visuals.InstancedMesh + + :param meshdata: meshdata to plot: dictionary with: + key: (r1, r2, length) + value: [(prox, dist, color, offset)] + :param plot_type: type of plot + :type plot_type: str + :param current_view: vispy viewbox to use + :type current_view: ViewBox + :param min_width: minimum width of tubes + :type min_width: float + :segment_info: dictionary with positions as keys and segment ids and cell objects and values + :type: {position: [seg_id, neuroml.Cell]} + :param: current_canvas: vispy canvas to use + :type: SceneCanvas """ total_mesh_instances = 0 for d, i in meshdata.items(): @@ -1175,11 +1204,64 @@ def create_instanced_meshes(meshdata, plot_type, current_view, min_width): instance_colors=instance_colors, parent=current_view.scene, ) + mesh.interactive = True + # TODO: add a shading filter for light? assert mesh is not None + + @current_canvas.events.mouse_press.connect + def on_mouse_press(event): + clicked_mesh = current_canvas.visual_at(event.pos) + if isinstance(clicked_mesh, InstancedMesh): + pos1, min, min_pos = get_view_axis_in_scene_coordinates( + event.pos, clicked_mesh + ) + print(f"visual at : {clicked_mesh}") + print(f"event.pos : {event.pos}") + print(f"min distance : {min} and min_pos : {min_pos}") + + # TODO handle when there multiple segments with same proximal + if segment_info is not None: + clicked_on_seg(tuple(min_pos), segment_info) + pbar.finish() +def get_view_axis_in_scene_coordinates(pos, mesh): + """ + Gets the event position (of the click) that is in 2d coordinates and an InstancedMesh object, converts + the instanced_positions from visual to canvas coordinates and finds the instanced_position that is closest to the + clicked position. + Returns a list of the instance_positions projected on canvas coordinates, the minimum distance (float) and the closest instance_position(list) + + :param pos: the event position + :type pos: list [float, float] + :param mesh: InstancedMesh object that was clicked + :type mesh: InstancedMesh object + """ + event_pos = numpy.array([pos[0], pos[1], 0, 1]) # in homogeneous screen coordinates + instance_on_canvas = [] + # Translate each position to corresponding 2d canvas coordinates + for instance in mesh.instance_positions: + on_canvas = mesh.get_transform(map_from="visual", map_to="canvas").map(instance) + on_canvas /= on_canvas[3:] + instance_on_canvas.append(on_canvas) + + min = 10000 + min_pos = None + # Find the closest position to the clicked position + for i, instance_pos in enumerate(instance_on_canvas): + # Not minding z axis + temp_min = numpy.linalg.norm( + numpy.array(event_pos[:2]) - numpy.array(instance_pos[:2]) + ) + if temp_min < min: + min = temp_min + min_pos = i + + return instance_on_canvas, min, mesh.instance_positions[min_pos] + + def plot_3D_schematic( cell: Cell, segment_groups: typing.Optional[typing.List[SegmentGroup]],