Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python

__doc__ = """
Simple example script showing how to use the Sybil library locally to predict risk scores for a set of DICOM files.
"""

import sybil
from sybil import visualize_attentions

from utils import get_demo_data


def main():
# Load a trained model
model = sybil.Sybil("sybil_ensemble")

dicom_files = get_demo_data()

# Get risk scores
serie = sybil.Serie(dicom_files)
print(f"Processing {len(dicom_files)} DICOM files")
prediction = model.predict([serie], return_attentions=True)
scores = prediction.scores

print(f"Risk scores: {scores}")

# Visualize attention maps
output_dir = "sybil_attention_output"

print(f"Writing attention images to {output_dir}")
series_with_attention = visualize_attentions(
serie,
attentions=prediction.attentions,
save_directory=output_dir,
gain=3,
)

print(f"Finished writing attention images to {output_dir}")

if __name__ == "__main__":
main()
77 changes: 77 additions & 0 deletions examples/remote_ark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python

__doc__ = """
This example shows how to use a client to access a
remote Sybil server (running Ark) to predict risk scores for a set of DICOM files.

The server must be started separately.

https://github.com/reginabarzilaygroup/Sybil/wiki
https://github.com/reginabarzilaygroup/ark/wiki
"""
import json
import os

import numpy as np
import requests

import sybil.utils.visualization

from utils import get_demo_data

if __name__ == "__main__":

dicom_files = get_demo_data()
serie = sybil.Serie(dicom_files)

# Set the URL of the remote Sybil server
ark_hostname = "localhost"
ark_port = 5000

# Set the URL of the remote Sybil server
ark_host = f"http://{ark_hostname}:{ark_port}"

data_dict = {"return_attentions": True}
payload = {"data": json.dumps(data_dict)}

# Check if the server is running and reachable
resp = requests.get(f"{ark_host}/info")
if resp.status_code != 200:
raise ValueError(f"Failed to connect to ARK server. Status code: {resp.status_code}")

info_data = resp.json()["data"]
assert info_data["modelName"].lower() == "sybil", "The ARK server is not running Sybil"
print(f"ARK server info: {info_data}")

# Submit prediction to ARK server.
files = [('dicom', open(file_path, 'rb')) for file_path in dicom_files]
r = requests.post(f"{ark_host}/dicom/files", files=files, data=payload)
_ = [f[1].close() for f in files]
if r.status_code != 200:
raise ValueError(f"Error occurred processing DICOM files. Status code: {r.status_code}.\n{r.text}")

r_json = r.json()
predictions = r_json["data"]["predictions"]

scores = predictions[0]
print(f"Risk scores: {scores}")

attentions = predictions[1]
attentions = np.array(attentions)
print(f"Ark received attention shape: {attentions.shape}")

# Visualize attention maps
save_directory = "remote_ark_sybil_attention_output"

print(f"Writing attention images to {save_directory}")

images = serie.get_raw_images()
overlayed_images = sybil.utils.visualization.build_overlayed_images(images, attentions, gain=3)

if save_directory is not None:
serie_idx = 0
save_path = os.path.join(save_directory, f"serie_{serie_idx}")
sybil.utils.visualization.save_images(overlayed_images, save_path, f"serie_{serie_idx}")

print(f"Finished writing attention images to {save_directory}")

42 changes: 42 additions & 0 deletions examples/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from urllib.request import urlopen


def download_file(url, filepath):
response = urlopen(url)

target_dir = os.path.dirname(filepath)
if target_dir and not os.path.exists(target_dir):
os.makedirs(target_dir)

# Check if the request was successful
if response.status == 200:
with open(filepath, 'wb') as f:
f.write(response.read())
else:
print(f"Failed to download file. Status code: {response.status_code}")

return filepath

def get_demo_data():
demo_data_url = "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&st=dqi0cf9k&dl=1"

zip_file_name = "sybil_example.zip"
cache_dir = os.path.expanduser("~/.sybil")
zip_file_path = os.path.join(cache_dir, zip_file_name)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(zip_file_path):
print(f"Downloading demo data to {zip_file_path}")
download_file(demo_data_url, zip_file_path)

demo_data_dir = os.path.join(cache_dir, "sybil_example")
image_data_dir = os.path.join(demo_data_dir, "sybil_demo_data")
if not os.path.exists(demo_data_dir):
print(f"Extracting demo data to {demo_data_dir}")
import zipfile
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(demo_data_dir)

dicom_files = os.listdir(image_data_dir)
dicom_files = [os.path.join(image_data_dir, x) for x in dicom_files]
return dicom_files
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ python_requires = >=3.8,<3.11
# For more information, check out https://semver.org/.
install_requires =
importlib-metadata; python_version>="3.8"
albumentations==1.1.0
imageio==2.34.1
numpy==1.24.1
opencv-python==4.5.4.60
opencv-python-headless==4.5.4.60
pillow>=10.2.0
pydicom==2.3.0
pylibjpeg[all]==2.0.0
scikit-learn==1.0.2
torch==1.13.1+cu117; platform_machine == "x86_64"
torch==1.13.1; platform_machine != "x86_64"
torchio==0.18.74
Expand All @@ -68,8 +67,10 @@ testing =
mypy
black
train =
albumentations==1.1.0
lifelines==0.26.4
pytorch_lightning==1.6.0
scikit-learn==1.0.2

[options.entry_points]
console_scripts =
Expand Down
4 changes: 2 additions & 2 deletions sybil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from sybil.model import Sybil
from sybil.serie import Serie
from sybil.utils.visualization import visualize_attentions
from sybil.utils.visualization import visualize_attentions, collate_attentions
import sybil.utils.logging_utils

__all__ = ["Sybil", "Serie", "visualize_attentions", "__version__"]
__all__ = ["Sybil", "Serie", "visualize_attentions", "collate_attentions", "__version__"]
28 changes: 24 additions & 4 deletions sybil/augmentations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import cv2
import torch
import torchvision
import albumentations as A
from albumentations.pytorch import ToTensorV2

from typing import Literal
from abc import ABCMeta, abstractmethod
import numpy as np
import random

try:
import albumentations as A
except ImportError:
# albumentations is not installed, training with augmentations will not be possible
A = None


def get_augmentations(split: Literal["train", "dev", "test"], args):
if split == "train":
Expand Down Expand Up @@ -94,7 +100,6 @@ class ToTensor(Abstract_augmentation):

def __init__(self):
super(ToTensor, self).__init__()
self.transform = ToTensorV2()
self.name = "totensor"

def __call__(self, input_dict, sample=None):
Expand All @@ -104,6 +109,20 @@ def __call__(self, input_dict, sample=None):
return input_dict


class ResizeTransform:
def __init__(self, width, height):
self.width = width
self.height = height

def __call__(self, image=None, mask=None):
out = {"image": None, "mask": None}
if image is not None:
out["image"] = cv2.resize(image, dsize=(self.width, self.height), interpolation=cv2.INTER_LINEAR)
if mask is not None:
out["mask"] = cv2.resize(mask, dsize=(self.width, self.height), interpolation=cv2.INTER_NEAREST)
return out


class Scale_2d(Abstract_augmentation):
"""
Given PIL image, enforce its some set size
Expand All @@ -115,7 +134,7 @@ def __init__(self, args, kwargs):
assert len(kwargs.keys()) == 0
width, height = args.img_size
self.set_cachable(width, height)
self.transform = A.Resize(height, width)
self.transform = ResizeTransform(width, height)

def __call__(self, input_dict, sample=None):
out = self.transform(
Expand All @@ -138,6 +157,7 @@ def __init__(self, args, kwargs):
super(Rotate_Range, self).__init__()
assert len(kwargs.keys()) == 1
self.max_angle = int(kwargs["deg"])
assert A is not None, "albumentations is not installed"
self.transform = A.Rotate(limit=self.max_angle, p=0.5)

def __call__(self, input_dict, sample=None):
Expand Down
12 changes: 6 additions & 6 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sybil.models.calibrator import SimpleClassifierGroup
from sybil.utils.logging_utils import get_logger
from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info
from sybil.utils.metrics import get_survival_metrics


# Leaving this here for a bit; these are IDs to download the models from Google Drive
Expand Down Expand Up @@ -67,7 +66,7 @@
},
}

CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://www.dropbox.com/scl/fi/45rtadfdci0bj8dbpotmr/sybil_checkpoints_v1.5.0.zip?rlkey=n8n7pvhb89pjoxgvm90mtbtuk&dl=1")
CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.5.0/sybil_checkpoints.zip")


class Prediction(NamedTuple):
Expand Down Expand Up @@ -107,12 +106,12 @@ def download_sybil(name, cache) -> Tuple[List[str], str]:
return download_model_paths, download_calib_path


def download_and_extract(remote_model_url: str, local_model_dir) -> List[str]:
resp = urlopen(remote_model_url)
os.makedirs(local_model_dir, exist_ok=True)
def download_and_extract(remote_url: str, local_dir: str) -> List[str]:
os.makedirs(local_dir, exist_ok=True)
resp = urlopen(remote_url)
with ZipFile(BytesIO(resp.read())) as zip_file:
all_files_and_dirs = zip_file.namelist()
zip_file.extractall(local_model_dir)
zip_file.extractall(local_dir)
return all_files_and_dirs


Expand Down Expand Up @@ -379,6 +378,7 @@ def evaluate(
Output evaluation. See details for :class:`~sybil.model.Evaluation`.

"""
from sybil.utils.metrics import get_survival_metrics
if isinstance(series, Serie):
series = [series]
elif not isinstance(series, list):
Expand Down
Loading
Loading