Skip to content

Commit 0370e23

Browse files
committed
polish off Imshow block
1 parent a9498d4 commit 0370e23

File tree

3 files changed

+942
-8
lines changed

3 files changed

+942
-8
lines changed

animatplot/blocks/image_like.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base import Block
22
import matplotlib.pyplot as plt
3+
import numpy as np
34

45

56
class Pcolormesh(Block):
@@ -48,16 +49,42 @@ def __len__(self):
4849

4950

5051
class Imshow(Block):
51-
def __init__(self, X, axis, **kwargs):
52-
self.X = X
53-
super().__init__(axis)
52+
"""Images a series of images
53+
54+
Parameters
55+
----------
56+
images : list of 2D/3D arrays, or a 3D or 4D array
57+
matplotlib considers arrays of the shape
58+
(n,m), (n,m,3), and (n,m,4) to be images.
59+
Images is either a list of arrays of those shapes,
60+
or an array of shape (T,n,m), (T,n,m,3), or (T,n,m,4)
61+
where T is the length of the time axis (assuming ``t_axis=0``).
62+
axis : matplotlib axis, optional
63+
The axis to attach the block to
64+
t_axis : int, optional
65+
The axis of the array that represents time. Defaults to 0.
66+
No effect if images is a list.
67+
Notes
68+
-----
69+
This block accepts additional keyword arguments to be passed to
70+
:meth:`matplotlib.axes.Axes.imshow`
71+
"""
72+
def __init__(self, images, axis=None, t_axis=0, **kwargs):
73+
self.ims = np.asanyarray(images)
74+
super().__init__(axis, t_axis)
75+
76+
self._is_list = isinstance(images, list)
77+
self._dim = len(self.ims.shape)
5478

55-
im_slice = [slice(None)]*(len(X.shape)-1) + [slice(1)]
56-
self.im = self.ax.imshow(X[im_slice], **kwargs)
79+
Slice = self._make_slice(0, self._dim)
80+
self.im = self.ax.imshow(self.ims[Slice], **kwargs)
5781

5882
def _update(self, i):
59-
im_slice = [slice(None)]*(len(self.X.shape)-1) + [slice(i)]
60-
self.im.set_array(self.X[im_slice])
83+
Slice = self._make_slice(i, self._dim)
84+
self.im.set_array(self.ims[Slice])
85+
return self.im
6186

6287
def __len__(self):
63-
return self.X.shape[-1]
88+
if self._is_list:
89+
return self.ims.shape[0]
90+
return self.ims.shape[self.t_axis]

0 commit comments

Comments
 (0)