This repository is based on SegVol-for-SegFM, and aims to extend it by incorporating a stateful GRU-based correction mechanism for sequential refinement in interactive medical image segmentation.
Our framework consists of three core components:
- Box-Initialized Segmentation
- Uncertainty-Aware Interaction Sampling
- GRU-Based Sequential Correction
Figure: Three-stage pipeline. (a) SegVol generates initial mask from box prompt; (b) Uncertainty and error regions are sampled to build sequential state tensor for GRU; (c) GRU corrects predictions using sequential states.
Given volumetric input
We sample voxels requiring correction from:
-
Uncertainty regions: Voxels with ambiguous probability (
$|p - 0.5| < \tau$ ). - Most likely error regions: Discrepancy around user clicks.
Uncertainty points:
Error points (simulated from user click
$$ \mathcal{E}t = \bigcup{n} {(i,j,k) \mid | (i,j,k) - (x_n, y_n, z_n) | \leq r,; \hat{y}_t(i,j,k) \neq l_n } $$
Both sets are ranked and truncated or padded to
Figure: Sampling workflow for uncertainty and error points.
Each sampled point is represented by an 8D feature vector including:
- Normalized coordinates
- Probability or user label
- Padding mask
- First-click global context
These vectors form two feature matrices:
$\mathbf{F}_t^{\text{unc}} \in \mathbb{R}^{K \times 8}$ $\mathbf{F}_t^{\text{err}} \in \mathbb{R}^{K \times 8}$
They are concatenated into: $$ \mathbf{X}_t = [\mathbf{F}_t^{\text{unc}}; \mathbf{F}_t^{\text{err}}] \in \mathbb{R}^{2K \times 8} $$
The original coordinates are stored in: $$ \mathbf{C}_t \in \mathbb{R}^{2K \times 3} $$
The GRU processes the sequence of interaction states:
$$ \Delta p_t, \Delta \mathbf{C}t = \text{GRU}\theta({\mathbf{State}_0, \ldots, \mathbf{State}T}, \mathbf{h}{t-1}) $$
-
$\Delta p_t \in \mathbb{R}^{B \times 2K}$ : predicted probability changes -
$\Delta \mathbf{C}_t \in \mathbb{R}^{B \times 2K \times 3}$ : coordinate refinements
Each round of correction includes:
- Predict initial mask from box prompt
- Sample
$\mathcal{U}_t$ and$\mathcal{E}_t$ from prediction error - Construct GRU input features
$\mathbf{X}_t$ and coordinates$\mathbf{C}_t$ - GRU outputs updated logits
$\Delta p_t$ and refined coordinates$\Delta \mathbf{C}_t$
Then, update the mask logits at corrected locations:
where
-
Base Network:
The base segmentation model adopts the SegVol volumetric encoder-decoder architecture, which generates initial segmentation masks from bounding box prompts. -
Correction Module:
A GRU-based sequential refinement module is introduced to correct boundary errors over multiple interaction steps.-
Input:
At each interaction step, the GRU receives a feature tensor of shape$B \times 2 \times (K \times 8)$ , where each of the$2K$ points encodes 8-dimensional features including normalized coordinates, prediction probabilities, binary labels, and user context information. -
GRU Unit:
The GRU processes this temporal sequence and maintains a hidden state that captures the history of interactions. -
Coordinate Fusion:
The GRU output is concatenated with flattened coordinate tensors of shape$B \times 6K$ to retain spatial alignment information. -
Output Layer:
A fully connected layer produces$2K$ refined probability values corresponding to the selected voxel positions.
-
-
Total Number of Parameters:
- GRU parameters depend on
input_size = K × 8
and a definedhidden_size
(e.g., 256). - The FC layer includes parameters for transforming from
hidden_size + 6K
to2K
outputs.
- GRU parameters depend on
We adopt the following metrics to evaluate segmentation performance:
- Dice Similarity Coefficient (DSC)
- Hausdorff Distance (95th percentile)
- Average Surface Distance (ASD)
👉 [Code reference for metrics: link-to-code or metrics.py]
- Dataset Name: CVPR2025 SegFM3D
- Source:
- Public Dataset:
Preprocessing follows the original SegVol pipeline.
Our training strategy builds on the SegVol pipeline, introducing enhanced augmentation, sampling, and interaction-aware optimization.
To improve generalization, we apply spatial and intensity-based augmentations:
-
Spatial Augmentation:
- Random flipping along sagittal, coronal, and axial planes (probability = 0.2 per axis).
- Mixed input strategy: full-volume resizing or patch extraction (3:1 preference for patches).
- Patch extraction via
RandCropByPosNegLabeld
with a 3:1 positive-to-negative sample ratio.
-
Intensity Augmentation:
- Random intensity scaling (±20%) and shifting (±20%), each applied with a 0.2 probability.
We adopt a class-balanced and memory-efficient sampling approach:
- Foreground-aware cropping ensures all available classes are represented.
- Patch-based sampling prioritizes regions with segmentation targets (positive:negative = 3:1).
- Ground truth masks are stored in sparse format to handle large 3D volumes.
- Inference is performed on full volumes, with optional sliding windows for large scans.
A two-stage training approach is designed to mimic real-world user interactions:
- Stage 1: Train with box prompts using global similarity losses to capture coarse structure.
- Stage 2: Add point prompts and fine-tune with boundary-sensitive losses (e.g., distance-based metrics).
To simulate clinical refinement behavior:
- Prioritize 2–3 GRU iterations during training.
- Assign lower probabilities to 1 or 4 steps, and rarely use 0 or 5 steps.
- Use sequential GRU updates with memory retention for realistic interaction modeling.
The optimization combines:
- Global structural similarity loss
- Voxel-wise prediction loss
- Consistency loss between initial and refined predictions
This ensures both stability during training and adaptability to varied clinical use cases.
For important setup instructions including pretrained weights download and testing procedures, please see:
weights_and_testing.md
Postprocessing follows the original SegVol-for-SegFM pipeline:
We thank all the data owners for making the medical images publicly available and CodaLabfor hosting the challenge platform.