Skip to content

Commit ac77fea

Browse files
authored
Merge pull request matplotlib#29435 from scottshambaugh/more_wireframe_speedups
Fix `plot_wireframe` with nonequal `rstride`, `cstride`, plus additional speedups
2 parents 306c8de + a1f8aa0 commit ac77fea

File tree

4 files changed

+19
-2
lines changed

4 files changed

+19
-2
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,11 +2445,20 @@ def plot_wireframe(self, X, Y, Z, *, axlim_clip=False, **kwargs):
24452445

24462446
row_lines = np.stack([X[rii], Y[rii], Z[rii]], axis=-1)
24472447
col_lines = np.stack([tX[cii], tY[cii], tZ[cii]], axis=-1)
2448-
lines = np.concatenate([row_lines, col_lines])
24492448

2449+
# We autoscale twice because autoscaling is much faster with vectorized numpy
2450+
# arrays, but row_lines and col_lines might not be the same shape, so we can't
2451+
# stack them to check them in a single pass.
2452+
# Note that while the column and row grid points are the same, the lines
2453+
# between them may expand the view limits, so we have to check both.
2454+
self.auto_scale_xyz(row_lines[..., 0], row_lines[..., 1], row_lines[..., 2],
2455+
had_data)
2456+
self.auto_scale_xyz(col_lines[..., 0], col_lines[..., 1], col_lines[..., 2],
2457+
had_data=True)
2458+
2459+
lines = list(row_lines) + list(col_lines)
24502460
linec = art3d.Line3DCollection(lines, axlim_clip=axlim_clip, **kwargs)
24512461
self.add_collection(linec)
2452-
self.auto_scale_xyz(X, Y, Z, had_data)
24532462

24542463
return linec
24552464

Loading
Loading

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,14 @@ def test_wireframe3d():
845845
ax.plot_wireframe(X, Y, Z, rcount=13, ccount=13)
846846

847847

848+
@mpl3d_image_comparison(['wireframe3dasymmetric.png'], style='mpl20')
849+
def test_wireframe3dasymmetric():
850+
fig = plt.figure()
851+
ax = fig.add_subplot(projection='3d')
852+
X, Y, Z = axes3d.get_test_data(0.05)
853+
ax.plot_wireframe(X, Y, Z, rcount=3, ccount=13)
854+
855+
848856
@mpl3d_image_comparison(['wireframe3dzerocstride.png'], style='mpl20')
849857
def test_wireframe3dzerocstride():
850858
fig = plt.figure()

0 commit comments

Comments
 (0)