-
Notifications
You must be signed in to change notification settings - Fork 9
Description
First of all, thank you for your great work!
Describe the bug
Attempting to load the fine‑tuned checkpoint provided in the repository causes torch.load
to raise an AttributeError
because the checkpoint was saved with a custom class (ATLASPredictor
) that is not available in the user’s environment.
AttributeError: Can't get attribute 'ATLASPredictor' on <module '__main__'>
Steps to reproduce
import torch, pathlib
ckpt_path = pathlib.Path('checkpoints/finetune_best_val_loss.pt')
device = torch.device('cpu')
torch.load(ckpt_path, map_location=device) # ⇧ triggers the error above
Environment
Item | Version |
---|---|
OS | Windows 10 |
Python | 3.9.23 (Conda) |
PyTorch | 2.3.0 |
CUDA | 11.8 |
GPU | NVIDIA RTX 3080 |
Additional context
- The error persists even when forcing
weights_only=True
ontorch.load
; that path fails because the checkpoint is a DistributedDataParallel‑wrapped model. - A workaround is to instantiate a dummy
class ATLASPredictor(torch.nn.Module)
, but I am not sure if what I am doing is correct.*
Minimal workaround
import torch
class ATLASPredictor(torch.nn.Module):
"""Minimal stub used to unpickle old checkpoints."""
def __init__(self, model=None):
super().__init__()
self.model = model
def forward(self, *args, **kwargs): # pragma: no cover - legacy only
if self.model is None:
raise RuntimeError("ATLASPredictor missing wrapped model")
return self.model(*args, **kwargs)
ckpt = torch.load(r'C:\Users\Danie\Downloads\repos\BrainSegFounder\checkpoints\finetune_best_val_loss.pt', map_location=device, weights_only=False)
print(ckpt)
How could I proceed to correctly loading the model weights to run an inference on a single T1 scan? Ideally I would like the output to be a lesion segmentation mask.
What I also tried
I tried a lot of scripts, all of which result in a faulty lesion map, see the attached pictures preceding the code that generated the first. Since the finetuned model was expecting 2 input channels, I attached a dummy second channel just to test the script.I had to do that since I was given the following error:
RuntimeError: Error(s) in loading state_dict for SwinUNETR:
size mismatch for swinViT.patch_embed.proj.weight: copying a param with shape torch.Size([48, 2, 2, 2, 2]) from checkpoint, the shape in current model is torch.Size([48, 1, 2, 2, 2]).
import torch
import monai
import numpy as np
import nibabel as nib
import argparse
import monai
from monai.networks.nets import SwinUNETR
import os
"""
To use this script run:
python infer_lesion_mask.py --checkpoint /path/to/finetune_best_val_loss.pt
--image /path/to/patient.nii.gz
--output /path/to/predicted_mask.nii.gz
--roi 96 96 96
"""
def load_image(image_path, roi):
img = nib.load(image_path)
data = img.get_fdata()
# Add channel dimension if needed
if data.ndim == 3:
data = data[np.newaxis, ...]
# Resize to ROI
transform = monai.transforms.Compose([
monai.transforms.ToTensor(),
monai.transforms.Resize(roi)
])
tensor = transform(data)
return tensor, img.affine
def save_mask(mask, affine, output_path):
mask = mask.astype(np.uint8)
nib.save(nib.Nifti1Image(mask, affine), output_path)
class ATLASPredictor(torch.nn.Module):
"""Minimal stub used to unpickle old checkpoints."""
def __init__(self, model=None):
super().__init__()
self.model = model
def forward(self, *args, **kwargs):
if self.model is None:
raise RuntimeError("ATLASPredictor missing wrapped model")
return self.model(*args, **kwargs)
from models.ssl_head import SSLHead
def main(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(args.checkpoint, map_location=device)
# Build SwinUNETR with the same parameters as in finetune.py
model = monai.networks.nets.SwinUNETR(
img_size=tuple(args.roi),
in_channels=2, # Match your training setup!
out_channels=1,
feature_size=48,
use_checkpoint=True,
depths=(2, 2, 2, 2),
num_heads=(3, 6, 12, 24),
drop_rate=0.1
)
state_dict = checkpoint['state_dict']
if "module." in list(state_dict.keys())[0]:
for key in list(state_dict.keys()):
state_dict[key.replace("module.", "swinViT.")] = state_dict.pop(key)
model.load_state_dict(state_dict, strict=False)
model = model.to(device)
model.eval()
# --- Load and preprocess image ---
img, affine = load_image(args.image, args.roi) # shape: (C, H, W, D)
if img.shape[0] == 1:
# Add a dummy second channel (zeros)
img = np.concatenate([img, np.zeros_like(img)], axis=0)
img_tensor = torch.from_numpy(img).unsqueeze(0).float().to(device) # shape: (1, 2, H, W, D)
# --- Inference ---
with torch.no_grad():
pred = model(img_tensor)
pred_mask = torch.sigmoid(pred).cpu().numpy()[0, 0] # (H, W, D)
pred_mask = (pred_mask > 0.5).astype(np.uint8)
# --- Save mask ---
save_mask(pred_mask, affine, args.output)
print(f"Saved predicted mask to {args.output}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Infer lesion mask from MRI using finetuned model.")
parser.add_argument('--checkpoint', type=str, required=True, help='Path to finetuned model checkpoint (.pt)')
parser.add_argument('--image', type=str, required=True, help='Path to input MRI image (NIfTI)')
parser.add_argument('--output', type=str, required=True, help='Path to save predicted mask (NIfTI)')
parser.add_argument('--roi', nargs=3, type=int, default=[96, 96, 96], help='ROI size used during training')
args = parser.parse_args()
main(args)
Thank you in advance for your time!