Skip to content

Commit 1b5cf8a

Browse files
committed
fixup! nnunet post-processing times
1 parent 851e5ba commit 1b5cf8a

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

data_processing/surface_fitting.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def pointcloud_surface_fitting(points: ArrayLike, crop_to_bbox=False, mask: sitk
8181
return poisson_mesh
8282

8383

84-
def poisson_reconstruction(fissures: sitk.Image, mask: sitk.Image):
84+
def poisson_reconstruction(fissures: sitk.Image, mask: sitk.Image, return_times=False):
8585
print('Performing surface fitting via Poisson Reconstruction')
8686
# transforming labelmap to unit spacing
8787
# fissures = image_ops.resample_equal_spacing(fissures, target_spacing=1.)
@@ -92,6 +92,7 @@ def poisson_reconstruction(fissures: sitk.Image, mask: sitk.Image):
9292

9393
# fit plane to each separate fissure
9494
labels = fissures_tensor.unique()[1:]
95+
times = []
9596
for f in labels:
9697
print(f'Fitting fissure {f} ...')
9798
# extract the current fissure and construct independent image
@@ -107,6 +108,7 @@ def poisson_reconstruction(fissures: sitk.Image, mask: sitk.Image):
107108

108109
# extract point cloud from thinned fissures
109110
fissure_points = mask_to_points(label_tensor, spacing)
111+
times.append(time.time() - start)
110112
print(f'\tTook {time.time() - start:.4f} s')
111113

112114
# compute the mesh
@@ -119,6 +121,7 @@ def poisson_reconstruction(fissures: sitk.Image, mask: sitk.Image):
119121
remove_all_but_biggest_component(poisson_mesh, right=right,
120122
center_x=(fissures.GetSize()[0] * fissures.GetSpacing()[0]) / 2)
121123
fissure_meshes.append(poisson_mesh)
124+
times[-1] += time.time() - start
122125
print(f'\tTook {time.time() - start:.4f} s')
123126

124127
# convert mesh to labelmap by sampling points
@@ -128,7 +131,10 @@ def poisson_reconstruction(fissures: sitk.Image, mask: sitk.Image):
128131
regularized_fissures.CopyInformation(fissures)
129132

130133
print('DONE\n')
131-
return regularized_fissures, fissure_meshes
134+
if return_times:
135+
return regularized_fissures, fissure_meshes, times
136+
else:
137+
return regularized_fissures, fissure_meshes
132138

133139

134140
def o3d_mesh_to_labelmap(o3d_meshes: List[o3d.geometry.TriangleMesh], shape, spacing: Tuple[float], n_samples=10**7) -> torch.Tensor:

0 commit comments

Comments
 (0)