-
Notifications
You must be signed in to change notification settings - Fork 1
Prototype example #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-0.1
Are you sure you want to change the base?
Changes from all commits
d7a03b1
ef44247
6a2b4c3
bea6c02
5bfaad6
31f2f1b
3496984
c187a31
3e921b9
8af1475
744efcc
81f3ccd
c8a27ea
b7c0df2
a311fd7
e466533
b5a6891
10a8765
cde174c
6b3d914
022a46d
74b7b02
12d5f45
d123b20
f33b906
ae573dd
d57d80a
09b63d5
1d63fad
6568d01
755d162
b537c83
1f7db29
75d5adc
071aa87
d7f62b8
cab2820
3f68f2f
4ea24a9
6d362b0
b3e69a1
13cb298
93b83f2
5380ac7
9a27c23
c284f05
a6f40f6
778144b
b270630
2eefe5c
756429c
30f6b17
7ac89b8
ed7d940
cec0df6
7f20d7a
825b745
ed47a06
63829d7
18dec2c
c0b00b6
90660ba
80be3dd
aa7b531
d7eb27e
5470fd4
65556a3
9751437
1b5d7ec
a052113
2ef624f
2e1532e
fda7de7
f909e34
1413280
0da8875
8d5a616
baeb80b
fc76155
d78beae
e23dd55
8e652b8
e0cbf41
c6ce487
1b81e1e
ce391c4
741c121
220c561
8c7e97c
c45fabf
207c75e
9596e53
ad93208
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# images and anything under mlflow | ||
*.png | ||
examples/example_train*/* | ||
|
||
# pycache | ||
*.pyc |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from abc import ABC | ||
|
||
class AbstractCallback(ABC): | ||
""" | ||
Abstract class for callbacks in the training process. | ||
Callbacks can be used to plot intermediate metrics, log contents, save checkpoints, etc. | ||
""" | ||
|
||
def __init__(self, name: str): | ||
""" | ||
:param name: Name of the callback. | ||
""" | ||
self._name = name | ||
self._trainer = None | ||
|
||
@property | ||
def name(self): | ||
""" | ||
Getter for callback name | ||
""" | ||
return self._name | ||
|
||
@property | ||
def trainer(self): | ||
""" | ||
Allows for access of trainer | ||
""" | ||
return self._trainer | ||
|
||
def _set_trainer(self, trainer): | ||
""" | ||
Helper function called by trainer class to initialize trainer value field | ||
|
||
:param trainer: trainer object | ||
:type trainer: AbstractTrainer or subclass | ||
""" | ||
|
||
self._trainer = trainer | ||
|
||
def on_train_start(self): | ||
""" | ||
Called at the start of training. | ||
""" | ||
pass | ||
|
||
def on_epoch_start(self): | ||
""" | ||
Called at the start of each epoch. | ||
""" | ||
pass | ||
|
||
def on_epoch_end(self): | ||
""" | ||
Called at the end of each epoch. | ||
""" | ||
pass | ||
|
||
def on_train_end(self): | ||
""" | ||
Called at the end of training. | ||
""" | ||
pass |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you plan to include multiple plotting classes in the same file? If not, consider changing the name of this file to match the class name, since this class is specific to plotting patches. I could also see a scenario where the plotting callback classes are in a separate folder inside the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am working on a major overhaul of the evaluation and plotting suite outside of this PR that would have separate plotting helper functions for ImageDataset vs PatchDataset that live under a single py file under a single evaluation folder and a single IntermediatePlot call back class would determine which to use to depending on the dataset it was given. I can merge the overhaul into this PR once I am done with that or it could be a separate PR. What are your thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That seems like a good idea to me, even if many of the helper functions are only for one callback. This is a design decision, which will depend on what is being plotted and how much complexity is involved. I think a separate pr would be a better choice for a change like this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just did with the push last night. The new plot files under visualization_utils would detect dataset type and plot 3 columns of the input, target and the predicted image for standard Image Dataset (ImageDataset and CachedDataset), for PatchDataset it will plot an additional column of the raw image where the patches are being cropped from at the left-most. There are also two versions of the plot function, one can operate on the inference and evaluation results (predictions from a previous forward pass and corresponding metrics computed) the other internally computes the inference and evaluation metrics so everything is self-contained. For now the self-contained version is used by the IntermediatePlot callback class for the sake of quickness of implementation, which will result in redundant evaluation and inference of n images per epoch where n is the number of patches/images being plotted. In a future version we can re-work the way of the trainer communicating with the callbacks to remove this redundant computation. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import pathlib | ||
from typing import List, Union | ||
import random | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import Dataset | ||
|
||
from .AbstractCallback import AbstractCallback | ||
from ..datasets.PatchDataset import PatchDataset | ||
from ..evaluation.visualization_utils import plot_predictions_grid_from_model | ||
|
||
class IntermediatePlot(AbstractCallback): | ||
""" | ||
Callback to plot model generated outputs, ground | ||
truth, and input stained image patches at the end of each epoch. | ||
""" | ||
|
||
def __init__(self, | ||
name: str, | ||
path: Union[pathlib.Path, str], | ||
dataset: Union[Dataset, PatchDataset], | ||
plot_n_patches: int=5, | ||
indices: Union[List[int], None]=None, | ||
plot_metrics: List[nn.Module]=None, | ||
every_n_epochs: int=5, | ||
random_seed: int=42, | ||
**kwargs): | ||
""" | ||
Initialize the IntermediatePlot callback. | ||
Allows plots of predictions to be generated during training for monitoring of training progress. | ||
Supports both PatchDataset and Dataset classes for plotting. | ||
This callback, when passed into the trainer, will plot the model predictions on a subset of the provided dataset at the end of each epoch. | ||
|
||
:param name: Name of the callback. | ||
:param path: Path to save the model weights. | ||
:type path: Union[pathlib.Path, str] | ||
:param dataset: Dataset to be used for plotting intermediate results. | ||
:type dataset: Union[Dataset, PatchDataset] | ||
:param plot_n_patches: Number of patches to randomly select and plot, defaults to 5. | ||
The exact patches/images being plotted may vary due to a difference in seed or dataset size. | ||
To ensure best reproducibility and consistency, please use a fixed dataset and indices argument instead. | ||
:type plot_n_patches: int, optional | ||
:param indices: Optional list of specific indices to subset the dataset before inference. | ||
Overrides the plot_n_patches and random_seed arguments and uses the indices list to subset. | ||
:type indices: Union[List[int], None] | ||
:param plot_metrics: List of metrics to compute and display in plot title, defaults to None. | ||
:type plot_metrics: List[nn.Module], optional | ||
:param kwargs: Additional keyword arguments to be passed to plot_patches. | ||
:type kwargs: dict | ||
:param every_n_epochs: How frequent should intermediate plots should be plotted, defaults to 5 | ||
:type every_n_epochs: int | ||
:param random_seed: Random seed for reproducibility for random patch/image selection, defaults to 42. | ||
:type random_seed: int | ||
:raises TypeError: If the dataset is not an instance of PatchDataset. | ||
""" | ||
super().__init__(name) | ||
self._path = path | ||
if isinstance(dataset, Dataset): | ||
pass | ||
elif isinstance(dataset, PatchDataset): | ||
pass | ||
else: | ||
raise TypeError(f"Expected PatchDataset, got {type(dataset)}") | ||
|
||
self._dataset = dataset | ||
|
||
# Additional kwargs passed to plot_patches | ||
self.plot_metrics = plot_metrics | ||
self.every_n_epochs = every_n_epochs | ||
self.plot_kwargs = kwargs | ||
|
||
if indices is not None: | ||
# Check if indices are within bounds | ||
for i in indices: | ||
if i >= len(self._dataset): | ||
raise ValueError(f"Index {i} out of bounds for dataset of size {len(self._dataset)}") | ||
self._dataset_subset_indices = indices | ||
else: | ||
# Generate random indices to subset given seed and plot_n_patches | ||
plot_n_patches = min(plot_n_patches, len(self._dataset)) | ||
random.seed(random_seed) | ||
self._dataset_subset_indices = random.sample(range(len(self._dataset)), plot_n_patches) | ||
|
||
def on_epoch_end(self): | ||
""" | ||
Called at the end of each epoch to plot predictions if the epoch is a multiple of `every_n_epochs`. | ||
""" | ||
if (self.trainer.epoch + 1) % self.every_n_epochs == 0 or self.trainer.epoch + 1 == self.trainer.epoch: | ||
self._plot() | ||
|
||
def on_train_end(self): | ||
""" | ||
Called at the end of training. Plots if not already done in the last epoch. | ||
""" | ||
if (self.trainer.epoch + 1) % self.every_n_epochs != 0: | ||
self._plot() | ||
|
||
def _plot(self): | ||
""" | ||
Helper method to generate and save plots. | ||
Plot dataset with model predictions on n random images from dataset at the end of each epoch. | ||
Called by the on_epoch_end and on_train_end methods | ||
""" | ||
|
||
original_device = next(self.trainer.model.parameters()).device | ||
wli51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
plot_predictions_grid_from_model( | ||
model=self.trainer.model, | ||
dataset=self._dataset, | ||
indices=self._dataset_subset_indices, | ||
metrics=self.plot_metrics, | ||
save_path=f"{self._path}/epoch_{self.trainer.epoch}.png", | ||
device=original_device, | ||
show=False, | ||
**self.plot_kwargs | ||
) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I missed this in the code. I think we want to also save the images after each epoch as artifacts with mlflow. I do this with cropped nuclei by reserving a folder for each nuclei, where the images for each epoch are saved in that folder. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a few ways you could go about doing this (e.g including more code in this logger or creating classes for each type of data saved). By types of data I mean: metrics, parameters, artifacts, etc... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optuna also uses algorithms to understand hyperparameter importances, so I think we will also want to save the optuna object as an artifact for optimizing models There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my current implementation, the mlflow logger and image plotting and optuna do not go hand in hand. I am less familiar with mlflow than you are and what I think you are proposing makes a lot more sense and improves the flow. I do think this amount of extra feature is perhaps more suitable in a separate PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so too. If you have questions, just let me know |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import os | ||
import pathlib | ||
import tempfile | ||
from typing import Union, Dict, Optional | ||
|
||
import mlflow | ||
import torch | ||
|
||
from .AbstractCallback import AbstractCallback | ||
|
||
class MlflowLogger(AbstractCallback): | ||
""" | ||
Callback to log metrics to MLflow. | ||
""" | ||
|
||
def __init__(self, | ||
|
||
name: str, | ||
artifact_name: str = 'best_model_weights.pth', | ||
mlflow_uri: Union[pathlib.Path, str] = None, | ||
mlflow_experiment_name: Optional[str] = None, | ||
mlflow_start_run_args: dict = None, | ||
mlflow_log_params_args: dict = None, | ||
|
||
): | ||
""" | ||
Initialize the MlflowLogger callback. | ||
|
||
:param name: Name of the callback. | ||
:param artifact_name: Name of the artifact file to log, defaults to 'best_model_weights.pth'. | ||
:param mlflow_uri: URI for the MLflow tracking server, defaults to None. | ||
If a path is specified, the logger class will call set_tracking_uri to that supplied path | ||
thereby initiating a new tracking server. | ||
If None (default), the logger class will not tamper with mlflow server to enable logging to a global server | ||
initialized outside of this class. | ||
:type mlflow_uri: pathlib.Path or str, optional | ||
:param mlflow_experiment_name: Name of the MLflow experiment, defaults to None, which will not call the | ||
set_experiment method of mlflow and will use whichever experiment name that is globally configured. If a | ||
name is provided, the logger class will call set_experiment to that supplied name. | ||
:type mlflow_experiment_name: str, optional | ||
:param mlflow_start_run_args: Additional arguments for starting an MLflow run, defaults to None. | ||
:type mlflow_start_run_args: dict, optional | ||
:param mlflow_log_params_args: Additional arguments for logging parameters to MLflow, defaults to None. | ||
:type mlflow_log_params_args: dict, optional | ||
""" | ||
super().__init__(name) | ||
|
||
if mlflow_uri is not None: | ||
try: | ||
mlflow.set_tracking_uri(mlflow_uri) | ||
except Exception as e: | ||
raise RuntimeError(f"Error setting MLflow tracking URI: {e}") | ||
|
||
if mlflow_experiment_name is not None: | ||
try: | ||
mlflow.set_experiment(mlflow_experiment_name) | ||
except Exception as e: | ||
raise RuntimeError(f"Error setting MLflow experiment: {e}") | ||
|
||
self._artifact_name = artifact_name | ||
self._mlflow_start_run_args = mlflow_start_run_args | ||
self._mlflow_log_params_args = mlflow_log_params_args | ||
|
||
def on_train_start(self): | ||
""" | ||
Called at the start of training. | ||
|
||
Calls mlflow start run and logs params if provided | ||
""" | ||
|
||
if self._mlflow_start_run_args is None: | ||
pass | ||
elif isinstance(self._mlflow_start_run_args, Dict): | ||
mlflow.start_run( | ||
**self._mlflow_start_run_args | ||
) | ||
else: | ||
raise TypeError("mlflow_start_run_args must be None or a dictionary.") | ||
|
||
if self._mlflow_log_params_args is None: | ||
pass | ||
elif isinstance(self._mlflow_log_params_args, Dict): | ||
mlflow.log_params( | ||
self._mlflow_log_params_args | ||
) | ||
else: | ||
raise TypeError("mlflow_log_params_args must be None or a dictionary.") | ||
|
||
def on_epoch_end(self): | ||
""" | ||
Called at the end of each epoch. | ||
|
||
Iterate over the most recent log items in trainer and call mlflow log metric | ||
""" | ||
for key, values in self.trainer.log.items(): | ||
if values is not None and len(values) > 0: | ||
value = values[-1] | ||
else: | ||
value = None | ||
mlflow.log_metric(key, value, step=self.trainer.epoch) | ||
|
||
def on_train_end(self): | ||
""" | ||
Called at the end of training. | ||
|
||
Saves trainer best model to a temporary directory and calls mlflow log artifact | ||
Then ends run | ||
""" | ||
# Save weights to a temporary directory and log artifacts | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
weights_path = os.path.join(tmpdirname, self._artifact_name) | ||
torch.save(self.trainer.best_model, weights_path) | ||
mlflow.log_artifact(weights_path, artifact_path="models") | ||
|
||
mlflow.end_run() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Here lives the callback classes that are meant to be fed into trainers to do stuff like saving images every epoch and logging. | ||
|
||
The callback classes must inherit the abstract class. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
name: cp_gan_env | ||
channels: | ||
- anaconda | ||
- pytorch | ||
- nvidia | ||
- conda-forge | ||
dependencies: | ||
- conda-forge::python=3.9 | ||
- conda-forge::pip | ||
- pytorch::pytorch | ||
- pytorch::torchvision | ||
- pytorch::torchaudio | ||
- pytorch::pytorch-cuda=12.1 | ||
- conda-forge::seaborn | ||
- conda-forge::matplotlib | ||
- conda-forge::jupyter | ||
- conda-forge::pre_commit | ||
- conda-forge::pandas | ||
- conda-forge::pillow | ||
- conda-forge::numpy | ||
- conda-forge::pathlib2 | ||
- conda-forge::scikit-learn | ||
- conda-forge::opencv | ||
- conda-forge::pyarrow | ||
- conda-forge::ipython | ||
- conda-forge::notebook | ||
- conda-forge::albumentations | ||
- conda-forge::optuna | ||
- conda-forge::mysqlclient | ||
- conda-forge::openjdk | ||
- conda-forge::gtk2 | ||
- conda-forge::typing-extensions | ||
- conda-forge::Jinja2 | ||
- conda-forge::inflect | ||
- conda-forge::wxpython | ||
- conda-forge::sentry-sdk | ||
- pip: | ||
- mlflow | ||
- cellprofiler==4.2.8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would recommend either consistently using type hints or using the function docstring for types. I think type hints are generally better though
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I purposefully did not include type hints to trainer objects here because the trainer classes already have type hints of the callback classes and adding type hints to trainer class here in the callback class would cause the problem of circular importing. This is probably a consequence of my sub-optimal class design that can needs to be solved with some refactoring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember if this is the way to do it exactly, but you can do something like this:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although, maybe there is a better alternative here such as changing the design