Skip to content

adding average_radius parameter into the rotation wrapper #534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions httomo/method_wrappers/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def _build_kwargs(
dict_params["ind"] == "mid" or dict_params["ind"] is None
):
updated_params = {**dict_params, "ind": (dataset.shape[1] - 1) // 2}
if "average_radius" not in dict_params:
updated_params.update({"average_radius": 0})
return super()._build_kwargs(updated_params, dataset)

def _gather_sino_slice(self, global_shape: Tuple[int, int, int]):
Expand Down Expand Up @@ -144,6 +146,10 @@ def _run_method(self, block: T, args: Dict[str, Any]) -> T:
else:
assert "ind" in args
slice_for_cor = args["ind"]
if "average_radius" not in args:
average_radius = 0
else:
average_radius = args["average_radius"]
# append to internal sinogram, until we have the last block
if self.sino is None:
self.sino = np.empty(
Expand All @@ -157,8 +163,25 @@ def _run_method(self, block: T, args: Dict[str, Any]) -> T:
if block.is_padded:
core_angles_start = block.padding[0]
core_angles_stop = core_angles_start + block.shape_unpadded[0]

data = block.data[core_angles_start:core_angles_stop, slice_for_cor, :]
if average_radius == 0:
data = block.data[core_angles_start:core_angles_stop, slice_for_cor, :]
else:
if 2 * average_radius <= block.data.shape[1]:
# averaging few sinograms to improve SNR and centering method accuracy
data = xp.mean(
block.data[
core_angles_start:core_angles_stop,
slice_for_cor
- average_radius : slice_for_cor
+ average_radius,
:,
],
1,
)
else:
raise ValueError(
f"The given average_radius = {average_radius} in the centering method is larger than the half size of the block = {block.data.shape[1]//2}. Please make it smaller or 0."
)

if block.is_gpu:
with catchtime() as t:
Expand Down
43 changes: 39 additions & 4 deletions tests/method_wrappers/test_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_rotation_accumulates_blocks(mocker: MockerFixture, padding: Tuple[int,
)

class FakeModule:
def rotation_tester(data, ind=None):
def rotation_tester(data, ind=None, average_radius=None):
assert data.ndim == 2 # for 1 slice only
np.testing.assert_array_equal(
data, global_data[:, (GLOBAL_SHAPE[1] - 1) // 2, :]
Expand Down Expand Up @@ -145,10 +145,11 @@ def test_rotation_gathers_single_sino_slice(
)

class FakeModule:
def rotation_tester(data, ind=None):
def rotation_tester(data, ind=None, average_radius=None):
assert rank == 0 # for rank 1, it shouldn't be called
assert data.ndim == 2 # for 1 slice only
assert ind == 0
assert average_radius == 0
if ind_par == "mid" or ind_par is None:
xp.testing.assert_array_equal(
global_data[:, (GLOBAL_SHAPE[1] - 1) // 2, :],
Expand Down Expand Up @@ -341,7 +342,7 @@ def test_rotation_normalize_sino_different_darks_flats_gpu():

def test_rotation_180(mocker: MockerFixture):
class FakeModule:
def rotation_tester(data, ind):
def rotation_tester(data, ind, average_radius):
return 42.0 # center of rotation

mocker.patch(
Expand All @@ -355,6 +356,7 @@ def rotation_tester(data, ind):
make_mock_preview_config(mocker),
output_mapping={"cor": "center"},
ind=5,
average_radius=0,
)

block = DataSetBlock(
Expand All @@ -367,6 +369,38 @@ def rotation_tester(data, ind):
assert new_block == block # note: not a deep comparison


def test_rotation_180_raise_average_radius(mocker: MockerFixture):
class FakeModule:
def rotation_tester(data, ind, average_radius):
return 42.0 # center of rotation

mocker.patch(
"httomo.method_wrappers.generic.import_module", return_value=FakeModule
)
wrp = make_method_wrapper(
make_mock_repo(mocker, pattern=Pattern.projection),
"mocked_module_path.rotation",
"rotation_tester",
MPI.COMM_WORLD,
make_mock_preview_config(mocker),
output_mapping={"cor": "center"},
ind=5,
average_radius=6,
)

block = DataSetBlock(
data=np.ones((10, 10, 10), dtype=np.float32),
aux_data=AuxiliaryData(angles=np.ones(10, dtype=np.float32)),
)

with pytest.raises(ValueError) as e:
_ = wrp.execute(block)
assert (
"The given average_radius = 6 in the centering method is larger than the half size of the block = 5."
in str(e)
)


def test_rotation_pc_180(mocker: MockerFixture):
class FakeModule:
def find_center_pc(proj1, proj2=None):
Expand Down Expand Up @@ -396,7 +430,7 @@ def find_center_pc(proj1, proj2=None):

def test_rotation_360(mocker: MockerFixture):
class FakeModule:
def rotation_tester(data, ind):
def rotation_tester(data, ind, average_radius):
# cor, overlap, side, overlap_position - from find_center_360
return 42.0, 3.0, 1, 10.0

Expand All @@ -415,6 +449,7 @@ def rotation_tester(data, ind):
"overlap_position": "pos",
},
ind=5,
average_radius=0,
)
block = DataSetBlock(
data=np.ones((10, 10, 10), dtype=np.float32),
Expand Down
Loading