Skip to content

Commit 03d42a8

Browse files
Fix issues with inference code
1 parent 9ceb256 commit 03d42a8

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

synapse_net/inference/util.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,8 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
414414
return {"tile": tile, "halo": halo}
415415

416416
if torch.cuda.is_available():
417-
# We always use the same default halo.
418-
halo = {"x": 64, "y": 64, "z": 16} # before 64,64,8
417+
# The default halo size.
418+
halo = {"x": 64, "y": 64, "z": 16}
419419

420420
# Determine the GPU RAM and derive a suitable tiling.
421421
vram = torch.cuda.get_device_properties(0).total_memory / 1e9
@@ -426,9 +426,11 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
426426
tile = {"x": 512, "y": 512, "z": 64}
427427
elif vram >= 20:
428428
tile = {"x": 352, "y": 352, "z": 48}
429+
elif vram >= 10:
430+
tile = {"x": 256, "y": 256, "z": 32}
431+
halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z.
429432
else:
430-
# TODO determine tilings for smaller VRAM
431-
raise NotImplementedError(f"Estimating the tile size for a GPU with {vram} GB is not yet supported.")
433+
raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
432434

433435
print(f"Determined tile size: {tile}")
434436
tiling = {"tile": tile, "halo": halo}

synapse_net/tools/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,11 @@ def segmentation_cli():
124124
)
125125
parser.add_argument(
126126
"--tile_shape", type=int, nargs=3,
127-
help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient."
127+
help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
128128
)
129129
parser.add_argument(
130130
"--halo", type=int, nargs=3,
131-
help="The halo for prediction. Increase the halo to minimize boundary artifacts."
131+
help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
132132
)
133133
parser.add_argument(
134134
"--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."

0 commit comments

Comments
 (0)