Skip to content

Commit 559a314

Browse files
committed
ENH: Handle implicit antsImage -> numpy conversion
Allow images to be put into an array of objects, but not array of scalar types This prevents obscure errors when an operation is meant to operate on img.numpy() but is passed img instead.
1 parent 9e3f8ac commit 559a314

File tree

2 files changed

+60
-43
lines changed

2 files changed

+60
-43
lines changed

ants/core/ants_image.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838

3939
class ANTsImage(object):
4040

41+
# This makes sure that numpy operations (e.g. np.sum(img)) call our __array__ method so that users get a helpful
42+
# error message instead of the numpy fallback to an object array
43+
__array_priority__ = 10000
44+
45+
4146
def __init__(self, pointer):
4247
"""
4348
Initialize an ANTsImage.
@@ -571,8 +576,20 @@ def __setitem__(self, idx, value):
571576
def __iter__(self):
572577
# Do not allow iteration on ANTsImage. Builtin iteration, eg sum(), will generally be much slower
573578
# than using numpy methods. We need to explicitly disallow it to prevent breaking object state.
574-
raise TypeError("ANTsImage is not iterable. See docs for available functions, or use numpy.")
575-
579+
raise TypeError("ANTsImage is not iterable. See docs for available functions, or convert to numpy.")
580+
581+
def __array__(self, dtype=None):
582+
if dtype is not None and dtype is np.dtype('O'):
583+
# Allow conversion to an object array (eg np.array([img1, img2, img3], dtype=object))
584+
out = np.empty((), dtype=object)
585+
out[()] = self
586+
return out
587+
else:
588+
# Disallow implicit conversion to numeric numpy array. This prevents using ANTsImage objects in numpy functions
589+
# as that's complicated to handle correctly. Users should explicitly convert to numpy array using the .numpy()
590+
# function
591+
raise TypeError("ANTsImage cannot be implicitly converted to a numeric array. Use the .numpy() method to obtain a "
592+
"copy of the image data as a numpy array. If you want a numpy array of ANTsImage objects, use np.array(..., dtype=object).")
576593

577594
def __repr__(self):
578595
if self.dimension == 3:

tests/test_plotting.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
# Prevent displaying figures
1515
import matplotlib as mpl
16-
backend_ = mpl.get_backend()
17-
mpl.use("Agg")
16+
backend_ = mpl.get_backend()
17+
mpl.use("Agg")
1818

1919
import numpy as np
2020
import numpy.testing as nptest
@@ -39,7 +39,7 @@ def test_plot_example(self):
3939
ants.plot(img, overlay=img*2)
4040
ants.plot(img, overlay=img*2)
4141
ants.plot(img, filename=filename)
42-
42+
4343
def test_extra_plot(self):
4444
img = ants.image_read(ants.get_ants_data('r16'))
4545
ants.plot(img, overlay=img*2, domain_image_map=ants.image_read(ants.get_data('r64')))
@@ -48,18 +48,18 @@ def test_extra_plot(self):
4848
ants.plot(img, crop=True)
4949

5050
img = ants.image_read(ants.get_ants_data('mni'))
51-
ants.plot(img, overlay=img*2,
51+
ants.plot(img, overlay=img*2,
5252
domain_image_map=ants.image_read(ants.get_data('mni')).resample_image((4,4,4)))
5353

5454
img = ants.image_read(ants.get_ants_data('mni'))
5555
ants.plot(img, overlay=img*2, reorient=True, crop=True)
56-
56+
5757
def test_random(self):
5858
img = ants.image_read(ants.get_ants_data('r16'))
5959
img3 = ants.image_read(ants.get_data('r64'))
6060
img2 = ants.image_read(ants.get_ants_data('mni'))
6161
imgv = ants.merge_channels([img2])
62-
62+
6363
ants.plot(img2, axis='x', scale=True, ncol=1)
6464
ants.plot(img2, axis='y', scale=(0.05, 0.95))
6565
ants.plot(img2, axis='z', slices=[10,20,30], title='Test', cbar=True,
@@ -76,18 +76,18 @@ def test_random(self):
7676
ants.plot(ants.get_ants_data('r16'), overlay=ants.merge_channels([img,img]))
7777
ants.plot(ants.from_numpy(np.zeros((100,100))))
7878
ants.plot(img.clone('unsigned int'))
79-
79+
8080
ants.plot(img, domain_image_map=img3)
81-
81+
8282
with self.assertRaises(Exception):
8383
ants.plot(123)
84-
85-
86-
87-
88-
89-
90-
84+
85+
86+
87+
88+
89+
90+
9191
class TestModule_plot_ortho(unittest.TestCase):
9292

9393
def setUp(self):
@@ -102,15 +102,15 @@ def test_plot_example(self):
102102
for img in self.imgs:
103103
ants.plot_ortho(img)
104104
ants.plot_ortho(img, filename=filename)
105-
105+
106106
def test_plot_extra(self):
107107
img = ants.image_read(ants.get_ants_data('mni'))
108-
ants.plot_ortho(img, overlay=img*2,
108+
ants.plot_ortho(img, overlay=img*2,
109109
domain_image_map=ants.image_read(ants.get_data('mni')))
110110

111111
img = ants.image_read(ants.get_ants_data('mni'))
112112
ants.plot_ortho(img, overlay=img*2, reorient=True, crop=True)
113-
113+
114114
def test_random_params(self):
115115
img = ants.image_read(ants.get_ants_data('mni')).resample_image((4,4,4))
116116
img2 = ants.image_read(ants.get_data('r16'))
@@ -121,25 +121,25 @@ def test_random_params(self):
121121
ants.plot_ortho(img2)
122122
with self.assertRaises(Exception):
123123
ants.plot_ortho(img, overlay=img2)
124-
124+
125125
imgx = img.clone()
126126
imgx.set_spacing((3,3,3))
127127
ants.plot_ortho(img,overlay=imgx)
128128
ants.plot_ortho(img.clone('unsigned int'),overlay=img, blend=True)
129-
129+
130130
imgx = img.clone()
131131
imgx.set_spacing((10,1,1))
132132
ants.plot_ortho(imgx)
133-
133+
134134
ants.plot_ortho(img, flat=True, title='Test', text='This is a test')
135135
ants.plot_ortho(img, title='Test', text='This is a test', cbar=True)
136-
136+
137137
with self.assertRaises(Exception):
138138
ants.plot_ortho(img, domain_image_map=123)
139-
139+
140140
with self.assertRaises(Exception):
141141
ants.plot_orto(ants.merge_channels([img,img]))
142-
142+
143143

144144
class TestModule_plot_ortho_stack(unittest.TestCase):
145145

@@ -153,7 +153,7 @@ def test_plot_example(self):
153153
filename = mktemp(suffix='.png')
154154
ants.plot_ortho_stack([self.img, self.img])
155155
ants.plot_ortho_stack([self.img, self.img], filename=filename)
156-
156+
157157
def test_extra_ortho_stack(self):
158158
img = ants.image_read(ants.get_ants_data('mni'))
159159
ants.plot_ortho_stack([img, img], overlays=[img*2, img*2],
@@ -185,9 +185,9 @@ def setUp(self):
185185
mni3 = mni1.smooth_image(2.)
186186
mni4 = mni1.smooth_image(3.)
187187
self.images3d = np.asarray([[mni1, mni2],
188-
[mni3, mni4]])
188+
[mni3, mni4]], dtype='object')
189189
self.images2d = np.asarray([[mni1.slice_image(2,100), mni2.slice_image(2,100)],
190-
[mni3.slice_image(2,100), mni4.slice_image(2,100)]])
190+
[mni3.slice_image(2,100), mni4.slice_image(2,100)]],dtype='object')
191191

192192
def tearDown(self):
193193
pass
@@ -198,19 +198,19 @@ def test_plot_example(self):
198198
ants.plot_grid(self.images3d)
199199
# should work with 2d images
200200
ants.plot_grid(self.images2d)
201-
201+
202202
def test_examples(self):
203203
mni1 = ants.image_read(ants.get_data('mni'))
204204
mni2 = mni1.smooth_image(1.)
205205
mni3 = mni1.smooth_image(2.)
206206
mni4 = mni1.smooth_image(3.)
207207
images = np.asarray([[mni1, mni2],
208-
[mni3, mni4]])
208+
[mni3, mni4]], dtype='object')
209209
slices = np.asarray([[100, 100],
210-
[100, 100]])
210+
[100, 100]], dtype='object')
211211
ants.plot_grid(images=images, slices=slices, title='2x2 Grid')
212212
images2d = np.asarray([[mni1.slice_image(2,100), mni2.slice_image(2,100)],
213-
[mni3.slice_image(2,100), mni4.slice_image(2,100)]])
213+
[mni3.slice_image(2,100), mni4.slice_image(2,100)]], dtype='object')
214214
ants.plot_grid(images=images2d, title='2x2 Grid Pre-Sliced')
215215
ants.plot_grid(images.reshape(1,4), slices.reshape(1,4), title='1x4 Grid')
216216
ants.plot_grid(images.reshape(4,1), slices.reshape(4,1), title='4x1 Grid')
@@ -228,11 +228,11 @@ def test_examples(self):
228228

229229
# Making a publication-quality image
230230
images = np.asarray([[mni1, mni2, mni2],
231-
[mni3, mni4, mni4]])
231+
[mni3, mni4, mni4]], dtype='object')
232232
slices = np.asarray([[100, 100, 100],
233-
[100, 100, 100]])
233+
[100, 100, 100]], dtype='object')
234234
axes = np.asarray([[0, 1, 2],
235-
[0, 1, 2]])
235+
[0, 1, 2]], dtype='object')
236236
ants.plot_grid(images, slices, axes, title='Publication Figures with ANTsPy',
237237
tfontsize=20, title_dy=0.03, title_dx=-0.04,
238238
rlabels=['Row 1', 'Row 2'],
@@ -246,15 +246,15 @@ def setUp(self):
246246
pass
247247
def tearDown(self):
248248
pass
249-
249+
250250
def test_random_ortho_stack_params(self):
251251
img = ants.image_read(ants.get_data('mni')).resample_image((4,4,4))
252252
img2 = ants.image_read(ants.get_data('r16')).resample_image((4,4))
253-
253+
254254
ants.plot_ortho_stack([ants.get_data('mni'), ants.get_data('mni')])
255255
ants.plot_ortho_stack([ants.get_data('mni'), ants.get_data('mni')],
256256
overlays=[ants.get_data('mni'), ants.get_data('mni')])
257-
257+
258258
with self.assertRaises(Exception):
259259
ants.plot_ortho_stack([1,2,3])
260260
with self.assertRaises(Exception):
@@ -263,18 +263,18 @@ def test_random_ortho_stack_params(self):
263263
ants.plot_ortho_stack([img,img], overlays=[img2,img2])
264264
with self.assertRaises(Exception):
265265
ants.plot_ortho_stack([img,img], overlays=[1,2])
266-
266+
267267
imgx = img.clone()
268268
imgx.set_spacing((2,2,2))
269269
ants.plot_ortho_stack([img,imgx])
270-
270+
271271
imgx.set_spacing((2,1,1))
272272
ants.plot_ortho_stack([imgx,img])
273-
273+
274274
ants.plot_ortho_stack([img,img], scale=True, transpose=True,
275275
title='Test', colpad=1, rowpad=1,
276276
xyz_lines=True)
277277
ants.plot_ortho_stack([img,img], scale=(0.05,0.95))
278-
278+
279279
if __name__ == '__main__':
280280
run_tests()

0 commit comments

Comments
 (0)