Skip to content

Loading model/checkpoints finetune_best_val_loss.pt fails with AttributeError: Can't get attribute 'ATLASPredictor' #9

@leonardozaggia

Description

@leonardozaggia

First of all, thank you for your great work!

Describe the bug

Attempting to load the fine‑tuned checkpoint provided in the repository causes torch.load to raise an AttributeError because the checkpoint was saved with a custom class (ATLASPredictor) that is not available in the user’s environment.

AttributeError: Can't get attribute 'ATLASPredictor' on <module '__main__'>

Steps to reproduce

import torch, pathlib
ckpt_path = pathlib.Path('checkpoints/finetune_best_val_loss.pt')
device = torch.device('cpu')
torch.load(ckpt_path, map_location=device)  # ⇧ triggers the error above

Environment

Item Version
OS Windows 10
Python 3.9.23 (Conda)
PyTorch 2.3.0
CUDA 11.8
GPU NVIDIA RTX 3080

Additional context

  • The error persists even when forcing weights_only=True on torch.load; that path fails because the checkpoint is a DistributedDataParallel‑wrapped model.
  • A workaround is to instantiate a dummy class ATLASPredictor(torch.nn.Module), but I am not sure if what I am doing is correct.*
Minimal workaround
import torch

class ATLASPredictor(torch.nn.Module):
    """Minimal stub used to unpickle old checkpoints."""

    def __init__(self, model=None):
        super().__init__()
        self.model = model

    def forward(self, *args, **kwargs):  # pragma: no cover - legacy only
        if self.model is None:
            raise RuntimeError("ATLASPredictor missing wrapped model")
        return self.model(*args, **kwargs)


ckpt = torch.load(r'C:\Users\Danie\Downloads\repos\BrainSegFounder\checkpoints\finetune_best_val_loss.pt', map_location=device, weights_only=False)
print(ckpt)

How could I proceed to correctly loading the model weights to run an inference on a single T1 scan? Ideally I would like the output to be a lesion segmentation mask.

What I also tried I tried a lot of scripts, all of which result in a faulty lesion map, see the attached pictures preceding the code that generated the first. Since the finetuned model was expecting 2 input channels, I attached a dummy second channel just to test the script.

I had to do that since I was given the following error:

RuntimeError: Error(s) in loading state_dict for SwinUNETR:
size mismatch for swinViT.patch_embed.proj.weight: copying a param with shape torch.Size([48, 2, 2, 2, 2]) from checkpoint, the shape in current model is torch.Size([48, 1, 2, 2, 2]).

Image
Image

import torch
import monai
import numpy as np
import nibabel as nib
import argparse
import monai
from monai.networks.nets import SwinUNETR
import os

"""
To use this script run:

python infer_lesion_mask.py --checkpoint /path/to/finetune_best_val_loss.pt 
    --image /path/to/patient.nii.gz 
    --output /path/to/predicted_mask.nii.gz 
    --roi 96 96 96

"""

def load_image(image_path, roi):
    img = nib.load(image_path)
    data = img.get_fdata()
    # Add channel dimension if needed
    if data.ndim == 3:
        data = data[np.newaxis, ...]
    # Resize to ROI
    transform = monai.transforms.Compose([
        monai.transforms.ToTensor(),
        monai.transforms.Resize(roi)
    ])
    tensor = transform(data)
    return tensor, img.affine

def save_mask(mask, affine, output_path):
    mask = mask.astype(np.uint8)
    nib.save(nib.Nifti1Image(mask, affine), output_path)

class ATLASPredictor(torch.nn.Module):
    """Minimal stub used to unpickle old checkpoints."""
    def __init__(self, model=None):
        super().__init__()
        self.model = model
    def forward(self, *args, **kwargs):
        if self.model is None:
            raise RuntimeError("ATLASPredictor missing wrapped model")
        return self.model(*args, **kwargs)

from models.ssl_head import SSLHead

def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(args.checkpoint, map_location=device)
    # Build SwinUNETR with the same parameters as in finetune.py
    model = monai.networks.nets.SwinUNETR(
        img_size=tuple(args.roi),
        in_channels=2,           # Match your training setup!
        out_channels=1,
        feature_size=48,
        use_checkpoint=True,
        depths=(2, 2, 2, 2),
        num_heads=(3, 6, 12, 24),
        drop_rate=0.1
    )
    state_dict = checkpoint['state_dict']
    if "module." in list(state_dict.keys())[0]:
        for key in list(state_dict.keys()):
            state_dict[key.replace("module.", "swinViT.")] = state_dict.pop(key)
    model.load_state_dict(state_dict, strict=False)
    model = model.to(device)
    model.eval()

    # --- Load and preprocess image ---
    img, affine = load_image(args.image, args.roi)  # shape: (C, H, W, D)
    if img.shape[0] == 1:
        # Add a dummy second channel (zeros)
        img = np.concatenate([img, np.zeros_like(img)], axis=0)
    img_tensor = torch.from_numpy(img).unsqueeze(0).float().to(device)  # shape: (1, 2, H, W, D)

    # --- Inference ---
    with torch.no_grad():
        pred = model(img_tensor)
        pred_mask = torch.sigmoid(pred).cpu().numpy()[0, 0]  # (H, W, D)
        pred_mask = (pred_mask > 0.5).astype(np.uint8)

    # --- Save mask ---
    save_mask(pred_mask, affine, args.output)
    print(f"Saved predicted mask to {args.output}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Infer lesion mask from MRI using finetuned model.")
    parser.add_argument('--checkpoint', type=str, required=True, help='Path to finetuned model checkpoint (.pt)')
    parser.add_argument('--image', type=str, required=True, help='Path to input MRI image (NIfTI)')
    parser.add_argument('--output', type=str, required=True, help='Path to save predicted mask (NIfTI)')
    parser.add_argument('--roi', nargs=3, type=int, default=[96, 96, 96], help='ROI size used during training')
    args = parser.parse_args()
    main(args)

Thank you in advance for your time!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions