Skip to content

Prototype example #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 93 commits into
base: dev-0.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
d7a03b1
Added readme file for the model folders
wli51 Feb 14, 2025
ef44247
Added model files
wli51 Feb 14, 2025
6a2b4c3
Added readme for dataset folder
wli51 Feb 14, 2025
bea6c02
Added dataset files
wli51 Feb 14, 2025
5bfaad6
Added readme for trainers
wli51 Feb 14, 2025
31f2f1b
Added trainer files
wli51 Feb 14, 2025
3496984
Added loss files
wli51 Feb 14, 2025
c187a31
Added metrics files
wli51 Feb 14, 2025
3e921b9
Added callback files
wli51 Feb 14, 2025
8af1475
Added transform files
wli51 Feb 14, 2025
744efcc
Added evaluation files
wli51 Feb 14, 2025
81f3ccd
Made some modifications to callback
wli51 Feb 14, 2025
c8a27ea
Added gitignore
wli51 Feb 14, 2025
b7c0df2
Added notebook that is a minimal example
wli51 Feb 14, 2025
a311fd7
Modified the way trainers are accessed by callbacks, instead of train…
wli51 Feb 17, 2025
e466533
Updated docstring and removed unneeded imports for the callbacks
wli51 Feb 17, 2025
b5a6891
Updated docstring and removed unneeded imports for the dataset classes
wli51 Feb 17, 2025
10a8765
Update loss class forward function variable name and added/modified d…
wli51 Feb 17, 2025
cde174c
Removed redundant metric name property and added documentation
wli51 Feb 17, 2025
6b3d914
Added documentation
wli51 Feb 17, 2025
022a46d
Re-ran notebook
wli51 Feb 17, 2025
74b7b02
Added wrapper class for torch modules that accumulates metric values …
wli51 Feb 18, 2025
12d5f45
Updated abstract trainer class's initialization and trian function to…
wli51 Feb 18, 2025
d123b20
Update example notebook to demonstrate alternative early termination …
wli51 Feb 18, 2025
f33b906
Added environment file for cp loss
wli51 Feb 20, 2025
ae573dd
Updated existing type hint to comply to python 3.9 standards
wli51 Feb 20, 2025
d57d80a
Re-ran example notebook
wli51 Feb 20, 2025
09b63d5
Added dataset class that does not rely on pe2loaddata generated file …
wli51 Feb 20, 2025
1d63fad
Modified gitignore to ignore produced files under additional example …
wli51 Feb 20, 2025
6568d01
Modified notebook description
wli51 Feb 20, 2025
755d162
Added helper function that computes metrics on a per image basis give…
wli51 Feb 26, 2025
b537c83
Added a new file which is a collection of helper functions for predic…
wli51 Feb 26, 2025
1f7db29
Added new file which is a collection of helper functions to visualize…
wli51 Feb 26, 2025
75d5adc
Fixed bug in patch dataset returning the wrong length of itself
wli51 Feb 26, 2025
071aa87
Modified callback to allow the plot frequency during training to be t…
wli51 Feb 26, 2025
d7f62b8
Fixed bug
wli51 Feb 27, 2025
cab2820
Merge branch 'dev-dataset' into prototype_example to apply a fix to t…
wli51 Feb 27, 2025
3f68f2f
Merge branch 'dev-callbacks' into prototype_example to apply modifica…
wli51 Feb 27, 2025
4ea24a9
Update trainer classes to remove best model attribute that is interna…
wli51 Feb 27, 2025
6d362b0
Update callbacks/IntermediatePlot.py for clearer documentation of wha…
wli51 Feb 27, 2025
b3e69a1
Merge changes committed through the PR
wli51 Feb 27, 2025
13cb298
Removed call to super class method that does not do anything
wli51 Feb 27, 2025
93b83f2
Modify MlflowLogger class so the default behavior is not to set new t…
wli51 Feb 28, 2025
5380ac7
Modify MlflowLogger class so the default behavior is not to log any p…
wli51 Feb 28, 2025
9a27c23
Modify comment for improved clarity
wli51 Feb 28, 2025
c284f05
Removed TODO item as it is not going to be useful
wli51 Feb 28, 2025
a6f40f6
Updated variable names and docstring for DiscriminatorLoss.py for imp…
wli51 Feb 28, 2025
778144b
Modified MlflowLogger so that the experiment name is also not configu…
wli51 Feb 28, 2025
b270630
Removed outdated TODO
wli51 Feb 28, 2025
2eefe5c
Removed commented out code that is no longer needed
wli51 Feb 28, 2025
756429c
Modified function name to make metrics aggregation more clear.
wli51 Feb 28, 2025
30f6b17
Update docstring
wli51 Feb 28, 2025
7ac89b8
Remove description of where the code is adapted from for consistency
wli51 Feb 28, 2025
ed7d940
fixed bug, log_params should not take keyword arguments, should just …
wli51 Mar 1, 2025
cec0df6
Changed comment to one line for cleanness
wli51 Mar 1, 2025
7f20d7a
Renamed cache related functions for clarity
wli51 Mar 1, 2025
825b745
Added comment to better describe the wGAN DiscriminatorLoss
wli51 Mar 1, 2025
ed47a06
Modified docstring for better clarity
wli51 Mar 1, 2025
63829d7
Changes the default behavior of early termination to be disabled when…
wli51 Mar 1, 2025
18dec2c
Added parameter to determine if batch normalization will be used in d…
wli51 Mar 1, 2025
c0b00b6
Do not retain graph when computing the gradient penalty loss for pote…
wli51 Mar 1, 2025
90660ba
Removed unecessary if statement
wli51 Mar 1, 2025
80be3dd
Added reconstruction loss weight parameter that defaults to 1
wli51 Mar 1, 2025
aa7b531
Merge branch 'dev-eval-plot' into prototype_example
wli51 Mar 1, 2025
d7eb27e
Added raw_input and raw_target properties to PatchDataset to centrali…
wli51 Mar 1, 2025
5470fd4
Fixed bug of higher order derivative not being able to be computed du…
wli51 Mar 2, 2025
65556a3
Fixed early temination enable/disable logic to ensure when no early t…
wli51 Mar 2, 2025
9751437
Modified predict_image function so it returns the target tensor along…
wli51 Mar 3, 2025
1b5d7ec
Modified evlauation_per_image_metric function so metrics cna be compu…
wli51 Mar 3, 2025
a052113
Modified plot_patches function for compatibility with updated predict…
wli51 Mar 3, 2025
2ef624f
Update return type hint
wli51 Mar 3, 2025
2e1532e
Added new functions to visualization_utils.py for visualization of se…
wli51 Mar 3, 2025
fda7de7
Added kwargs support for plot parameters and fixed metrics in title
wli51 Mar 3, 2025
f909e34
Updated IntermedaitePlot callback class for compatibility with new pl…
wli51 Mar 3, 2025
1413280
Renamed callback name and modified type checking to reflect that it s…
wli51 Mar 3, 2025
0da8875
Update comment to better reflect what the functions are doing
wli51 Mar 3, 2025
8d5a616
Removed uneeded files and functions
wli51 Mar 3, 2025
baeb80b
Update example so it is consistent with the updated evaluation/plotti…
wli51 Mar 3, 2025
fc76155
Bug fix for early termination disable
wli51 Mar 3, 2025
d78beae
Bug fix for dataset type checking
wli51 Mar 3, 2025
e23dd55
Update example so it is consistent with the updated evaluation/plotti…
wli51 Mar 3, 2025
8e652b8
Renamed loss class for clarity and updated examples accordingly
wli51 Mar 4, 2025
e0cbf41
Removed uneeded functions
wli51 Mar 4, 2025
c6ce487
Removed TODO comment that will not be addressed at the moment
wli51 Mar 4, 2025
1b81e1e
Removed uneeded comment
wli51 Mar 4, 2025
ce391c4
Updated UNet and FNet to have consistent specification for output act…
wli51 Mar 4, 2025
741c121
Removed some type docstrings to reduce redundancy
wli51 Mar 4, 2025
220c561
Removed not planned TODOs
wli51 Mar 4, 2025
8c7e97c
Added docstring
wli51 Mar 4, 2025
c45fabf
Removed unwanted functions
wli51 Mar 4, 2025
207c75e
Renamed WGANTrainer and updated examples accordingly
wli51 Mar 4, 2025
9596e53
Removed unplanned TODO
wli51 Mar 4, 2025
ad93208
Removed unused attributes and refactored to use the depth attribute
wli51 Mar 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# images and anything under mlflow
*.png
examples/example_train*/*

# pycache
*.pyc
62 changes: 62 additions & 0 deletions callbacks/AbstractCallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from abc import ABC

class AbstractCallback(ABC):
"""
Abstract class for callbacks in the training process.
Callbacks can be used to plot intermediate metrics, log contents, save checkpoints, etc.
"""

def __init__(self, name: str):
"""
:param name: Name of the callback.
"""
self._name = name
self._trainer = None

@property
def name(self):
"""
Getter for callback name
"""
return self._name

@property
def trainer(self):
"""
Allows for access of trainer
"""
return self._trainer

def _set_trainer(self, trainer):
"""
Helper function called by trainer class to initialize trainer value field

:param trainer: trainer object
:type trainer: AbstractTrainer or subclass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would recommend either consistently using type hints or using the function docstring for types. I think type hints are generally better though

Copy link
Collaborator Author

@wli51 wli51 Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I purposefully did not include type hints to trainer objects here because the trainer classes already have type hints of the callback classes and adding type hints to trainer class here in the callback class would cause the problem of circular importing. This is probably a consequence of my sub-optimal class design that can needs to be solved with some refactoring

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember if this is the way to do it exactly, but you can do something like this:

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from some_module import AbstractTrainer  # Only for type hints, avoids circular imports

class CallbackClass:
    def _set_trainer(self, trainer: "AbstractTrainer") -> None:
        """
        Helper function called by trainer class to initialize trainer value field.
        """
        self._trainer = trainer

Copy link
Member

@MattsonCam MattsonCam Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although, maybe there is a better alternative here such as changing the design

"""

self._trainer = trainer

def on_train_start(self):
"""
Called at the start of training.
"""
pass

def on_epoch_start(self):
"""
Called at the start of each epoch.
"""
pass

def on_epoch_end(self):
"""
Called at the end of each epoch.
"""
pass

def on_train_end(self):
"""
Called at the end of training.
"""
pass
117 changes: 117 additions & 0 deletions callbacks/IntermediatePlot.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to include multiple plotting classes in the same file? If not, consider changing the name of this file to match the class name, since this class is specific to plotting patches. I could also see a scenario where the plotting callback classes are in a separate folder inside the callbacks folder.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am working on a major overhaul of the evaluation and plotting suite outside of this PR that would have separate plotting helper functions for ImageDataset vs PatchDataset that live under a single py file under a single evaluation folder and a single IntermediatePlot call back class would determine which to use to depending on the dataset it was given. I can merge the overhaul into this PR once I am done with that or it could be a separate PR. What are your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like a good idea to me, even if many of the helper functions are only for one callback. This is a design decision, which will depend on what is being plotted and how much complexity is involved. I think a separate pr would be a better choice for a change like this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just did with the push last night. The new plot files under visualization_utils would detect dataset type and plot 3 columns of the input, target and the predicted image for standard Image Dataset (ImageDataset and CachedDataset), for PatchDataset it will plot an additional column of the raw image where the patches are being cropped from at the left-most. There are also two versions of the plot function, one can operate on the inference and evaluation results (predictions from a previous forward pass and corresponding metrics computed) the other internally computes the inference and evaluation metrics so everything is self-contained. For now the self-contained version is used by the IntermediatePlot callback class for the sake of quickness of implementation, which will result in redundant evaluation and inference of n images per epoch where n is the number of patches/images being plotted. In a future version we can re-work the way of the trainer communicating with the callbacks to remove this redundant computation.

Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import pathlib
from typing import List, Union
import random

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from .AbstractCallback import AbstractCallback
from ..datasets.PatchDataset import PatchDataset
from ..evaluation.visualization_utils import plot_predictions_grid_from_model

class IntermediatePlot(AbstractCallback):
"""
Callback to plot model generated outputs, ground
truth, and input stained image patches at the end of each epoch.
"""

def __init__(self,
name: str,
path: Union[pathlib.Path, str],
dataset: Union[Dataset, PatchDataset],
plot_n_patches: int=5,
indices: Union[List[int], None]=None,
plot_metrics: List[nn.Module]=None,
every_n_epochs: int=5,
random_seed: int=42,
**kwargs):
"""
Initialize the IntermediatePlot callback.
Allows plots of predictions to be generated during training for monitoring of training progress.
Supports both PatchDataset and Dataset classes for plotting.
This callback, when passed into the trainer, will plot the model predictions on a subset of the provided dataset at the end of each epoch.

:param name: Name of the callback.
:param path: Path to save the model weights.
:type path: Union[pathlib.Path, str]
:param dataset: Dataset to be used for plotting intermediate results.
:type dataset: Union[Dataset, PatchDataset]
:param plot_n_patches: Number of patches to randomly select and plot, defaults to 5.
The exact patches/images being plotted may vary due to a difference in seed or dataset size.
To ensure best reproducibility and consistency, please use a fixed dataset and indices argument instead.
:type plot_n_patches: int, optional
:param indices: Optional list of specific indices to subset the dataset before inference.
Overrides the plot_n_patches and random_seed arguments and uses the indices list to subset.
:type indices: Union[List[int], None]
:param plot_metrics: List of metrics to compute and display in plot title, defaults to None.
:type plot_metrics: List[nn.Module], optional
:param kwargs: Additional keyword arguments to be passed to plot_patches.
:type kwargs: dict
:param every_n_epochs: How frequent should intermediate plots should be plotted, defaults to 5
:type every_n_epochs: int
:param random_seed: Random seed for reproducibility for random patch/image selection, defaults to 42.
:type random_seed: int
:raises TypeError: If the dataset is not an instance of PatchDataset.
"""
super().__init__(name)
self._path = path
if isinstance(dataset, Dataset):
pass
elif isinstance(dataset, PatchDataset):
pass
else:
raise TypeError(f"Expected PatchDataset, got {type(dataset)}")

self._dataset = dataset

# Additional kwargs passed to plot_patches
self.plot_metrics = plot_metrics
self.every_n_epochs = every_n_epochs
self.plot_kwargs = kwargs

if indices is not None:
# Check if indices are within bounds
for i in indices:
if i >= len(self._dataset):
raise ValueError(f"Index {i} out of bounds for dataset of size {len(self._dataset)}")
self._dataset_subset_indices = indices
else:
# Generate random indices to subset given seed and plot_n_patches
plot_n_patches = min(plot_n_patches, len(self._dataset))
random.seed(random_seed)
self._dataset_subset_indices = random.sample(range(len(self._dataset)), plot_n_patches)

def on_epoch_end(self):
"""
Called at the end of each epoch to plot predictions if the epoch is a multiple of `every_n_epochs`.
"""
if (self.trainer.epoch + 1) % self.every_n_epochs == 0 or self.trainer.epoch + 1 == self.trainer.epoch:
self._plot()

def on_train_end(self):
"""
Called at the end of training. Plots if not already done in the last epoch.
"""
if (self.trainer.epoch + 1) % self.every_n_epochs != 0:
self._plot()

def _plot(self):
"""
Helper method to generate and save plots.
Plot dataset with model predictions on n random images from dataset at the end of each epoch.
Called by the on_epoch_end and on_train_end methods
"""

original_device = next(self.trainer.model.parameters()).device

plot_predictions_grid_from_model(
model=self.trainer.model,
dataset=self._dataset,
indices=self._dataset_subset_indices,
metrics=self.plot_metrics,
save_path=f"{self._path}/epoch_{self.trainer.epoch}.png",
device=original_device,
show=False,
**self.plot_kwargs
)
115 changes: 115 additions & 0 deletions callbacks/MlflowLogger.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I missed this in the code. I think we want to also save the images after each epoch as artifacts with mlflow. I do this with cropped nuclei by reserving a folder for each nuclei, where the images for each epoch are saved in that folder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few ways you could go about doing this (e.g including more code in this logger or creating classes for each type of data saved). By types of data I mean: metrics, parameters, artifacts, etc...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optuna also uses algorithms to understand hyperparameter importances, so I think we will also want to save the optuna object as an artifact for optimizing models

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my current implementation, the mlflow logger and image plotting and optuna do not go hand in hand. I am less familiar with mlflow than you are and what I think you are proposing makes a lot more sense and improves the flow. I do think this amount of extra feature is perhaps more suitable in a separate PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so too. If you have questions, just let me know

Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
import pathlib
import tempfile
from typing import Union, Dict, Optional

import mlflow
import torch

from .AbstractCallback import AbstractCallback

class MlflowLogger(AbstractCallback):
"""
Callback to log metrics to MLflow.
"""

def __init__(self,

name: str,
artifact_name: str = 'best_model_weights.pth',
mlflow_uri: Union[pathlib.Path, str] = None,
mlflow_experiment_name: Optional[str] = None,
mlflow_start_run_args: dict = None,
mlflow_log_params_args: dict = None,

):
"""
Initialize the MlflowLogger callback.

:param name: Name of the callback.
:param artifact_name: Name of the artifact file to log, defaults to 'best_model_weights.pth'.
:param mlflow_uri: URI for the MLflow tracking server, defaults to None.
If a path is specified, the logger class will call set_tracking_uri to that supplied path
thereby initiating a new tracking server.
If None (default), the logger class will not tamper with mlflow server to enable logging to a global server
initialized outside of this class.
:type mlflow_uri: pathlib.Path or str, optional
:param mlflow_experiment_name: Name of the MLflow experiment, defaults to None, which will not call the
set_experiment method of mlflow and will use whichever experiment name that is globally configured. If a
name is provided, the logger class will call set_experiment to that supplied name.
:type mlflow_experiment_name: str, optional
:param mlflow_start_run_args: Additional arguments for starting an MLflow run, defaults to None.
:type mlflow_start_run_args: dict, optional
:param mlflow_log_params_args: Additional arguments for logging parameters to MLflow, defaults to None.
:type mlflow_log_params_args: dict, optional
"""
super().__init__(name)

if mlflow_uri is not None:
try:
mlflow.set_tracking_uri(mlflow_uri)
except Exception as e:
raise RuntimeError(f"Error setting MLflow tracking URI: {e}")

if mlflow_experiment_name is not None:
try:
mlflow.set_experiment(mlflow_experiment_name)
except Exception as e:
raise RuntimeError(f"Error setting MLflow experiment: {e}")

self._artifact_name = artifact_name
self._mlflow_start_run_args = mlflow_start_run_args
self._mlflow_log_params_args = mlflow_log_params_args

def on_train_start(self):
"""
Called at the start of training.

Calls mlflow start run and logs params if provided
"""

if self._mlflow_start_run_args is None:
pass
elif isinstance(self._mlflow_start_run_args, Dict):
mlflow.start_run(
**self._mlflow_start_run_args
)
else:
raise TypeError("mlflow_start_run_args must be None or a dictionary.")

if self._mlflow_log_params_args is None:
pass
elif isinstance(self._mlflow_log_params_args, Dict):
mlflow.log_params(
self._mlflow_log_params_args
)
else:
raise TypeError("mlflow_log_params_args must be None or a dictionary.")

def on_epoch_end(self):
"""
Called at the end of each epoch.

Iterate over the most recent log items in trainer and call mlflow log metric
"""
for key, values in self.trainer.log.items():
if values is not None and len(values) > 0:
value = values[-1]
else:
value = None
mlflow.log_metric(key, value, step=self.trainer.epoch)

def on_train_end(self):
"""
Called at the end of training.

Saves trainer best model to a temporary directory and calls mlflow log artifact
Then ends run
"""
# Save weights to a temporary directory and log artifacts
with tempfile.TemporaryDirectory() as tmpdirname:
weights_path = os.path.join(tmpdirname, self._artifact_name)
torch.save(self.trainer.best_model, weights_path)
mlflow.log_artifact(weights_path, artifact_path="models")

mlflow.end_run()
3 changes: 3 additions & 0 deletions callbacks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Here lives the callback classes that are meant to be fed into trainers to do stuff like saving images every epoch and logging.

The callback classes must inherit the abstract class.
39 changes: 39 additions & 0 deletions cp_gan_env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: cp_gan_env
channels:
- anaconda
- pytorch
- nvidia
- conda-forge
dependencies:
- conda-forge::python=3.9
- conda-forge::pip
- pytorch::pytorch
- pytorch::torchvision
- pytorch::torchaudio
- pytorch::pytorch-cuda=12.1
- conda-forge::seaborn
- conda-forge::matplotlib
- conda-forge::jupyter
- conda-forge::pre_commit
- conda-forge::pandas
- conda-forge::pillow
- conda-forge::numpy
- conda-forge::pathlib2
- conda-forge::scikit-learn
- conda-forge::opencv
- conda-forge::pyarrow
- conda-forge::ipython
- conda-forge::notebook
- conda-forge::albumentations
- conda-forge::optuna
- conda-forge::mysqlclient
- conda-forge::openjdk
- conda-forge::gtk2
- conda-forge::typing-extensions
- conda-forge::Jinja2
- conda-forge::inflect
- conda-forge::wxpython
- conda-forge::sentry-sdk
- pip:
- mlflow
- cellprofiler==4.2.8
Loading