|
1 | 1 | from .base import Block
|
2 | 2 | import matplotlib.pyplot as plt
|
| 3 | +import numpy as np |
3 | 4 |
|
4 | 5 |
|
5 | 6 | class Pcolormesh(Block):
|
@@ -48,16 +49,42 @@ def __len__(self):
|
48 | 49 |
|
49 | 50 |
|
50 | 51 | class Imshow(Block):
|
51 |
| - def __init__(self, X, axis, **kwargs): |
52 |
| - self.X = X |
53 |
| - super().__init__(axis) |
| 52 | + """Images a series of images |
| 53 | +
|
| 54 | + Parameters |
| 55 | + ---------- |
| 56 | + images : list of 2D/3D arrays, or a 3D or 4D array |
| 57 | + matplotlib considers arrays of the shape |
| 58 | + (n,m), (n,m,3), and (n,m,4) to be images. |
| 59 | + Images is either a list of arrays of those shapes, |
| 60 | + or an array of shape (T,n,m), (T,n,m,3), or (T,n,m,4) |
| 61 | + where T is the length of the time axis (assuming ``t_axis=0``). |
| 62 | + axis : matplotlib axis, optional |
| 63 | + The axis to attach the block to |
| 64 | + t_axis : int, optional |
| 65 | + The axis of the array that represents time. Defaults to 0. |
| 66 | + No effect if images is a list. |
| 67 | + Notes |
| 68 | + ----- |
| 69 | + This block accepts additional keyword arguments to be passed to |
| 70 | + :meth:`matplotlib.axes.Axes.imshow` |
| 71 | + """ |
| 72 | + def __init__(self, images, axis=None, t_axis=0, **kwargs): |
| 73 | + self.ims = np.asanyarray(images) |
| 74 | + super().__init__(axis, t_axis) |
| 75 | + |
| 76 | + self._is_list = isinstance(images, list) |
| 77 | + self._dim = len(self.ims.shape) |
54 | 78 |
|
55 |
| - im_slice = [slice(None)]*(len(X.shape)-1) + [slice(1)] |
56 |
| - self.im = self.ax.imshow(X[im_slice], **kwargs) |
| 79 | + Slice = self._make_slice(0, self._dim) |
| 80 | + self.im = self.ax.imshow(self.ims[Slice], **kwargs) |
57 | 81 |
|
58 | 82 | def _update(self, i):
|
59 |
| - im_slice = [slice(None)]*(len(self.X.shape)-1) + [slice(i)] |
60 |
| - self.im.set_array(self.X[im_slice]) |
| 83 | + Slice = self._make_slice(i, self._dim) |
| 84 | + self.im.set_array(self.ims[Slice]) |
| 85 | + return self.im |
61 | 86 |
|
62 | 87 | def __len__(self):
|
63 |
| - return self.X.shape[-1] |
| 88 | + if self._is_list: |
| 89 | + return self.ims.shape[0] |
| 90 | + return self.ims.shape[self.t_axis] |
0 commit comments