Skip to content

Commit a1125b2

Browse files
committed
Adding model caching for human segmentation
1 parent db14b95 commit a1125b2

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

py/image.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,10 +1244,14 @@ def parsing(self, image, confidence, method, crop_multi, prompt=None, my_unique_
12441244

12451245
from functools import reduce
12461246

1247-
model_path = get_local_filepath(MEDIAPIPE_MODELS['selfie_multiclass_256x256']['model_url'], MEDIAPIPE_DIR)
1248-
model_asset_buffer = None
1249-
with open(model_path, "rb") as f:
1250-
model_asset_buffer = f.read()
1247+
if method in cache:
1248+
_, model_asset_buffer = cache["selfie_multiclass_256x256"][1]
1249+
else:
1250+
model_path = get_local_filepath(MEDIAPIPE_MODELS['selfie_multiclass_256x256']['model_url'], MEDIAPIPE_DIR)
1251+
model_asset_buffer = None
1252+
with open(model_path, "rb") as f:
1253+
model_asset_buffer = f.read()
1254+
update_cache(method, 'human_segmentation', (False, model_asset_buffer))
12511255
image_segmenter_base_options = mp.tasks.BaseOptions(model_asset_buffer=model_asset_buffer)
12521256
options = mp.tasks.vision.ImageSegmenterOptions(
12531257
base_options=image_segmenter_base_options,
@@ -1313,10 +1317,15 @@ def parsing(self, image, confidence, method, crop_multi, prompt=None, my_unique_
13131317
mask = torch.cat(ret_masks, dim=0)
13141318

13151319
elif method == "human_parsing_lip":
1316-
from .human_parsing.run_parsing import HumanParsing
1317-
onnx_path = os.path.join(folder_paths.models_dir, 'onnx')
1318-
model_path = get_local_filepath(HUMANPARSING_MODELS['parsing_lip']['model_url'], onnx_path)
1319-
parsing = HumanParsing(model_path=model_path)
1320+
if method in cache:
1321+
_, parsing = cache[method][1]
1322+
else:
1323+
from .human_parsing.run_parsing import HumanParsing
1324+
onnx_path = os.path.join(folder_paths.models_dir, 'onnx')
1325+
model_path = get_local_filepath(HUMANPARSING_MODELS['parsing_lip']['model_url'], onnx_path)
1326+
parsing = HumanParsing(model_path=model_path)
1327+
update_cache(method, 'human_segmentation', (False, parsing))
1328+
13201329
model_image = image.squeeze(0)
13211330
model_image = model_image.permute((2, 0, 1))
13221331
model_image = to_pil_image(model_image)
@@ -1330,11 +1339,15 @@ def parsing(self, image, confidence, method, crop_multi, prompt=None, my_unique_
13301339
output_image, = JoinImageWithAlpha().join_image_with_alpha(image, alpha)
13311340

13321341
elif method == "human_parts (deeplabv3p)":
1333-
from .human_parsing.run_parsing import HumanParts
1334-
onnx_path = os.path.join(folder_paths.models_dir, 'onnx')
1335-
human_parts_path = os.path.join(onnx_path, 'human-parts')
1336-
model_path = get_local_filepath(HUMANPARSING_MODELS['human-parts']['model_url'], human_parts_path)
1337-
parsing = HumanParts(model_path=model_path)
1342+
if method in cache:
1343+
_, parsing = cache[method][1]
1344+
else:
1345+
from .human_parsing.run_parsing import HumanParts
1346+
onnx_path = os.path.join(folder_paths.models_dir, 'onnx')
1347+
human_parts_path = os.path.join(onnx_path, 'human-parts')
1348+
model_path = get_local_filepath(HUMANPARSING_MODELS['human-parts']['model_url'], human_parts_path)
1349+
parsing = HumanParts(model_path=model_path)
1350+
update_cache(method, 'human_segmentation', (False, parsing))
13381351

13391352
ret_images = []
13401353
ret_masks = []

0 commit comments

Comments
 (0)