Skip to content

boundary mask for unsupervised training #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 18, 2025

Conversation

stmartineau99
Copy link
Contributor

  • get_unsupervised_loader now accepts a boundary mask and a sampler to generate patches inside the mask
  • the raw data and mask are stacked and written out to an .h5 file, since RawDataset expects a single data file path
  • to avoid altering the source code of RawDataset I came up with the following solutions:
  1. the ChannelSplitterSampler class splits the stacked data again into raw and mask to be used by the MinForegroundSampler
  2. raw_transform is updated using the ComposedTransform class to drop the mask channel after it is no longer needed
  • added GPU ID argument to mean_teacher_adaptation

Copy link
Contributor

@constantinpape constantinpape left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very good, and should achieve exactly what we discussed! I only have a minor change request on supporting both kinds of samplers; see comments for details.

@@ -82,6 +84,10 @@ def mean_teacher_adaptation(
based on the patch_shape and size of the volumes used for training.
n_samples_val: The number of val samples per epoch. By default this will be estimated
based on the patch_shape and size of the volumes used for validation.
train_mask_paths: Boundary masks used by the sampler to accept or reject patches for training.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor semantic comment: I think that this is no necessarily a boundary mask. I think just calling it mask is more precise.

Copy link
Contributor Author

@stmartineau99 stmartineau99 Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jonathan's lamella masker uses the term "boundary mask" so that is why I used it. It makes sense because the mask defines the boundary of the signal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using three different masks in this pipeline now (gradient mask, boundary mask, membrane mask) we need to have different words to describe them. Correct me if there is a more clear way to refer to it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the context here:

  • The "gradient mask" is computed internally only, so we don't need to expose parameters related to it here. But if you want to refer to it for some explanations then calling it "gradient mask" is good.
  • "boundary mask" I would call different, as we use this for accepting / rejecting samples. It does not necessarily have to be on the (spatial) boundary. (And I find the 'boundary of the signal' notion not so intuitive). I would call it "sample mask".
  • I would call the other mask, which you called "membrane mask", "background mask", as we use it to enforce background label in the pseudo labels. In our case this is indeed for membranes, but it could also be for other structures.

@@ -82,6 +84,10 @@ def mean_teacher_adaptation(
based on the patch_shape and size of the volumes used for training.
n_samples_val: The number of val samples per epoch. By default this will be estimated
based on the patch_shape and size of the volumes used for validation.
train_mask_paths: Boundary masks used by the sampler to accept or reject patches for training.
val_mask_paths: Boundary masks used by the sampler to accept or reject patches for validation.
sampler: Accept or reject patches based on a condition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The samplers for the datasets and the mean teacher trainer have slightly different meaning. See also comment below. I think the best approach here is to expose and document two different sampler arguments:

  • patch_sampler: is passed to get_unsupervised_loader
  • pseudo_label_sampler: is passed to MeanTeacherTrainer

Feel free to suggest better names ;).

@@ -155,7 +172,7 @@ def mean_teacher_adaptation(
device=device,
reinit_teacher=reinit_teacher,
save_root=save_root,
sampler=sampler,
sampler=None, # TODO currently set to none cause I didn't want to pass the same sampler used by get_unsupervised_loader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sampler here is applied to the pseudo-labels predicted by the teacher, to give a criterion for rejecting pseudo labels. In contrast, the sampler passed to the loaders rejects patches based on some criterion applied to the data. It makes sense to support both and to pass them with different names; see comment above.


# update variables for RawDataset()
data_paths = tuple(stacked_paths)
base_transform = torch_em.transform.get_raw_transform()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be adapted to only act on channel 0 (the actual raw data.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave this for the next PR.


# update variables for RawDataset()
data_paths = tuple(stacked_paths)
base_transform = torch_em.transform.get_raw_transform()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave this for the next PR.

@constantinpape constantinpape merged commit 43eff47 into computational-cell-analytics:main Jul 18, 2025
3 checks passed
@constantinpape
Copy link
Contributor

I merged the changes @stmartineau99. Good job on these changes!

For the next PR implementing the background masking for excluding boundaries from the pseudo-labeling you should address the following:

  • Add the new mask paths for the background paths, name these paths accordingly.
  • Refactor the code here into a separate function, so that you can also cover the case where the background masks are also given. (And put this into its own function as well)
  • Update the base transform here so that it only acts on the first channel; otherwise the masks will be normalized as well, which doesn't make sense.
  • If the background mask is given, then update the augmentations here so that they are only applied to the first channel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants