Skip to content

Prototype example #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 93 commits into
base: dev-0.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
d7a03b1
Added readme file for the model folders
wli51 Feb 14, 2025
ef44247
Added model files
wli51 Feb 14, 2025
6a2b4c3
Added readme for dataset folder
wli51 Feb 14, 2025
bea6c02
Added dataset files
wli51 Feb 14, 2025
5bfaad6
Added readme for trainers
wli51 Feb 14, 2025
31f2f1b
Added trainer files
wli51 Feb 14, 2025
3496984
Added loss files
wli51 Feb 14, 2025
c187a31
Added metrics files
wli51 Feb 14, 2025
3e921b9
Added callback files
wli51 Feb 14, 2025
8af1475
Added transform files
wli51 Feb 14, 2025
744efcc
Added evaluation files
wli51 Feb 14, 2025
81f3ccd
Made some modifications to callback
wli51 Feb 14, 2025
c8a27ea
Added gitignore
wli51 Feb 14, 2025
b7c0df2
Added notebook that is a minimal example
wli51 Feb 14, 2025
a311fd7
Modified the way trainers are accessed by callbacks, instead of train…
wli51 Feb 17, 2025
e466533
Updated docstring and removed unneeded imports for the callbacks
wli51 Feb 17, 2025
b5a6891
Updated docstring and removed unneeded imports for the dataset classes
wli51 Feb 17, 2025
10a8765
Update loss class forward function variable name and added/modified d…
wli51 Feb 17, 2025
cde174c
Removed redundant metric name property and added documentation
wli51 Feb 17, 2025
6b3d914
Added documentation
wli51 Feb 17, 2025
022a46d
Re-ran notebook
wli51 Feb 17, 2025
74b7b02
Added wrapper class for torch modules that accumulates metric values …
wli51 Feb 18, 2025
12d5f45
Updated abstract trainer class's initialization and trian function to…
wli51 Feb 18, 2025
d123b20
Update example notebook to demonstrate alternative early termination …
wli51 Feb 18, 2025
f33b906
Added environment file for cp loss
wli51 Feb 20, 2025
ae573dd
Updated existing type hint to comply to python 3.9 standards
wli51 Feb 20, 2025
d57d80a
Re-ran example notebook
wli51 Feb 20, 2025
09b63d5
Added dataset class that does not rely on pe2loaddata generated file …
wli51 Feb 20, 2025
1d63fad
Modified gitignore to ignore produced files under additional example …
wli51 Feb 20, 2025
6568d01
Modified notebook description
wli51 Feb 20, 2025
755d162
Added helper function that computes metrics on a per image basis give…
wli51 Feb 26, 2025
b537c83
Added a new file which is a collection of helper functions for predic…
wli51 Feb 26, 2025
1f7db29
Added new file which is a collection of helper functions to visualize…
wli51 Feb 26, 2025
75d5adc
Fixed bug in patch dataset returning the wrong length of itself
wli51 Feb 26, 2025
071aa87
Modified callback to allow the plot frequency during training to be t…
wli51 Feb 26, 2025
d7f62b8
Fixed bug
wli51 Feb 27, 2025
cab2820
Merge branch 'dev-dataset' into prototype_example to apply a fix to t…
wli51 Feb 27, 2025
3f68f2f
Merge branch 'dev-callbacks' into prototype_example to apply modifica…
wli51 Feb 27, 2025
4ea24a9
Update trainer classes to remove best model attribute that is interna…
wli51 Feb 27, 2025
6d362b0
Update callbacks/IntermediatePlot.py for clearer documentation of wha…
wli51 Feb 27, 2025
b3e69a1
Merge changes committed through the PR
wli51 Feb 27, 2025
13cb298
Removed call to super class method that does not do anything
wli51 Feb 27, 2025
93b83f2
Modify MlflowLogger class so the default behavior is not to set new t…
wli51 Feb 28, 2025
5380ac7
Modify MlflowLogger class so the default behavior is not to log any p…
wli51 Feb 28, 2025
9a27c23
Modify comment for improved clarity
wli51 Feb 28, 2025
c284f05
Removed TODO item as it is not going to be useful
wli51 Feb 28, 2025
a6f40f6
Updated variable names and docstring for DiscriminatorLoss.py for imp…
wli51 Feb 28, 2025
778144b
Modified MlflowLogger so that the experiment name is also not configu…
wli51 Feb 28, 2025
b270630
Removed outdated TODO
wli51 Feb 28, 2025
2eefe5c
Removed commented out code that is no longer needed
wli51 Feb 28, 2025
756429c
Modified function name to make metrics aggregation more clear.
wli51 Feb 28, 2025
30f6b17
Update docstring
wli51 Feb 28, 2025
7ac89b8
Remove description of where the code is adapted from for consistency
wli51 Feb 28, 2025
ed7d940
fixed bug, log_params should not take keyword arguments, should just …
wli51 Mar 1, 2025
cec0df6
Changed comment to one line for cleanness
wli51 Mar 1, 2025
7f20d7a
Renamed cache related functions for clarity
wli51 Mar 1, 2025
825b745
Added comment to better describe the wGAN DiscriminatorLoss
wli51 Mar 1, 2025
ed47a06
Modified docstring for better clarity
wli51 Mar 1, 2025
63829d7
Changes the default behavior of early termination to be disabled when…
wli51 Mar 1, 2025
18dec2c
Added parameter to determine if batch normalization will be used in d…
wli51 Mar 1, 2025
c0b00b6
Do not retain graph when computing the gradient penalty loss for pote…
wli51 Mar 1, 2025
90660ba
Removed unecessary if statement
wli51 Mar 1, 2025
80be3dd
Added reconstruction loss weight parameter that defaults to 1
wli51 Mar 1, 2025
aa7b531
Merge branch 'dev-eval-plot' into prototype_example
wli51 Mar 1, 2025
d7eb27e
Added raw_input and raw_target properties to PatchDataset to centrali…
wli51 Mar 1, 2025
5470fd4
Fixed bug of higher order derivative not being able to be computed du…
wli51 Mar 2, 2025
65556a3
Fixed early temination enable/disable logic to ensure when no early t…
wli51 Mar 2, 2025
9751437
Modified predict_image function so it returns the target tensor along…
wli51 Mar 3, 2025
1b5d7ec
Modified evlauation_per_image_metric function so metrics cna be compu…
wli51 Mar 3, 2025
a052113
Modified plot_patches function for compatibility with updated predict…
wli51 Mar 3, 2025
2ef624f
Update return type hint
wli51 Mar 3, 2025
2e1532e
Added new functions to visualization_utils.py for visualization of se…
wli51 Mar 3, 2025
fda7de7
Added kwargs support for plot parameters and fixed metrics in title
wli51 Mar 3, 2025
f909e34
Updated IntermedaitePlot callback class for compatibility with new pl…
wli51 Mar 3, 2025
1413280
Renamed callback name and modified type checking to reflect that it s…
wli51 Mar 3, 2025
0da8875
Update comment to better reflect what the functions are doing
wli51 Mar 3, 2025
8d5a616
Removed uneeded files and functions
wli51 Mar 3, 2025
baeb80b
Update example so it is consistent with the updated evaluation/plotti…
wli51 Mar 3, 2025
fc76155
Bug fix for early termination disable
wli51 Mar 3, 2025
d78beae
Bug fix for dataset type checking
wli51 Mar 3, 2025
e23dd55
Update example so it is consistent with the updated evaluation/plotti…
wli51 Mar 3, 2025
8e652b8
Renamed loss class for clarity and updated examples accordingly
wli51 Mar 4, 2025
e0cbf41
Removed uneeded functions
wli51 Mar 4, 2025
c6ce487
Removed TODO comment that will not be addressed at the moment
wli51 Mar 4, 2025
1b81e1e
Removed uneeded comment
wli51 Mar 4, 2025
ce391c4
Updated UNet and FNet to have consistent specification for output act…
wli51 Mar 4, 2025
741c121
Removed some type docstrings to reduce redundancy
wli51 Mar 4, 2025
220c561
Removed not planned TODOs
wli51 Mar 4, 2025
8c7e97c
Added docstring
wli51 Mar 4, 2025
c45fabf
Removed unwanted functions
wli51 Mar 4, 2025
207c75e
Renamed WGANTrainer and updated examples accordingly
wli51 Mar 4, 2025
9596e53
Removed unplanned TODO
wli51 Mar 4, 2025
ad93208
Removed unused attributes and refactored to use the depth attribute
wli51 Mar 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# images and anything under mlflow
*.png
examples/example_train*/*

# pycache
*.pyc
62 changes: 62 additions & 0 deletions callbacks/AbstractCallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from abc import ABC

class AbstractCallback(ABC):
"""
Abstract class for callbacks in the training process.
Callbacks can be used to plot intermediate metrics, log contents, save checkpoints, etc.
"""

def __init__(self, name: str):
"""
:param name: Name of the callback.
"""
self._name = name
self._trainer = None

@property
def name(self):
"""
Getter for callback name
"""
return self._name

@property
def trainer(self):
"""
Allows for access of trainer
"""
return self._trainer

def _set_trainer(self, trainer):
"""
Helper function called by trainer class to initialize trainer value field

:param trainer: trainer object
:type trainer: AbstractTrainer or subclass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would recommend either consistently using type hints or using the function docstring for types. I think type hints are generally better though

Copy link
Collaborator Author

@wli51 wli51 Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I purposefully did not include type hints to trainer objects here because the trainer classes already have type hints of the callback classes and adding type hints to trainer class here in the callback class would cause the problem of circular importing. This is probably a consequence of my sub-optimal class design that can needs to be solved with some refactoring

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember if this is the way to do it exactly, but you can do something like this:

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from some_module import AbstractTrainer  # Only for type hints, avoids circular imports

class CallbackClass:
    def _set_trainer(self, trainer: "AbstractTrainer") -> None:
        """
        Helper function called by trainer class to initialize trainer value field.
        """
        self._trainer = trainer

Copy link
Member

@MattsonCam MattsonCam Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although, maybe there is a better alternative here such as changing the design

"""

self._trainer = trainer

def on_train_start(self):
"""
Called at the start of training.
"""
pass

def on_epoch_start(self):
"""
Called at the start of each epoch.
"""
pass

def on_epoch_end(self):
"""
Called at the end of each epoch.
"""
pass

def on_train_end(self):
"""
Called at the end of training.
"""
pass
69 changes: 69 additions & 0 deletions callbacks/IntermediatePlot.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to include multiple plotting classes in the same file? If not, consider changing the name of this file to match the class name, since this class is specific to plotting patches. I could also see a scenario where the plotting callback classes are in a separate folder inside the callbacks folder.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am working on a major overhaul of the evaluation and plotting suite outside of this PR that would have separate plotting helper functions for ImageDataset vs PatchDataset that live under a single py file under a single evaluation folder and a single IntermediatePlot call back class would determine which to use to depending on the dataset it was given. I can merge the overhaul into this PR once I am done with that or it could be a separate PR. What are your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like a good idea to me, even if many of the helper functions are only for one callback. This is a design decision, which will depend on what is being plotted and how much complexity is involved. I think a separate pr would be a better choice for a change like this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just did with the push last night. The new plot files under visualization_utils would detect dataset type and plot 3 columns of the input, target and the predicted image for standard Image Dataset (ImageDataset and CachedDataset), for PatchDataset it will plot an additional column of the raw image where the patches are being cropped from at the left-most. There are also two versions of the plot function, one can operate on the inference and evaluation results (predictions from a previous forward pass and corresponding metrics computed) the other internally computes the inference and evaluation metrics so everything is self-contained. For now the self-contained version is used by the IntermediatePlot callback class for the sake of quickness of implementation, which will result in redundant evaluation and inference of n images per epoch where n is the number of patches/images being plotted. In a future version we can re-work the way of the trainer communicating with the callbacks to remove this redundant computation.

Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import List, Union

import torch
import torch.nn as nn

from .AbstractCallback import AbstractCallback
from ..datasets.PatchDataset import PatchDataset

from ..evaluation.visualization_utils import plot_patches

class IntermediatePatchPlot(AbstractCallback):
"""
Callback to plot model generated outputs alongside ground
truth and input at the end end of each epoch.
"""

def __init__(self,
name: str,
path: str,
dataset: PatchDataset,
plot_n_patches: int=5,
plot_metrics: List[nn.Module]=None,
**kwargs):
"""
Initialize the IntermediatePlot callback.

:param name: Name of the callback.
:type name: str
:param path: Path to save the model weights.
:type path: str
:param dataset: Dataset to be used for plotting intermediate results.
:type dataset: PatchDataset
:param plot_n_patches: Number of patches to plot, defaults to 5.
:type plot_n_patches: int, optional
:param plot_metrics: List of metrics to compute and display in plot title, defaults to None.
:type plot_metrics: List[nn.Module], optional
:param kwargs: Additional keyword arguments to be passed to plot_patches.
:type kwargs: dict
:raises TypeError: If the dataset is not an instance of PatchDataset.
"""
super().__init__(name)
self._path = path
if not isinstance(dataset, PatchDataset):
raise TypeError(f"Expected PatchDataset, got {type(dataset)}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea. Consider adding more input validation. In particular I think it would be worth including input validation to check if the user enters correct inputs (such as a positive int of epochs and update frequency). You could include this functionality in a place like the AbstractTrainer or where the user's inputs are read, since this will need to be specified for each model.

self._dataset = dataset

# Additional kwargs passed to plot_patches
self.plot_n_patches = plot_n_patches
self.plot_metrics = plot_metrics
self.plot_kwargs = kwargs

def on_epoch_end(self):
"""
Called at the end of each epoch.

Plot dataset with model predictions on n random images from dataset at the end of each epoch.
"""

original_device = next(self.trainer.model.parameters()).device

plot_patches(
_dataset = self._dataset,
_n_patches = self.plot_n_patches,
_model = self.trainer.model,
_metrics = self.plot_metrics,
save_path = f"{self._path}/epoch_{self.trainer.epoch}.png",
device=original_device,
**self.plot_kwargs
)
93 changes: 93 additions & 0 deletions callbacks/MlflowLogger.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I missed this in the code. I think we want to also save the images after each epoch as artifacts with mlflow. I do this with cropped nuclei by reserving a folder for each nuclei, where the images for each epoch are saved in that folder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few ways you could go about doing this (e.g including more code in this logger or creating classes for each type of data saved). By types of data I mean: metrics, parameters, artifacts, etc...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optuna also uses algorithms to understand hyperparameter importances, so I think we will also want to save the optuna object as an artifact for optimizing models

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my current implementation, the mlflow logger and image plotting and optuna do not go hand in hand. I am less familiar with mlflow than you are and what I think you are proposing makes a lot more sense and improves the flow. I do think this amount of extra feature is perhaps more suitable in a separate PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so too. If you have questions, just let me know

Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import pathlib
import tempfile
from typing import Union

import mlflow
import torch

from .AbstractCallback import AbstractCallback

class MlflowLogger(AbstractCallback):
"""
Callback to log metrics to MLflow.
"""

def __init__(self,

name: str,
artifact_name: str = 'best_model_weights.pth',
mlflow_uri: Union[pathlib.Path, str] = 'mlruns',
mlflow_experiment_name: str = 'Default',
mlflow_start_run_args: dict = {},
mlflow_log_params_args: dict = {},

):
"""
Initialize the MlflowLogger callback.

:param name: Name of the callback.
:type name: str
:param artifact_name: Name of the artifact file to log, defaults to 'best_model_weights.pth'.
:type artifact_name: str, optional
:param mlflow_uri: URI for the MLflow tracking server, defaults to 'mlruns' under current wd.
:type mlflow_uri: pathlib.Path or str, optional
:param mlflow_experiment_name: Name of the MLflow experiment, defaults to 'Default'.
:type mlflow_experiment_name: str, optional
:param mlflow_start_run_args: Additional arguments for starting an MLflow run, defaults to {}.
:type mlflow_start_run_args: dict, optional
:param mlflow_log_params_args: Additional arguments for logging parameters to MLflow, defaults to {}.
:type mlflow_log_params_args: dict, optional
"""
super().__init__(name)

try:
mlflow.set_tracking_uri(mlflow_uri)
mlflow.set_experiment(mlflow_experiment_name)
except Exception as e:
print(f"Error setting MLflow tracking URI: {e}")

self._artifact_name = artifact_name
self._mlflow_start_run_args = mlflow_start_run_args
self._mlflow_log_params_args = mlflow_log_params_args

def on_train_start(self):
"""
Called at the start of training.

Calls mlflow start run and logs params if provided
"""
mlflow.start_run(
**self._mlflow_start_run_args
)
mlflow.log_params(
self._mlflow_log_params_args
)

def on_epoch_end(self):
"""
Called at the end of each epoch.

Iterate over the most recent log items in trainer and call mlflow log metric
"""
for key, values in self.trainer.log.items():
if values is not None and len(values) > 0:
value = values[-1]
else:
value = None
mlflow.log_metric(key, value, step=self.trainer.epoch)

def on_train_end(self):
"""
Called at the end of training.

Saves trainer best model to a temporary directory and calls mlflow log artifact
Then ends run
"""
# Save weights to a temporary directory and log artifacts
with tempfile.TemporaryDirectory() as tmpdirname:
weights_path = os.path.join(tmpdirname, self._artifact_name)
torch.save(self.trainer.best_model, weights_path)
mlflow.log_artifact(weights_path, artifact_path="models")

mlflow.end_run()
3 changes: 3 additions & 0 deletions callbacks/README.md
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be useful to include more info in these readmes in the future to further guide the user

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Here lives the callback classes that are meant to be fed into trainers to do stuff like saving images every epoch and logging.

The callback classes must inherit the abstract class.
39 changes: 39 additions & 0 deletions cp_gan_env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: cp_gan_env
channels:
- anaconda
- pytorch
- nvidia
- conda-forge
dependencies:
- conda-forge::python=3.9
- conda-forge::pip
- pytorch::pytorch
- pytorch::torchvision
- pytorch::torchaudio
- pytorch::pytorch-cuda=12.1
- conda-forge::seaborn
- conda-forge::matplotlib
- conda-forge::jupyter
- conda-forge::pre_commit
- conda-forge::pandas
- conda-forge::pillow
- conda-forge::numpy
- conda-forge::pathlib2
- conda-forge::scikit-learn
- conda-forge::opencv
- conda-forge::pyarrow
- conda-forge::ipython
- conda-forge::notebook
- conda-forge::albumentations
- conda-forge::optuna
- conda-forge::mysqlclient
- conda-forge::openjdk
- conda-forge::gtk2
- conda-forge::typing-extensions
- conda-forge::Jinja2
- conda-forge::inflect
- conda-forge::wxpython
- conda-forge::sentry-sdk
- pip:
- mlflow
- cellprofiler==4.2.8
Loading