Skip to content

Commit ec083cf

Browse files
author
Allard Hendriksen
committed
fdk: Fix non-centered source position
When the source position was not located exactly on the line through the detector center parallel to the detector normal, the pre-weighting could be slightly off. This has now been fixed. Also, some changes could be made to the pre-weighting that did not result in test failure. This possibility has been further minimized.
1 parent 49cb28d commit ec083cf

File tree

2 files changed

+158
-39
lines changed

2 files changed

+158
-39
lines changed

tests/test_fdk.py

Lines changed: 126 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import tomosipo as ts
44
from ts_algorithms import fdk
5+
from ts_algorithms.fdk import fdk_weigh_projections
56
import numpy as np
67

78

@@ -12,30 +13,100 @@ def make_box_phantom():
1213
return x
1314

1415

15-
def test_fdk_reconstruction():
16-
vg = ts.volume(shape=64, size=1)
17-
pg = ts.cone(angles=32, shape=(64, 64), size=(2, 2), src_det_dist=2, src_orig_dist=2)
16+
def astra_fdk(A, y):
17+
vg, pg = A.astra_compat_vg, A.astra_compat_pg
18+
19+
vd = ts.data(vg)
20+
pd = ts.data(pg, y.cpu().numpy())
21+
ts.fdk(vd, pd)
22+
# XXX: disgregard clean up of vd and pd (tests are short-lived and
23+
# small)
24+
return torch.from_numpy(vd.data.copy()).to(y.device)
25+
26+
27+
# Standard parameters
28+
vg64 = [
29+
ts.volume(shape=64), # voxel size == 1
30+
ts.volume(shape=64, size=1), # volume size == 1
31+
ts.volume(shape=64, size=1), # idem
32+
]
33+
pg64 = [
34+
ts.cone(angles=96, shape=96, src_det_dist=192), # pixel size == 1
35+
ts.cone(angles=96, shape=96, size=1.5, src_det_dist=3), # detector size == 1 / 64
36+
ts.cone(angles=96, shape=96, size=3, src_det_dist=3, src_orig_dist=3), # magnification 2
37+
]
38+
phantom64 = [
39+
make_box_phantom(),
40+
make_box_phantom(),
41+
make_box_phantom(),
42+
]
43+
44+
45+
@pytest.mark.parametrize("vg, pg, x", zip(vg64, pg64, phantom64))
46+
def test_astra_compatibility(vg, pg, x):
47+
A = ts.operator(vg, pg)
48+
y = A(x)
49+
rec_ts = fdk(A, y)
50+
rec_astra = astra_fdk(A, y)
51+
52+
print(abs(rec_ts - rec_astra).max())
53+
assert torch.allclose(rec_ts, rec_astra, atol=5e-4)
54+
55+
56+
def test_fdk_flipped_cone_geometry():
57+
vg = ts.volume(shape=64)
58+
angles = np.linspace(0, 2 * np.pi, 96)
59+
R = ts.rotate(pos=0, axis=(1, 0, 0), rad=angles)
60+
pg = ts.cone_vec(
61+
shape=(96, 96),
62+
src_pos=[[0, 130, 0]], # usually -130
63+
det_pos=[[0, 0, 0]],
64+
det_v=[[1, 0, 0]],
65+
det_u=[[0, 0, 1]],
66+
)
67+
A = ts.operator(vg, R * pg)
1868

69+
fdk(A, torch.ones(A.range_shape))
70+
71+
72+
@pytest.mark.parametrize("vg, pg, x", zip(vg64, pg64, phantom64))
73+
def test_fdk_inverse(vg, pg, x):
74+
"""Rough test if reconstruction is close to original volume.
75+
76+
The mean error must be less than 10%. The sharp edges of the box
77+
phantom make this a difficult test case.
78+
"""
1979
A = ts.operator(vg, pg)
20-
x = make_box_phantom()
2180
y = A(x)
2281

23-
# rough test if reconstruction is close to original volume
2482
rec = fdk(A, y)
25-
assert torch.mean(torch.abs(rec - x)) < 0.15
26-
rec_nonPadded = fdk(A, y, padded=False)
27-
assert torch.mean(torch.abs(rec_nonPadded - x)) < 0.15
83+
assert torch.mean(torch.abs(rec - x)) < 0.1
84+
2885

29-
# test whether cone and cone_vec geometries yield the same result
86+
@pytest.mark.parametrize("vg, pg, x", zip(vg64, pg64, phantom64))
87+
def test_fdk_cone_vec(vg, pg, x):
88+
""" Test that cone and cone_vec yield same result."""
89+
A = ts.operator(vg, pg)
3090
A_vec = ts.operator(vg, pg.to_vec())
91+
y = A(x)
92+
93+
rec = fdk(A, y)
3194
rec_vec = fdk(A_vec, y)
32-
assert torch.allclose(rec_vec, rec, atol=1e-3, rtol=1e-2)
33-
assert torch.mean(torch.abs(rec_vec - rec)) < 1e-6
95+
assert torch.allclose(rec, rec_vec, atol=5e-4)
96+
97+
98+
@pytest.mark.parametrize("vg, pg, x", zip(vg64, pg64, phantom64))
99+
def test_fdk_gpu(vg, pg, x):
100+
""" Test that cuda and cpu tensors yield same result."""
101+
A = ts.operator(vg, pg)
102+
y = A(x)
34103

35-
# test whether GPU and CPU calculations yield the same result
104+
rec_cpu = fdk(A, y)
36105
rec_cuda = fdk(A, y.cuda()).cpu()
37-
assert torch.allclose(rec_cuda, rec, atol=1e-3, rtol=1e-2)
38-
assert torch.mean(torch.abs(rec_cuda - rec)) < 1e-6
106+
107+
# The atol is necessary because the ASTRA backprojection appears
108+
# to differ slightly when given cpu and gpu arguments...
109+
assert torch.allclose(rec_cpu, rec_cuda, atol=5e-4)
39110

40111

41112
def test_fdk_off_center_cor():
@@ -150,6 +221,45 @@ def test_fdk_off_center_cor_subsets():
150221
assert torch.allclose(r[sub_slice], r_sub, atol=1e-1, rtol=1e-6)
151222

152223

224+
225+
@pytest.mark.parametrize("vg, pg, x", zip(vg64, pg64, phantom64))
226+
def test_fdk_split_detector(vg, pg, x):
227+
"""Split detector in four quarters
228+
229+
Test that pre-weighting each quarter individually is the same as
230+
pre-weighting the full detector at once.
231+
"""
232+
233+
pg = pg.to_vec()
234+
235+
# determine the half-length of the detector shape:
236+
n, m = np.array(pg.det_shape) // 2
237+
238+
# Generate slices to split the detector of a projection geometry
239+
# into four slices.
240+
pg_slices = [
241+
np.s_[:, :n, :m],
242+
np.s_[:, :n, m:],
243+
np.s_[:, n:, :m],
244+
np.s_[:, n:, m:],
245+
]
246+
# Change slices to be in 'sinogram' form with angles in the middle.
247+
sino_slices = [(slice_v, slice_angles, slice_u) for (slice_angles, slice_v, slice_u) in pg_slices]
248+
249+
A = ts.operator(vg, pg)
250+
y = A(x)
251+
252+
As = [ts.operator(vg, pg[pg_slice]) for pg_slice in pg_slices]
253+
254+
w = fdk_weigh_projections(A, y)
255+
sub_ws = [fdk_weigh_projections(A_sub, y[sino_slice].contiguous()) for A_sub, sino_slice in zip(As, sino_slices)]
256+
257+
for sub_w, sino_slice in zip(sub_ws, sino_slices):
258+
abs_diff = abs(w[sino_slice] - sub_w)
259+
print(sub_w.max(), abs_diff.max().item(), abs_diff.mean().item())
260+
assert torch.allclose(w[sino_slice], sub_w, rtol=1e-2)
261+
262+
153263
def test_fdk_rotating_volume():
154264
"""Test that fdk handles volume_vec geometries correctly
155265
@@ -341,11 +451,11 @@ def test_fdk_errors():
341451

342452
# 4. Rotation center behind source position
343453
vg = ts.volume(pos=(0, -64, 0), shape=64).to_vec()
344-
pg = ts.cone(shape=96, angles=1, src_det_dist=128)
454+
pg = ts.cone(shape=96, angles=1, src_det_dist=128).to_vec()
345455
angles = np.linspace(0, 2 * np.pi, 90)
346456
R = ts.rotate(pos=(0, -129, 0), axis=(1, 0, 0), rad=angles)
347457

348458
A = ts.operator(R * vg, pg)
349459

350-
with pytest.raises(ValueError):
460+
with pytest.warns(UserWarning):
351461
fdk(A, torch.ones(A.range_shape))

ts_algorithms/fdk.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,34 @@ def fdk_weigh_projections(A, y):
1515
vg, pg = A.domain, A.range
1616

1717
T = ts.from_perspective(
18-
pos=pg.det_pos,
18+
pos=pg.src_pos,
1919
w=pg.det_v/np.linalg.norm(pg.det_v[0]),
2020
v=pg.det_normal/np.linalg.norm(pg.det_normal[0]),
2121
u=pg.det_u/np.linalg.norm(pg.det_u[0])
2222
)
23-
23+
# We have:
24+
# - the source position is placed at the origin
25+
# - the z axis is parallel to the detector v axis
26+
# - the y axis is orthogonal to the detecor plane
27+
# - the x axis is parallel to the detector u axis
2428
vg_fixed, pg_fixed = T * vg.to_vec(), T * pg.to_vec()
2529

2630
###########################################################################
2731
# Determine source-detector distance #
2832
###########################################################################
29-
# Read of source-detector distance in the y coordinate of the
30-
# transformed source position.
31-
src_det_dists = pg_fixed.src_pos[:, 1]
32-
src_det_dist = src_det_dists.mean()
33-
34-
if np.ptp(src_det_dists) > ts.epsilon:
33+
# The source is located on the origin. So the source-detector
34+
# distance can be read of from the detector position. We warn if
35+
# the detector position is not constant (and use the mean detector
36+
# position regardless).
37+
det_positions = pg_fixed.det_pos
38+
det_pos = det_positions.mean(axis=0)
39+
40+
if np.ptp(det_positions, axis=0).max() > ts.epsilon:
3541
warnings.warn(
3642
f"The source to detector distance is not constant. "
37-
f"It has a variation of {np.ptp(src_det_dists): 0.2e}. "
43+
f"It has a variation of {np.ptp(det_positions, axis=0)}. "
3844
f"This may cause unexpected results in the reconstruction. "
39-
f"The mean source to detector distance ({src_det_dist: 0.2e}) "
45+
f"The mean source to detector distance ({det_pos}) "
4046
"has been used to compute the reconstruction. "
4147
)
4248

@@ -45,15 +51,20 @@ def fdk_weigh_projections(A, y):
4551
###########################################################################
4652
# Read of source-object distance in the y coordinate of the
4753
# transformed volume position.
48-
src_obj_dists = vg_fixed.pos[:, 1] - src_det_dists
54+
obj_positions = vg_fixed.pos[:, 1]
4955

5056
# Take the rotation center as the mean of the volume positions.
51-
src_rot_center_dist = src_obj_dists.mean()
52-
53-
if src_rot_center_dist < 0.0:
54-
raise ValueError(
55-
"Rotation center is behind source position. "
56-
"Consider adjusting your geometry to obtain a reconstruction. "
57+
rot_center_pos = obj_positions.mean()
58+
src_rot_center_dist = abs(rot_center_pos)
59+
60+
# Check that the center of rotation is "in front" of the source
61+
# beam. Warn otherwise. We want to avoid the situation:
62+
# rot_center src ----> det
63+
# ⊙ . ----> | or | <---- . ⊙
64+
if np.sign(rot_center_pos) != np.sign(det_pos[1]):
65+
warnings.warn(
66+
"Rotation center of volume is behind source position. "
67+
"Adjust your geometry to obtain a better reconstruction. "
5768
)
5869

5970
###########################################################################
@@ -64,21 +75,19 @@ def fdk_weigh_projections(A, y):
6475

6576
v_range = torch.arange(num_v, dtype=torch.float64) - (num_v - 1) / 2
6677
u_range = torch.arange(num_u, dtype=torch.float64) - (num_u - 1) / 2
67-
u_pos_squared = (u_size * u_range) ** 2
68-
v_pos_squared = (v_size * v_range) ** 2
78+
u_pos_squared = (det_pos[2] + u_size * u_range) ** 2
79+
v_pos_squared = (det_pos[0] + v_size * v_range) ** 2
6980

7081
# Determine source-pixel distance for each pixel on the detector.
7182
src_pixel_dist = torch.sqrt(
72-
u_pos_squared[None, :] + v_pos_squared[:, None] + src_det_dist**2
83+
u_pos_squared[None, :] + v_pos_squared[:, None] + det_pos[1]**2
7384
)
7485

7586
###########################################################################
7687
# Determine weighting #
7788
###########################################################################
78-
weights_mat = src_det_dist / src_pixel_dist
79-
8089
# Multiply with extra scaling factor to account for detector distance
81-
weights_mat *= (src_rot_center_dist / src_det_dist)
90+
weights_mat = src_rot_center_dist / src_pixel_dist
8291
weights_mat = weights_mat.float().to(y.device)
8392

8493
return y * weights_mat[:, None, :]

0 commit comments

Comments
 (0)