Skip to content

Add Unet KITS19 in PyTorch #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1407fa2
first commit
MarcelWilnicki Mar 8, 2024
7683f0c
wip
MarcelWilnicki Mar 8, 2024
0dda31c
wip
MarcelWilnicki Mar 8, 2024
1605b95
wip
MarcelWilnicki Mar 8, 2024
8920897
wip
MarcelWilnicki Mar 8, 2024
9890071
wip
MarcelWilnicki Mar 12, 2024
fbc378c
wip
MarcelWilnicki Mar 12, 2024
085efb5
wip
MarcelWilnicki Mar 12, 2024
ea6d050
wip
MarcelWilnicki Mar 18, 2024
ae11b27
wip
MarcelWilnicki Mar 18, 2024
188280b
wip
MarcelWilnicki Mar 18, 2024
2d57f20
wip
MarcelWilnicki Mar 18, 2024
33cded2
wip
MarcelWilnicki Mar 19, 2024
ed80ebd
wip
MarcelWilnicki Mar 19, 2024
f612b95
wip
MarcelWilnicki Mar 19, 2024
2d51497
wip
MarcelWilnicki Mar 19, 2024
35cd3b1
wip
MarcelWilnicki Mar 22, 2024
613503f
Merge branch 'main' into marcel/unet-pytorch
MarcelWilnicki Mar 27, 2024
1ea6652
Merge branch 'main' into marcel/unet-pytorch
MarcelWilnicki Mar 27, 2024
bc075b3
wip
MarcelWilnicki Mar 27, 2024
c9aacd8
add tests
MarcelWilnicki Apr 4, 2024
e182f36
wip
MarcelWilnicki Apr 4, 2024
9461b9d
wip
MarcelWilnicki Apr 4, 2024
2ab67e0
wip
MarcelWilnicki Apr 4, 2024
b611466
wip
MarcelWilnicki Apr 4, 2024
f979209
wip
MarcelWilnicki Apr 4, 2024
ffadbf4
wip
MarcelWilnicki Apr 4, 2024
9973ec1
wip
MarcelWilnicki Apr 4, 2024
f74996b
wip
MarcelWilnicki Apr 4, 2024
47e5bce
wip
MarcelWilnicki Apr 4, 2024
946b2d9
Merge branch 'main' into marcel/unet-pytorch
MarcelWilnicki May 16, 2024
164ed88
wip
MarcelWilnicki May 16, 2024
fb7aaaf
wip
MarcelWilnicki May 16, 2024
f390e47
wip
MarcelWilnicki May 16, 2024
885d80b
wip
MarcelWilnicki May 16, 2024
5a30eec
wip
MarcelWilnicki May 16, 2024
a2d715f
wip
MarcelWilnicki May 17, 2024
ac799fc
wip
MarcelWilnicki May 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ jobs:
S3_URL_IMAGENET_DATASET_LABELS: ${{ secrets.S3_URL_IMAGENET_DATASET_LABELS }}
S3_URL_COCO_DATASET: ${{ secrets.S3_URL_COCO_DATASET }}
S3_URL_COCO_DATASET_ANNOTATIONS: ${{ secrets.S3_URL_COCO_DATASET_ANNOTATIONS }}
S3_URL_KITS19_REDUCED_DATASET: ${{ secrets.S3_URL_KITS19_REDUCED_DATASET }}
S3_URL_UNET_KITS_PYTORCH_FP32: ${{ secrets.S3_URL_UNET_KITS_PYTORCH_FP32 }}
S3_URL_COVOST2_DATASET: ${{ secrets.S3_URL_COVOST2_DATASET }}
HF_HUB_TOKEN: ${{ secrets.HF_HUB_TOKEN }}
steps:
Expand Down
75 changes: 36 additions & 39 deletions computer_vision/semantic_segmentation/unet_3d/kits_19/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Ampere Computing LLC


try:
from utils import misc # noqa
except ModuleNotFoundError:
Expand All @@ -18,31 +20,6 @@
sys.exit(1)


def parse_args():
import argparse
parser = argparse.ArgumentParser(description="Run 3D Unet KiTS 2019 model.")
parser.add_argument("-m", "--model_path",
type=str,
help="path to the model")
parser.add_argument("-p", "--precision",
type=str, choices=["fp32"], required=True,
help="precision of the model provided")
parser.add_argument("-f", "--framework",
type=str, default="tf",
choices=["tf"],
help="specify the framework in which a model should be run")
parser.add_argument("--timeout",
type=float, default=60.0,
help="timeout in seconds")
parser.add_argument("--num_runs",
type=int,
help="number of passes through network to execute")
parser.add_argument("--kits_path",
type=str,
help="path to directory with KiTS19 dataset")
return parser.parse_args()


def run_tf_fp(model_path, num_runs, timeout, kits_path):
import numpy as np
import tensorflow as tf
Expand All @@ -64,27 +41,47 @@ def run_single_pass(tf_runner, kits):
return run_model(run_single_pass, runner, dataset, 1, num_runs, timeout)


def run_pytorch_fp(model_path, num_runs, timeout, kits_path):
import torch
import numpy as np
import tensorflow as tf
from utils.pytorch import PyTorchRunnerV2
from utils.cv.kits import KiTS19
from utils.benchmark import run_model

def run_single_pass(pytorch_runner, kits):
output = pytorch_runner.run(1, torch.from_numpy(np.expand_dims(kits.get_input_array(), axis=0)))
kits.submit_predictions(tf.convert_to_tensor(output.numpy()))

dataset = KiTS19(dataset_dir_path=kits_path)
model = torch.jit.load(model_path, map_location=torch.device('cpu')).eval()
model = torch.jit.freeze(model)
runner = PyTorchRunnerV2(model)

return run_model(run_single_pass, runner, dataset, 1, num_runs, timeout)


def run_tf_fp32(model_path, num_runs, timeout, kits_path, **kwargs):
return run_tf_fp(model_path, num_runs, timeout, kits_path)


def run_pytorch_fp32(model_path, num_runs, timeout, kits_path, **kwargs):
return run_pytorch_fp(model_path, num_runs, timeout, kits_path)


def main():
from utils.misc import print_goodbye_message_and_die
args = parse_args()
if args.framework == "tf":
if args.model_path is None:
print_goodbye_message_and_die(
"a path to model is unspecified!")

if args.precision == "fp32":
run_tf_fp32(**vars(args))
else:
print_goodbye_message_and_die(
"this model seems to be unsupported in a specified precision: " + args.precision)
from utils.helpers import DefaultArgParser
parser = DefaultArgParser(["tf", "pytorch"])
parser.require_model_path()
parser.add_argument("--kits_path",
type=str,
help="path to directory with KiTS19 dataset")

args = parser.parse()
if args.framework == 'tf':
run_tf_fp32(**vars(parser.parse()))
else:
print_goodbye_message_and_die(
"this model seems to be unsupported in a specified framework: " + args.framework)
run_pytorch_fp32(**vars(parser.parse()))


if __name__ == "__main__":
Expand Down
35 changes: 35 additions & 0 deletions tests/test_pytorch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,41 @@ def wrapper(**kwargs):
self.assertTrue(acc["f1"] / f1_ref > 0.95)


class UNET_KITS(unittest.TestCase):
def setUp(self):
self.dataset_path = pathlib.Path(get_downloads_path(), "kits19")
if not self.dataset_path.exists():
# url = os.environ.get("S3_URL_KITS19_REDUCED_DATASET")
url = "https://ampereaimodelzoo.s3.eu-central-1.amazonaws.com/kits19_reduced.tar.gz"
assert url is not None
subprocess.run(f"wget -P /tmp {url}".split(),
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
subprocess.run(f"tar -xf /tmp/kits19_reduced.tar.gz -C {get_downloads_path()}".split(),
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
subprocess.run("rm /tmp/kits19_reduced.tar.gz".split(),
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

self.model_path = pathlib.Path(get_downloads_path(), "3d_unet_kits_pytorch_fp32.ptc")
if not self.model_path.exists():
# url = os.environ.get("S3_URL_UNET_KITS_PYTORCH_FP32")
url = "https://ampereaimodelzoo.s3.eu-central-1.amazonaws.com/3d_unet_kits_pytorch_fp32.ptc"
subprocess.run(f"wget -P {get_downloads_path()} {url}".split(),
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

def test_unet_kits(self):
from computer_vision.semantic_segmentation.unet_3d.kits_19.run import run_pytorch_fp32

def wrapper(**kwargs):
kwargs["q"].put(run_pytorch_fp32(**kwargs)[0])

mean_kidney_acc, mean_tumor_acc = 0.927, 0.837
acc = run_process(wrapper, {"model_path": self.model_path, "kits_path": self.dataset_path,
"batch_size": 1, "num_runs": 500, "timeout": 200, "debug": True})

self.assertTrue(acc["mean_kidney_acc"] / mean_kidney_acc > 0.90)
self.assertTrue(acc["mean_tumor_acc"] / mean_tumor_acc > 0.80)


def download_imagenet_maybe():
dataset_path = pathlib.Path(get_downloads_path(), "ILSVRC2012_onspecta")
if not dataset_path.exists():
Expand Down
Loading