Skip to content

Commit b48aaf0

Browse files
committed
fix code to calculate site spacing & figure reformatting
1 parent 48e0744 commit b48aaf0

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

element_array_ephys/plotting/unit_level.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from .. import probe
2+
from modulefinder import Module
13
import numpy as np
24
import pandas as pd
35
import plotly.graph_objs as go
@@ -65,19 +67,17 @@ def plot_correlogram(
6567
template="simple_white",
6668
width=350,
6769
height=350,
68-
yaxis_range=[0, None]
70+
yaxis_range=[0, None],
6971
)
7072
return fig
7173

7274

7375
def plot_depth_waveforms(
76+
ephys: Module,
7477
unit_key: dict,
7578
y_range: float = 60,
7679
) -> go.Figure:
7780

78-
from .. import probe
79-
from .. import ephys_no_curation as ephys
80-
8181
sampling_rate = (ephys.EphysRecording & unit_key).fetch1(
8282
"sampling_rate"
8383
) / 1e3 # in kHz
@@ -122,13 +122,13 @@ def plot_depth_waveforms(
122122
x_min, x_max = np.min(coords[:, 0]), np.max(coords[:, 0])
123123
y_min, y_max = np.min(coords[:, 1]), np.max(coords[:, 1])
124124

125-
# Spacing between channels (in um)
126-
x_inc = np.abs(np.diff(coords[:, 0])).min()
127-
y_inc = (np.abs(np.diff(coords[:, 1]))).max()
125+
# Spacing between recording sites (in um)
126+
x_inc = np.abs(np.diff(coords[coords[:, 1] == coords[0, 1]][:, 0])).mean() / 2
127+
y_inc = np.abs(np.diff(coords[coords[:, 0] == coords[0, 0]][:, 1])).mean() / 2
128128

129129
time = np.arange(waveforms.shape[1]) / sampling_rate
130130

131-
x_scale_factor = x_inc / (time[-1] + 1 / sampling_rate)
131+
x_scale_factor = x_inc / (time[-1] + 1 / sampling_rate) # correspond to 1 ms
132132
time_scaled = time * x_scale_factor
133133

134134
wf_amps = waveforms.max(axis=1) - waveforms.min(axis=1)
@@ -152,7 +152,7 @@ def plot_depth_waveforms(
152152
x=time_scaled + coord[0],
153153
y=wf_scaled + coord[1],
154154
mode="lines",
155-
line=dict(color=color, width=1),
155+
line=dict(color=color, width=1.5),
156156
hovertemplate=f"electrode {electrode}<br>"
157157
+ f"x ={coord[0]: .0f} μm<br>"
158158
+ f"y ={coord[1]: .0f} μm<extra></extra>",
@@ -164,7 +164,7 @@ def plot_depth_waveforms(
164164
yaxis_title="Distance from the probe tip (μm)",
165165
template="simple_white",
166166
width=400,
167-
height=700,
167+
height=600,
168168
xaxis_range=[x_min - x_inc / 2, x_max + x_inc * 1.2],
169169
yaxis_range=[y_min - y_inc * 2, y_max + y_inc * 2],
170170
)
@@ -173,12 +173,12 @@ def plot_depth_waveforms(
173173
fig.update_xaxes(tickvals=xtick_loc, ticktext=xtick_label)
174174

175175
# Add a scale bar
176-
x0 = xtick_loc[0] / 6
177-
y0 = y_min - y_inc * 1.5
176+
x0 = xtick_loc[0] - (x_scale_factor * 1.5)
177+
y0 = y_min - (y_inc * 1.5)
178178

179179
fig.add_trace(
180180
go.Scatter(
181-
x=[x0, xtick_loc[0] + x_scale_factor],
181+
x=[x0, x0 + x_scale_factor],
182182
y=[y0, y0],
183183
mode="lines",
184184
line=dict(color="black", width=2),

0 commit comments

Comments
 (0)