@@ -414,8 +414,8 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
414
414
return {"tile" : tile , "halo" : halo }
415
415
416
416
if torch .cuda .is_available ():
417
- # We always use the same default halo.
418
- halo = {"x" : 64 , "y" : 64 , "z" : 16 } # before 64,64,8
417
+ # The default halo size .
418
+ halo = {"x" : 64 , "y" : 64 , "z" : 16 }
419
419
420
420
# Determine the GPU RAM and derive a suitable tiling.
421
421
vram = torch .cuda .get_device_properties (0 ).total_memory / 1e9
@@ -426,9 +426,11 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
426
426
tile = {"x" : 512 , "y" : 512 , "z" : 64 }
427
427
elif vram >= 20 :
428
428
tile = {"x" : 352 , "y" : 352 , "z" : 48 }
429
+ elif vram >= 10 :
430
+ tile = {"x" : 256 , "y" : 256 , "z" : 32 }
431
+ halo = {"x" : 64 , "y" : 64 , "z" : 8 } # Choose a smaller halo in z.
429
432
else :
430
- # TODO determine tilings for smaller VRAM
431
- raise NotImplementedError (f"Estimating the tile size for a GPU with { vram } GB is not yet supported." )
433
+ raise NotImplementedError (f"Infererence with a GPU with { vram } GB VRAM is not supported." )
432
434
433
435
print (f"Determined tile size: { tile } " )
434
436
tiling = {"tile" : tile , "halo" : halo }
0 commit comments