-
Notifications
You must be signed in to change notification settings - Fork 1
Prototype example #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-0.1
Are you sure you want to change the base?
Changes from 30 commits
d7a03b1
ef44247
6a2b4c3
bea6c02
5bfaad6
31f2f1b
3496984
c187a31
3e921b9
8af1475
744efcc
81f3ccd
c8a27ea
b7c0df2
a311fd7
e466533
b5a6891
10a8765
cde174c
6b3d914
022a46d
74b7b02
12d5f45
d123b20
f33b906
ae573dd
d57d80a
09b63d5
1d63fad
6568d01
755d162
b537c83
1f7db29
75d5adc
071aa87
d7f62b8
cab2820
3f68f2f
4ea24a9
6d362b0
b3e69a1
13cb298
93b83f2
5380ac7
9a27c23
c284f05
a6f40f6
778144b
b270630
2eefe5c
756429c
30f6b17
7ac89b8
ed7d940
cec0df6
7f20d7a
825b745
ed47a06
63829d7
18dec2c
c0b00b6
90660ba
80be3dd
aa7b531
d7eb27e
5470fd4
65556a3
9751437
1b5d7ec
a052113
2ef624f
2e1532e
fda7de7
f909e34
1413280
0da8875
8d5a616
baeb80b
fc76155
d78beae
e23dd55
8e652b8
e0cbf41
c6ce487
1b81e1e
ce391c4
741c121
220c561
8c7e97c
c45fabf
207c75e
9596e53
ad93208
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# images and anything under mlflow | ||
*.png | ||
examples/example_train*/* | ||
|
||
# pycache | ||
*.pyc |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from abc import ABC | ||
|
||
class AbstractCallback(ABC): | ||
""" | ||
Abstract class for callbacks in the training process. | ||
Callbacks can be used to plot intermediate metrics, log contents, save checkpoints, etc. | ||
""" | ||
|
||
def __init__(self, name: str): | ||
""" | ||
:param name: Name of the callback. | ||
""" | ||
self._name = name | ||
self._trainer = None | ||
|
||
@property | ||
def name(self): | ||
""" | ||
Getter for callback name | ||
""" | ||
return self._name | ||
|
||
@property | ||
def trainer(self): | ||
""" | ||
Allows for access of trainer | ||
""" | ||
return self._trainer | ||
|
||
def _set_trainer(self, trainer): | ||
""" | ||
Helper function called by trainer class to initialize trainer value field | ||
|
||
:param trainer: trainer object | ||
:type trainer: AbstractTrainer or subclass | ||
""" | ||
|
||
self._trainer = trainer | ||
|
||
def on_train_start(self): | ||
""" | ||
Called at the start of training. | ||
""" | ||
pass | ||
|
||
def on_epoch_start(self): | ||
""" | ||
Called at the start of each epoch. | ||
""" | ||
pass | ||
|
||
def on_epoch_end(self): | ||
""" | ||
Called at the end of each epoch. | ||
""" | ||
pass | ||
|
||
def on_train_end(self): | ||
""" | ||
Called at the end of training. | ||
""" | ||
pass |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you plan to include multiple plotting classes in the same file? If not, consider changing the name of this file to match the class name, since this class is specific to plotting patches. I could also see a scenario where the plotting callback classes are in a separate folder inside the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am working on a major overhaul of the evaluation and plotting suite outside of this PR that would have separate plotting helper functions for ImageDataset vs PatchDataset that live under a single py file under a single evaluation folder and a single IntermediatePlot call back class would determine which to use to depending on the dataset it was given. I can merge the overhaul into this PR once I am done with that or it could be a separate PR. What are your thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That seems like a good idea to me, even if many of the helper functions are only for one callback. This is a design decision, which will depend on what is being plotted and how much complexity is involved. I think a separate pr would be a better choice for a change like this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just did with the push last night. The new plot files under visualization_utils would detect dataset type and plot 3 columns of the input, target and the predicted image for standard Image Dataset (ImageDataset and CachedDataset), for PatchDataset it will plot an additional column of the raw image where the patches are being cropped from at the left-most. There are also two versions of the plot function, one can operate on the inference and evaluation results (predictions from a previous forward pass and corresponding metrics computed) the other internally computes the inference and evaluation metrics so everything is self-contained. For now the self-contained version is used by the IntermediatePlot callback class for the sake of quickness of implementation, which will result in redundant evaluation and inference of n images per epoch where n is the number of patches/images being plotted. In a future version we can re-work the way of the trainer communicating with the callbacks to remove this redundant computation. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import List, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from .AbstractCallback import AbstractCallback | ||
from ..datasets.PatchDataset import PatchDataset | ||
|
||
from ..evaluation.visualization_utils import plot_patches | ||
|
||
class IntermediatePatchPlot(AbstractCallback): | ||
""" | ||
Callback to plot model generated outputs alongside ground | ||
truth and input at the end end of each epoch. | ||
wli51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
def __init__(self, | ||
name: str, | ||
path: str, | ||
dataset: PatchDataset, | ||
plot_n_patches: int=5, | ||
plot_metrics: List[nn.Module]=None, | ||
**kwargs): | ||
""" | ||
Initialize the IntermediatePlot callback. | ||
|
||
:param name: Name of the callback. | ||
:type name: str | ||
wli51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
:param path: Path to save the model weights. | ||
:type path: str | ||
:param dataset: Dataset to be used for plotting intermediate results. | ||
:type dataset: PatchDataset | ||
:param plot_n_patches: Number of patches to plot, defaults to 5. | ||
:type plot_n_patches: int, optional | ||
:param plot_metrics: List of metrics to compute and display in plot title, defaults to None. | ||
:type plot_metrics: List[nn.Module], optional | ||
:param kwargs: Additional keyword arguments to be passed to plot_patches. | ||
:type kwargs: dict | ||
:raises TypeError: If the dataset is not an instance of PatchDataset. | ||
""" | ||
super().__init__(name) | ||
self._path = path | ||
if not isinstance(dataset, PatchDataset): | ||
raise TypeError(f"Expected PatchDataset, got {type(dataset)}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good idea. Consider adding more input validation. In particular I think it would be worth including input validation to check if the user enters correct inputs (such as a positive int of epochs and update frequency). You could include this functionality in a place like the |
||
self._dataset = dataset | ||
|
||
# Additional kwargs passed to plot_patches | ||
self.plot_n_patches = plot_n_patches | ||
self.plot_metrics = plot_metrics | ||
self.plot_kwargs = kwargs | ||
|
||
def on_epoch_end(self): | ||
""" | ||
Called at the end of each epoch. | ||
|
||
Plot dataset with model predictions on n random images from dataset at the end of each epoch. | ||
""" | ||
|
||
original_device = next(self.trainer.model.parameters()).device | ||
wli51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
plot_patches( | ||
_dataset = self._dataset, | ||
_n_patches = self.plot_n_patches, | ||
_model = self.trainer.model, | ||
_metrics = self.plot_metrics, | ||
save_path = f"{self._path}/epoch_{self.trainer.epoch}.png", | ||
device=original_device, | ||
**self.plot_kwargs | ||
) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I missed this in the code. I think we want to also save the images after each epoch as artifacts with mlflow. I do this with cropped nuclei by reserving a folder for each nuclei, where the images for each epoch are saved in that folder. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a few ways you could go about doing this (e.g including more code in this logger or creating classes for each type of data saved). By types of data I mean: metrics, parameters, artifacts, etc... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optuna also uses algorithms to understand hyperparameter importances, so I think we will also want to save the optuna object as an artifact for optimizing models There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my current implementation, the mlflow logger and image plotting and optuna do not go hand in hand. I am less familiar with mlflow than you are and what I think you are proposing makes a lot more sense and improves the flow. I do think this amount of extra feature is perhaps more suitable in a separate PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so too. If you have questions, just let me know |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os | ||
import pathlib | ||
import tempfile | ||
from typing import Union | ||
|
||
import mlflow | ||
import torch | ||
|
||
from .AbstractCallback import AbstractCallback | ||
|
||
class MlflowLogger(AbstractCallback): | ||
""" | ||
Callback to log metrics to MLflow. | ||
""" | ||
|
||
def __init__(self, | ||
|
||
name: str, | ||
artifact_name: str = 'best_model_weights.pth', | ||
mlflow_uri: Union[pathlib.Path, str] = 'mlruns', | ||
wli51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mlflow_experiment_name: str = 'Default', | ||
wli51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mlflow_start_run_args: dict = {}, | ||
mlflow_log_params_args: dict = {}, | ||
wli51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
): | ||
""" | ||
Initialize the MlflowLogger callback. | ||
|
||
:param name: Name of the callback. | ||
:type name: str | ||
:param artifact_name: Name of the artifact file to log, defaults to 'best_model_weights.pth'. | ||
:type artifact_name: str, optional | ||
:param mlflow_uri: URI for the MLflow tracking server, defaults to 'mlruns' under current wd. | ||
:type mlflow_uri: pathlib.Path or str, optional | ||
:param mlflow_experiment_name: Name of the MLflow experiment, defaults to 'Default'. | ||
:type mlflow_experiment_name: str, optional | ||
:param mlflow_start_run_args: Additional arguments for starting an MLflow run, defaults to {}. | ||
:type mlflow_start_run_args: dict, optional | ||
:param mlflow_log_params_args: Additional arguments for logging parameters to MLflow, defaults to {}. | ||
:type mlflow_log_params_args: dict, optional | ||
""" | ||
super().__init__(name) | ||
|
||
try: | ||
mlflow.set_tracking_uri(mlflow_uri) | ||
mlflow.set_experiment(mlflow_experiment_name) | ||
except Exception as e: | ||
print(f"Error setting MLflow tracking URI: {e}") | ||
|
||
self._artifact_name = artifact_name | ||
self._mlflow_start_run_args = mlflow_start_run_args | ||
self._mlflow_log_params_args = mlflow_log_params_args | ||
|
||
def on_train_start(self): | ||
""" | ||
Called at the start of training. | ||
|
||
Calls mlflow start run and logs params if provided | ||
""" | ||
mlflow.start_run( | ||
**self._mlflow_start_run_args | ||
) | ||
wli51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mlflow.log_params( | ||
self._mlflow_log_params_args | ||
wli51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
def on_epoch_end(self): | ||
""" | ||
Called at the end of each epoch. | ||
|
||
Iterate over the most recent log items in trainer and call mlflow log metric | ||
""" | ||
for key, values in self.trainer.log.items(): | ||
if values is not None and len(values) > 0: | ||
value = values[-1] | ||
else: | ||
value = None | ||
mlflow.log_metric(key, value, step=self.trainer.epoch) | ||
|
||
def on_train_end(self): | ||
""" | ||
Called at the end of training. | ||
|
||
Saves trainer best model to a temporary directory and calls mlflow log artifact | ||
Then ends run | ||
""" | ||
# Save weights to a temporary directory and log artifacts | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
weights_path = os.path.join(tmpdirname, self._artifact_name) | ||
torch.save(self.trainer.best_model, weights_path) | ||
mlflow.log_artifact(weights_path, artifact_path="models") | ||
|
||
mlflow.end_run() |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may be useful to include more info in these readmes in the future to further guide the user |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Here lives the callback classes that are meant to be fed into trainers to do stuff like saving images every epoch and logging. | ||
|
||
The callback classes must inherit the abstract class. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
name: cp_gan_env | ||
channels: | ||
- anaconda | ||
- pytorch | ||
- nvidia | ||
- conda-forge | ||
dependencies: | ||
- conda-forge::python=3.9 | ||
- conda-forge::pip | ||
- pytorch::pytorch | ||
- pytorch::torchvision | ||
- pytorch::torchaudio | ||
- pytorch::pytorch-cuda=12.1 | ||
- conda-forge::seaborn | ||
- conda-forge::matplotlib | ||
- conda-forge::jupyter | ||
- conda-forge::pre_commit | ||
- conda-forge::pandas | ||
- conda-forge::pillow | ||
- conda-forge::numpy | ||
- conda-forge::pathlib2 | ||
- conda-forge::scikit-learn | ||
- conda-forge::opencv | ||
- conda-forge::pyarrow | ||
- conda-forge::ipython | ||
- conda-forge::notebook | ||
- conda-forge::albumentations | ||
- conda-forge::optuna | ||
- conda-forge::mysqlclient | ||
- conda-forge::openjdk | ||
- conda-forge::gtk2 | ||
- conda-forge::typing-extensions | ||
- conda-forge::Jinja2 | ||
- conda-forge::inflect | ||
- conda-forge::wxpython | ||
- conda-forge::sentry-sdk | ||
- pip: | ||
- mlflow | ||
- cellprofiler==4.2.8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would recommend either consistently using type hints or using the function docstring for types. I think type hints are generally better though
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I purposefully did not include type hints to trainer objects here because the trainer classes already have type hints of the callback classes and adding type hints to trainer class here in the callback class would cause the problem of circular importing. This is probably a consequence of my sub-optimal class design that can needs to be solved with some refactoring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember if this is the way to do it exactly, but you can do something like this:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although, maybe there is a better alternative here such as changing the design