Skip to content

Improve paganin filter method's memory estimation #454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Sep 26, 2024
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9a5b5eb
Remove unnecessary `enumerate()`
yousefmoazzam Sep 19, 2024
b223f55
Rename `element` var for improved clarity
yousefmoazzam Sep 19, 2024
9bac3dd
Move unpadded input size calculation to top of estimator
yousefmoazzam Sep 19, 2024
6c48494
Rename var to explicitly mention it accounts for padding
yousefmoazzam Sep 19, 2024
a358662
Move var accounting for cast to `complex64` to match ordering in meth…
yousefmoazzam Sep 19, 2024
e370bec
Estimate 2D FFT plan using cufft plan estimator
yousefmoazzam Sep 19, 2024
2d98c07
Remove unnecessary out slice size in estimation
yousefmoazzam Sep 20, 2024
88e9d14
Remove unnecessary complex slice doubling in estimation
yousefmoazzam Sep 20, 2024
d509cc4
Account for deallocated padded float32 array in memory estimation
yousefmoazzam Sep 20, 2024
fdf58a4
Include padding in reciprocal grid size calculation
yousefmoazzam Sep 20, 2024
135c0c2
Modify value + rename var accounting for cropped float32 IFFT result
yousefmoazzam Sep 20, 2024
f23a062
Add extra var for clarity on where another allocation occurs
yousefmoazzam Sep 20, 2024
f145f51
Rename var to explicitly mention the absence of padding
yousefmoazzam Sep 23, 2024
f1660e3
Rename var to explicitly mention the presence of padding
yousefmoazzam Sep 23, 2024
8e4aefc
Separate estimation of cast to `complex64` from FFT plan estimation
yousefmoazzam Sep 23, 2024
df3d9b4
Estimate 2D FFT plan using cufft plan estimator
yousefmoazzam Sep 23, 2024
a090d1e
Rename var accounting for cropped float32 IFFT result
yousefmoazzam Sep 23, 2024
6c7d076
Add code comment for filter size estimated value
yousefmoazzam Sep 23, 2024
b32a2c8
Account for negligible FFT plan size affecting peak memory usage
yousefmoazzam Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from typing import Tuple
import numpy as np

from httomo.cufft import CufftType, cufft_estimate_2d

__all__ = [
"_calc_memory_bytes_paganin_filter_savu",
"_calc_memory_bytes_paganin_filter_tomopy",
Expand All @@ -37,20 +39,55 @@ def _calc_memory_bytes_paganin_filter_savu(
) -> Tuple[int, int]:
pad_x = kwargs["pad_x"]
pad_y = kwargs["pad_y"]
input_size = np.prod(non_slice_dims_shape) * dtype.itemsize
in_slice_size = (
(non_slice_dims_shape[0] + 2 * pad_y)
* (non_slice_dims_shape[1] + 2 * pad_x)
* dtype.itemsize

# Input (unpadded)
unpadded_in_slice_size = np.prod(non_slice_dims_shape) * dtype.itemsize

# Padded input
padded_non_slice_dims_shape = (
non_slice_dims_shape[0] + 2 * pad_y,
non_slice_dims_shape[1] + 2 * pad_x,
)
# FFT needs complex inputs, so copy to complex happens first
complex_slice = in_slice_size / dtype.itemsize * np.complex64().nbytes
fftplan_slice = complex_slice
filter_size = complex_slice
res_slice = np.prod(non_slice_dims_shape) * np.float32().nbytes
tot_memory_bytes = (
input_size + in_slice_size + complex_slice + fftplan_slice + res_slice
padded_in_slice_size = (
padded_non_slice_dims_shape[0] * padded_non_slice_dims_shape[1] * dtype.itemsize
)

# Padded input cast to `complex64`
complex_slice = padded_in_slice_size / dtype.itemsize * np.complex64().nbytes

# Plan size for 2D FFT
fftplan_slice_size = cufft_estimate_2d(
nx=padded_non_slice_dims_shape[1],
ny=padded_non_slice_dims_shape[0],
fft_type=CufftType.CUFFT_C2C,
)

# Shape of 2D filter is the same as the padded `complex64` slice shape, so the size will be
# the same
filter_size = complex_slice

# Size of cropped/unpadded + cast to float32 result of 2D IFFT
cropped_float32_res_slice = np.prod(non_slice_dims_shape) * np.float32().nbytes

# If the FFT plan size is negligible for some reason, this changes where the peak GPU
# memory usage occurs. Hence, the if/else branching below for calculating the total bytes.
NEGLIGIBLE_FFT_PLAN_SIZE = 16
if fftplan_slice_size < NEGLIGIBLE_FFT_PLAN_SIZE:
tot_memory_bytes = int(
unpadded_in_slice_size + padded_in_slice_size + complex_slice
)
else:
tot_memory_bytes = int(
unpadded_in_slice_size
+ padded_in_slice_size
+ complex_slice
# The padded float32 array is deallocated when a copy is made when casting to complex64
# and the variable `padded_tomo` is reassigned to the complex64 version
- padded_in_slice_size
+ fftplan_slice_size
+ cropped_float32_res_slice
)

return (tot_memory_bytes, filter_size)


Expand All @@ -61,11 +98,14 @@ def _calc_memory_bytes_paganin_filter_tomopy(
) -> Tuple[int, int]:
from httomolibgpu.prep.phase import _shift_bit_length

# Input (unpadded)
unpadded_in_slice_size = np.prod(non_slice_dims_shape) * dtype.itemsize

# estimate padding size here based on non_slice dimensions
pad_tup = []
for index, element in enumerate(non_slice_dims_shape):
diff = _shift_bit_length(element + 1) - element
if element % 2 == 0:
for dim_len in non_slice_dims_shape:
diff = _shift_bit_length(dim_len + 1) - dim_len
if dim_len % 2 == 0:
pad_width = diff // 2
pad_width = (pad_width, pad_width)
else:
Expand All @@ -75,34 +115,55 @@ def _calc_memory_bytes_paganin_filter_tomopy(
pad_width = (left_pad, right_pad)
pad_tup.append(pad_width)

input_size = np.prod(non_slice_dims_shape) * dtype.itemsize

in_slice_size = (
# Padded input
padded_in_slice_size = (
(non_slice_dims_shape[0] + pad_tup[0][0] + pad_tup[0][1])
* (non_slice_dims_shape[1] + pad_tup[1][0] + pad_tup[1][1])
* dtype.itemsize
)
out_slice_size = (
(non_slice_dims_shape[0] + pad_tup[0][0] + pad_tup[0][1])
* (non_slice_dims_shape[1] + pad_tup[1][0] + pad_tup[1][1])
* dtype.itemsize

# Padded input cast to `complex64`
complex_slice = padded_in_slice_size / dtype.itemsize * np.complex64().nbytes

# Plan size for 2D FFT
ny = non_slice_dims_shape[0] + pad_tup[0][0] + pad_tup[0][1]
nx = non_slice_dims_shape[1] + pad_tup[1][0] + pad_tup[1][1]
fftplan_slice_size = cufft_estimate_2d(
nx=nx,
ny=ny,
fft_type=CufftType.CUFFT_C2C,
)

# FFT needs complex inputs, so copy to complex happens first
complex_slice = in_slice_size / dtype.itemsize * np.complex64().nbytes
fftplan_slice = complex_slice
grid_size = np.prod(non_slice_dims_shape) * np.float32().nbytes
# Size of "reciprocal grid" generated, based on padded projections shape
grid_size = np.prod((ny, nx)) * np.float32().nbytes
filter_size = grid_size
res_slice = grid_size

tot_memory_bytes = int(
input_size
+ in_slice_size
+ out_slice_size
+ 2 * complex_slice
+ 0.5 * fftplan_slice
+ 2 * res_slice
)

# Size of cropped/unpadded + cast to float32 result of 2D IFFT
cropped_float32_res_slice = np.prod(non_slice_dims_shape) * np.float32().nbytes

# Size of negative log of cropped float32 result of 2D IFFT
negative_log_slice = cropped_float32_res_slice

# If the FFT plan size is negligible for some reason, this changes where the peak GPU
# memory usage occurs. Hence, the if/else branching below for calculating the total bytes.
NEGLIGIBLE_FFT_PLAN_SIZE = 16
if fftplan_slice_size < NEGLIGIBLE_FFT_PLAN_SIZE:
tot_memory_bytes = int(
unpadded_in_slice_size + padded_in_slice_size + complex_slice
)
else:
tot_memory_bytes = int(
unpadded_in_slice_size
+ padded_in_slice_size
+ complex_slice
# The padded float32 array is deallocated when a copy is made when casting to complex64
# and the variable `padded_tomo` is reassigned to the complex64 version
- padded_in_slice_size
+ fftplan_slice_size
+ cropped_float32_res_slice
+ negative_log_slice
)

subtract_bytes = int(filter_size + grid_size)

return (tot_memory_bytes, subtract_bytes)
Loading