Training hyper parameters definable in the config python file #105
rhoadesScholar
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
The included training pipeline takes a
config_path
argument pointing to a configuration file to use for training the model. This file should be a Python file that defines the hyperparameters and other configurations for training the model. This may include:model_save_path
: Path to save the model checkpoints. Default is'checkpoints/{model_name}_{epoch}.pth'
.logs_save_path
: Path to save the logs for tensorboard. Default is'tensorboard/{model_name}'
. Training progress may be monitored by runningtensorboard --logdir <logs_save_path>
in the terminal.datasplit_path
: Path to the datasplit file that defines the train/val split the dataloader should use. Default is'datasplit.csv'
.validation_prob
: Proportion of the datasets to use for validation. This is used if the datasplit CSV specified bydatasplit_path
does not already exist. Default is0.15
.learning_rate
: Learning rate for the optimizer. Default is0.0001
.batch_size
: Batch size for the dataloader. Default is8
.input_array_info
: Dictionary containing the shape and scale of the input data. Default is{'shape': (1, 128, 128), 'scale': (8, 8, 8)}
.target_array_info
: Dictionary containing the shape and scale of the target data. Default is to useinput_array_info
.epochs
: Number of epochs to train the model for. Default is1000
.iterations_per_epoch
: Number of iterations per epoch. Each iteration includes an independently generated random batch from the training set. Default is1000
.random_seed
: Random seed for reproducibility. Default is42
.classes
: List of classes to train the model to predict. This will be reflected in the data included in the datasplit, if generated de novo after calling this script. Default is['nuc', 'er']
.model_name
: Name of the model to use. If the config file constructs the PyTorch model, this name can be anything. If the config file does not construct the PyTorch model, the model_name will need to specify which included architecture to use. This includes'2d_unet'
,'2d_resnet'
,'3d_unet'
,'3d_resnet'
, and'vitnet'
. Default is'2d_unet'
. See themodels
moduleREADME.md
for more information.model_to_load
: Name of the pre-trained model to load. Default is the same asmodel_name
.model_kwargs
: Dictionary of keyword arguments to pass to the model constructor. Default is{}
. If the PyTorchmodel
is passed, this will be ignored. See themodels
moduleREADME.md
for more information.model
: PyTorch model to use for training. If this is provided, themodel_name
andmodel_to_load
can be any string. Default isNone
.load_model
: Which model checkpoint to load if it exists. Options are'latest'
or'best'
. If no checkpoints exist, will silently use the already initialized model. Default is'latest'
.spatial_transforms
: Dictionary of spatial transformations to apply to the training data. Default is{'mirror': {'axes': {'x': 0.5, 'y': 0.5}}, 'transpose': {'axes': ['x', 'y']}, 'rotate': {'axes': {'x': [-180, 180], 'y': [-180, 180]}}}
. See thedataloader
module documentation for more information.validation_time_limit
: Maximum time to spend on validation in seconds. IfNone
, there is no time limit. Default isNone
.validation_batch_limit
: Maximum number of validation batches to process. IfNone
, there is no limit. Default isNone
.device
: Device to use for training. IfNone
, will use'cuda'
if available,'mps'
if available, or'cpu'
otherwise. Default isNone
.use_s3
: Whether to use the S3 bucket for the datasplit. Default isFalse
.optimizer
: PyTorch optimizer to use for training. Default istorch.optim.RAdam(model.parameters(), lr=learning_rate, decoupled_weight_decay=True)
.criterion
: Uninstantiated PyTorch loss function to use for training. Default istorch.nn.BCEWithLogitsLoss
.criterion_kwargs
: Dictionary of keyword arguments to pass to the loss function constructor. Default is{}
.weight_loss
: Whether to weight the loss function by class counts found in the datasets. Default isTrue
.use_mutual_exclusion
: Whether to use mutual exclusion to infer labels for unannotated pixels. Default isFalse
.weighted_sampler
: Whether to use a sampler weighted by class counts for the dataloader. Default isTrue
.train_raw_value_transforms
: Transform to apply to the raw values for training. Defaults toT.Compose([Normalize(), T.ToDtype(torch.float, scale=True), NaNtoNum({"nan": 0, "posinf": None, "neginf": None})])
which normalizes the input data, converts it tofloat32
, and replaces NaNs with0
. This can be used to add augmentations such as random erasing, blur, noise, etc.val_raw_value_transforms
: Transform to apply to the raw values for validation, similar totrain_raw_value_transforms
. Default is the same astrain_raw_value_transforms
.target_value_transforms
: Transform to apply to the target values. Default isT.Compose([T.ToDtype(torch.float), Binarize()])
which converts the input masks tofloat32
and threshold at0
(turning object ID's into binary masks for use with binary cross entropy loss). This can be used to specify other targets, such as distance transforms.max_grad_norm
: Maximum gradient norm for clipping. IfNone
, no clipping is performed. Default isNone
. This can be useful to prevent exploding gradients which would lead to NaNs in the weights.force_all_classes
: Whether to force all classes to be present in each batch provided by dataloaders. Can either beTrue
to force this for both validation and training dataloader,False
to force for neither, ortrain
/validate
to restrict it to training or validation, respectively. Default is'validate'
.scheduler
: PyTorch learning rate scheduler (or uninstantiated class) to use for training. Default isNone
. If provided, the scheduler will be called at the end of each epoch.scheduler_kwargs
: Dictionary of keyword arguments to pass to the scheduler constructor. Default is{}
. Ifscheduler
instantiation is provided, this will be ignored.Anything others would like to see added?
Beta Was this translation helpful? Give feedback.
All reactions