Skip to content

Commit d0700db

Browse files
committed
ENH: More nifti tests and fixes
1 parent f98cbce commit d0700db

File tree

2 files changed

+63
-16
lines changed

2 files changed

+63
-16
lines changed

ants/utils/nifti_utils.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,9 @@ def get_nifti_sform_spatial_info(metadata, shear_threshold=1e-6, max_angle_devia
266266

267267
def get_nifti_qform_spatial_info(metadata):
268268
"""
269-
Extract qform-derived spacing, origin, direction. Uses 'qto_xyz' from the metadata dict.
269+
Extract qform-derived spacing, origin, direction. Uses the 4x4 'qto_xyz' from the metadata dict. This is the rotation
270+
matrix derived from quaternion parameters in the NIfTI header, multiplied by the pixdim scales and with the qoffset
271+
translation.
270272
271273
Note: output is in ITK LPS coordinates
272274
@@ -358,9 +360,9 @@ def get_nifti_spatial_transform_from_metadata(metadata, prefer_sform=True, shear
358360
if verbose:
359361
print(f"[sform] spacing={info_s['spacing']} origin={info_s['origin']} (desheared={info_s['desheared']})")
360362
# Verify spacing vs pixdim
361-
if not _spacing_matches(info_s["spacing"], pixdims):
363+
if not _spacing_matches(info_s['transform_spacing'], pixdims):
362364
warnings.warn(
363-
f"sform-derived spacing {info_s['spacing']} does not match NIfTI pixdim {pixdims}; "
365+
f"sform-derived spacing {info_s['transform_spacing']} does not match NIfTI pixdim {pixdims}; "
364366
"ignoring sform and trying qform.",
365367
RuntimeWarning,
366368
)
@@ -376,22 +378,23 @@ def get_nifti_spatial_transform_from_metadata(metadata, prefer_sform=True, shear
376378
warnings.warn("No usable sform or qform present in metadata; image not modified", RuntimeWarning)
377379
return
378380
info_q = get_nifti_qform_spatial_info(metadata)
379-
if not _spacing_matches(info_q["spacing"], pixdims):
380-
raise ValueError(f"qform-derived spacing {info_q['spacing']} does not match NIfTI pixdim {pixdims}")
381+
if not _spacing_matches(info_q['transform_spacing'], pixdims):
382+
raise ValueError(f"qform-derived spacing {info_q['transform_spacing']} does not match NIfTI pixdim {pixdims}")
381383
if verbose:
382-
print(f"[qform] spacing={info_q['spacing']} origin={info_q['origin']}")
384+
print(f"[qform] spacing={info_q['transform_spacing']} origin={info_q['origin']}")
385+
print(f"Setting spacing to pixdims={pixdims}")
383386
return dict(
384-
origin=info_q["origin"],
387+
origin=info_q['origin'],
385388
spacing=pixdims,
386-
direction=info_q["direction"],
389+
direction=info_q['direction'],
387390
transform_source="qform",
388391
)
389392
else:
390393
# Use sform
391394
return dict(
392-
origin=info_s["origin"],
395+
origin=info_s['origin'],
393396
spacing=pixdims,
394-
direction=info_s["direction"],
397+
direction=info_s['direction'],
395398
transform_source="sform",
396399
)
397400

@@ -439,11 +442,13 @@ def set_nifti_spatial_transform_from_metadata(image, metadata, prefer_sform=True
439442

440443
if image.dimension == 4:
441444
# Use original definition of time spacing and origin
442-
image.set_origin(info["origin"].append(image.origin[-1]))
445+
origin_4d = info['origin'].copy()
446+
origin_4d.append(image.origin[-1])
447+
image.set_origin(origin_4d)
443448
# direction is 4x4 with identity time
444449
dir4 = np.eye(4)
445-
dir4[:3, :3] = np.array(info["direction"])
450+
dir4[:3, :3] = np.array(info['direction'])
446451
image.set_direction(dir4.tolist())
447452
else:
448-
image.set_origin(info["origin"])
449-
image.set_direction(info["direction"])
453+
image.set_origin(info['origin'])
454+
image.set_direction(info['direction'])

tests/test_utils.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tempfile import TemporaryDirectory
88

99
from common import run_tests
10+
from copy import deepcopy
1011

1112
import math
1213
import numpy.testing as nptest
@@ -1351,19 +1352,21 @@ def get_nifti_spatial_transform_from_metadata(self):
13511352
self.assertTrue(xform['transform_source'] == 'sform')
13521353

13531354
# Flip sform just to be different
1354-
mni_metadata_flip = self.mni_metadata.deepcopy()
1355+
mni_metadata_flip = deepcopy(self.mni_metadata)
13551356
srow_x = self.mni_metadata['srow_x'].split()
13561357
srow_x = [str(-float(v)) for v in srow_x]
13571358
mni_metadata_flip['srow_x'] = ' '.join(srow_x)
13581359

13591360
xform = ants.get_nifti_spatial_transform_from_metadata(mni_metadata_flip)
13601361
self.assertTrue(np.allclose(xform['direction'][0], -self.mni.direction[0], atol=1e-5))
13611362
self.assertTrue(np.allclose(xform['origin'][0], -self.mni.origin[0], atol=1e-5))
1363+
1364+
# Prefer qform should give original
13621365
xform_q = ants.get_nifti_spatial_transform_from_metadata(mni_metadata_flip, prefer_sform=False)
13631366
self.assertTrue(np.allclose(xform_q['direction'], self.mni.direction, atol=1e-5))
13641367

13651368
# Check that shear gets removed
1366-
mni_metadata_shear = self.mni_metadata.deepcopy()
1369+
mni_metadata_shear = deepcopy(self.mni_metadata)
13671370
srow_x = self.mni_metadata['srow_x'].split()
13681371
srow_x[1] = str('-0.0001')
13691372
mni_metadata_shear['srow_x'] = ' '.join(srow_x)
@@ -1383,5 +1386,44 @@ def get_nifti_spatial_transform_from_metadata(self):
13831386
self.assertTrue(xform['transform_source'] == 'qform')
13841387

13851388

1389+
def test_set_nifti_spatial_transform_from_metadata(self):
1390+
# Flip sform
1391+
mni_metadata_flip = deepcopy(self.mni_metadata)
1392+
srow_x = self.mni_metadata['srow_x'].split()
1393+
srow_x = [str(-float(v)) for v in srow_x]
1394+
mni_metadata_flip['srow_x'] = ' '.join(srow_x)
1395+
1396+
mni_copy = ants.image_clone(self.mni)
1397+
ants.set_nifti_spatial_transform_from_metadata(mni_copy, mni_metadata_flip)
1398+
self.assertTrue(np.allclose(mni_copy.direction[0], -self.mni.direction[0], atol=1e-5))
1399+
self.assertTrue(np.allclose(mni_copy.direction[1], self.mni.direction[1], atol=1e-5))
1400+
self.assertTrue(np.allclose(mni_copy.direction[2], self.mni.direction[2], atol=1e-5))
1401+
self.assertTrue(np.allclose(mni_copy.origin[0], -self.mni.origin[0], atol=1e-5))
1402+
self.assertTrue(np.allclose(mni_copy.origin[1], self.mni.origin[1], atol=1e-5))
1403+
self.assertTrue(np.allclose(mni_copy.origin[2], self.mni.origin[2], atol=1e-5))
1404+
1405+
# Make a time series image
1406+
mni_metadata_ts = deepcopy(mni_metadata_flip)
1407+
mni_metadata_ts['dim[0]'] = '4'
1408+
mni_metadata_ts['dim[4]'] = '3'
1409+
mni_metadata_ts['pixdim[4]'] = '2.0'
1410+
array_list = [img.numpy() for img in [self.mni, self.mni, self.mni]]
1411+
stacked_array = np.stack(array_list, axis=-1)
1412+
mni_ts = ants.from_numpy(
1413+
stacked_array,
1414+
has_components=False,
1415+
spacing=self.mni.spacing + (2.0,),
1416+
origin=self.mni.origin + (0.0,),
1417+
direction=np.eye(4)
1418+
)
1419+
ants.set_nifti_spatial_transform_from_metadata(mni_ts, mni_metadata_ts)
1420+
1421+
mni_ts_expected_direction = np.eye(4)
1422+
mni_ts_expected_direction[:3,:3] = mni_copy.direction
1423+
1424+
self.assertTrue(np.allclose(mni_ts.direction, mni_ts_expected_direction, atol=1e-5))
1425+
self.assertTrue(np.allclose(mni_ts.origin, mni_copy.origin + (0.0,), atol=1e-5))
1426+
self.assertTrue(np.allclose(mni_ts.spacing, mni_copy.spacing + (2.0,), atol=1e-5))
1427+
13861428
if __name__ == "__main__":
13871429
run_tests()

0 commit comments

Comments
 (0)