This is a complimentary repository to our paper: Knowledge Distillation for Semantic Segmentation A Label Space Unification Approach.
- python >= 3.10
- pytorch >= 2.3
- accelerate >= 1.2
- transformers >= 4.45.0
$ pip install -r requirements.txt
Huggingface accelerate is a wrapper used mainly for multi-gpu and half-precision training. You can adjust the settings prior to training with (recommended for faster training) or just skip it:
$ accelerate config
Model | Taxonomy | IoU |
---|---|---|
M2FB | GOOSE | 64.4 |
M2FL | GOOSE | 67.9 |
M2FB | Cityscapes | 75.5 |
M2FL | Cityscapes | 78.3 |
M2FL | Mapillary | 52.7 |
Train a standard Mask2Former on a source dataset.
$ accelerate launch train.py --config config.yaml --exper_name <experiment_name>
Track progress in Tensorboard:
$ tensorboard --logdir experiments/<experiment_name>/logs
Before you can generate pseudo-labels with priors, you need to define an ontology mapping between target and extra datasets.
You can find plenty example mappings in datasets/<dataset_name>/lists/master_labels_<source_dataset>.csv
.
For a full explanation, see the datasets README.
Generate pseudo-labels using dataset priors. Use config_inference.yaml to set your labeling parameters.
$ accelerate launch inference.py
Now that you have generated labels with priors, you can adjust the config.yaml to include the pseudo-labels.
$ accelerate launch train.py --config config.yaml --exper_name <experiment_name>
Inference is done using the same script as pseudo-label generation. Adjust the config_inference.yaml to save or display images with or w/o priors.
$ accelerate launch inference.py
I have trained a model using a different framework and would only like to do some pseudo-labeling. How do I do that?
Currently, this repository only supports Huggingface's Mask2Former.
If you would like to implement your own model, you will have to adjust the model loading part and the handling of the model output up to the softmax in inference.py
.
You must add a Dataset class to DatasetClasses.py
and make it available in the get_dataset
method in DataLoader.py
.