-
Notifications
You must be signed in to change notification settings - Fork 3
boundary mask for unsupervised training #132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
boundary mask for unsupervised training #132
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very good, and should achieve exactly what we discussed! I only have a minor change request on supporting both kinds of samplers; see comments for details.
@@ -82,6 +84,10 @@ def mean_teacher_adaptation( | |||
based on the patch_shape and size of the volumes used for training. | |||
n_samples_val: The number of val samples per epoch. By default this will be estimated | |||
based on the patch_shape and size of the volumes used for validation. | |||
train_mask_paths: Boundary masks used by the sampler to accept or reject patches for training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor semantic comment: I think that this is no necessarily a boundary mask. I think just calling it mask is more precise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jonathan's lamella masker uses the term "boundary mask" so that is why I used it. It makes sense because the mask defines the boundary of the signal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are using three different masks in this pipeline now (gradient mask, boundary mask, membrane mask) we need to have different words to describe them. Correct me if there is a more clear way to refer to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the context here:
- The "gradient mask" is computed internally only, so we don't need to expose parameters related to it here. But if you want to refer to it for some explanations then calling it "gradient mask" is good.
- "boundary mask" I would call different, as we use this for accepting / rejecting samples. It does not necessarily have to be on the (spatial) boundary. (And I find the 'boundary of the signal' notion not so intuitive). I would call it "sample mask".
- I would call the other mask, which you called "membrane mask", "background mask", as we use it to enforce background label in the pseudo labels. In our case this is indeed for membranes, but it could also be for other structures.
@@ -82,6 +84,10 @@ def mean_teacher_adaptation( | |||
based on the patch_shape and size of the volumes used for training. | |||
n_samples_val: The number of val samples per epoch. By default this will be estimated | |||
based on the patch_shape and size of the volumes used for validation. | |||
train_mask_paths: Boundary masks used by the sampler to accept or reject patches for training. | |||
val_mask_paths: Boundary masks used by the sampler to accept or reject patches for validation. | |||
sampler: Accept or reject patches based on a condition. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The samplers for the datasets and the mean teacher trainer have slightly different meaning. See also comment below. I think the best approach here is to expose and document two different sampler arguments:
patch_sampler
: is passed toget_unsupervised_loader
pseudo_label_sampler
: is passed toMeanTeacherTrainer
Feel free to suggest better names ;).
@@ -155,7 +172,7 @@ def mean_teacher_adaptation( | |||
device=device, | |||
reinit_teacher=reinit_teacher, | |||
save_root=save_root, | |||
sampler=sampler, | |||
sampler=None, # TODO currently set to none cause I didn't want to pass the same sampler used by get_unsupervised_loader |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sampler here is applied to the pseudo-labels predicted by the teacher, to give a criterion for rejecting pseudo labels. In contrast, the sampler passed to the loaders rejects patches based on some criterion applied to the data. It makes sense to support both and to pass them with different names; see comment above.
|
||
# update variables for RawDataset() | ||
data_paths = tuple(stacked_paths) | ||
base_transform = torch_em.transform.get_raw_transform() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be adapted to only act on channel 0 (the actual raw data.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's leave this for the next PR.
|
||
# update variables for RawDataset() | ||
data_paths = tuple(stacked_paths) | ||
base_transform = torch_em.transform.get_raw_transform() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's leave this for the next PR.
43eff47
into
computational-cell-analytics:main
I merged the changes @stmartineau99. Good job on these changes! For the next PR implementing the background masking for excluding boundaries from the pseudo-labeling you should address the following:
|
get_unsupervised_loader
now accepts a boundary mask and a sampler to generate patches inside the maskRawDataset
expects a single data file pathMinForegroundSampler
ComposedTransform
class to drop the mask channel after it is no longer neededmean_teacher_adaptation