diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2201fd5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +# images and anything under mlflow +*.png +examples/example_train*/* + +# pycache +*.pyc \ No newline at end of file diff --git a/callbacks/AbstractCallback.py b/callbacks/AbstractCallback.py new file mode 100644 index 0000000..758d53e --- /dev/null +++ b/callbacks/AbstractCallback.py @@ -0,0 +1,62 @@ +from abc import ABC + +class AbstractCallback(ABC): + """ + Abstract class for callbacks in the training process. + Callbacks can be used to plot intermediate metrics, log contents, save checkpoints, etc. + """ + + def __init__(self, name: str): + """ + :param name: Name of the callback. + """ + self._name = name + self._trainer = None + + @property + def name(self): + """ + Getter for callback name + """ + return self._name + + @property + def trainer(self): + """ + Allows for access of trainer + """ + return self._trainer + + def _set_trainer(self, trainer): + """ + Helper function called by trainer class to initialize trainer value field + + :param trainer: trainer object + :type trainer: AbstractTrainer or subclass + """ + + self._trainer = trainer + + def on_train_start(self): + """ + Called at the start of training. + """ + pass + + def on_epoch_start(self): + """ + Called at the start of each epoch. + """ + pass + + def on_epoch_end(self): + """ + Called at the end of each epoch. + """ + pass + + def on_train_end(self): + """ + Called at the end of training. + """ + pass \ No newline at end of file diff --git a/callbacks/IntermediatePlot.py b/callbacks/IntermediatePlot.py new file mode 100644 index 0000000..ad2c9ea --- /dev/null +++ b/callbacks/IntermediatePlot.py @@ -0,0 +1,117 @@ +import pathlib +from typing import List, Union +import random + +import torch +import torch.nn as nn +from torch.utils.data import Dataset + +from .AbstractCallback import AbstractCallback +from ..datasets.PatchDataset import PatchDataset +from ..evaluation.visualization_utils import plot_predictions_grid_from_model + +class IntermediatePlot(AbstractCallback): + """ + Callback to plot model generated outputs, ground + truth, and input stained image patches at the end of each epoch. + """ + + def __init__(self, + name: str, + path: Union[pathlib.Path, str], + dataset: Union[Dataset, PatchDataset], + plot_n_patches: int=5, + indices: Union[List[int], None]=None, + plot_metrics: List[nn.Module]=None, + every_n_epochs: int=5, + random_seed: int=42, + **kwargs): + """ + Initialize the IntermediatePlot callback. + Allows plots of predictions to be generated during training for monitoring of training progress. + Supports both PatchDataset and Dataset classes for plotting. + This callback, when passed into the trainer, will plot the model predictions on a subset of the provided dataset at the end of each epoch. + + :param name: Name of the callback. + :param path: Path to save the model weights. + :type path: Union[pathlib.Path, str] + :param dataset: Dataset to be used for plotting intermediate results. + :type dataset: Union[Dataset, PatchDataset] + :param plot_n_patches: Number of patches to randomly select and plot, defaults to 5. + The exact patches/images being plotted may vary due to a difference in seed or dataset size. + To ensure best reproducibility and consistency, please use a fixed dataset and indices argument instead. + :type plot_n_patches: int, optional + :param indices: Optional list of specific indices to subset the dataset before inference. + Overrides the plot_n_patches and random_seed arguments and uses the indices list to subset. + :type indices: Union[List[int], None] + :param plot_metrics: List of metrics to compute and display in plot title, defaults to None. + :type plot_metrics: List[nn.Module], optional + :param kwargs: Additional keyword arguments to be passed to plot_patches. + :type kwargs: dict + :param every_n_epochs: How frequent should intermediate plots should be plotted, defaults to 5 + :type every_n_epochs: int + :param random_seed: Random seed for reproducibility for random patch/image selection, defaults to 42. + :type random_seed: int + :raises TypeError: If the dataset is not an instance of PatchDataset. + """ + super().__init__(name) + self._path = path + if isinstance(dataset, Dataset): + pass + elif isinstance(dataset, PatchDataset): + pass + else: + raise TypeError(f"Expected PatchDataset, got {type(dataset)}") + + self._dataset = dataset + + # Additional kwargs passed to plot_patches + self.plot_metrics = plot_metrics + self.every_n_epochs = every_n_epochs + self.plot_kwargs = kwargs + + if indices is not None: + # Check if indices are within bounds + for i in indices: + if i >= len(self._dataset): + raise ValueError(f"Index {i} out of bounds for dataset of size {len(self._dataset)}") + self._dataset_subset_indices = indices + else: + # Generate random indices to subset given seed and plot_n_patches + plot_n_patches = min(plot_n_patches, len(self._dataset)) + random.seed(random_seed) + self._dataset_subset_indices = random.sample(range(len(self._dataset)), plot_n_patches) + + def on_epoch_end(self): + """ + Called at the end of each epoch to plot predictions if the epoch is a multiple of `every_n_epochs`. + """ + if (self.trainer.epoch + 1) % self.every_n_epochs == 0 or self.trainer.epoch + 1 == self.trainer.epoch: + self._plot() + + def on_train_end(self): + """ + Called at the end of training. Plots if not already done in the last epoch. + """ + if (self.trainer.epoch + 1) % self.every_n_epochs != 0: + self._plot() + + def _plot(self): + """ + Helper method to generate and save plots. + Plot dataset with model predictions on n random images from dataset at the end of each epoch. + Called by the on_epoch_end and on_train_end methods + """ + + original_device = next(self.trainer.model.parameters()).device + + plot_predictions_grid_from_model( + model=self.trainer.model, + dataset=self._dataset, + indices=self._dataset_subset_indices, + metrics=self.plot_metrics, + save_path=f"{self._path}/epoch_{self.trainer.epoch}.png", + device=original_device, + show=False, + **self.plot_kwargs + ) \ No newline at end of file diff --git a/callbacks/MlflowLogger.py b/callbacks/MlflowLogger.py new file mode 100644 index 0000000..7b7e1a4 --- /dev/null +++ b/callbacks/MlflowLogger.py @@ -0,0 +1,115 @@ +import os +import pathlib +import tempfile +from typing import Union, Dict, Optional + +import mlflow +import torch + +from .AbstractCallback import AbstractCallback + +class MlflowLogger(AbstractCallback): + """ + Callback to log metrics to MLflow. + """ + + def __init__(self, + + name: str, + artifact_name: str = 'best_model_weights.pth', + mlflow_uri: Union[pathlib.Path, str] = None, + mlflow_experiment_name: Optional[str] = None, + mlflow_start_run_args: dict = None, + mlflow_log_params_args: dict = None, + + ): + """ + Initialize the MlflowLogger callback. + + :param name: Name of the callback. + :param artifact_name: Name of the artifact file to log, defaults to 'best_model_weights.pth'. + :param mlflow_uri: URI for the MLflow tracking server, defaults to None. + If a path is specified, the logger class will call set_tracking_uri to that supplied path + thereby initiating a new tracking server. + If None (default), the logger class will not tamper with mlflow server to enable logging to a global server + initialized outside of this class. + :type mlflow_uri: pathlib.Path or str, optional + :param mlflow_experiment_name: Name of the MLflow experiment, defaults to None, which will not call the + set_experiment method of mlflow and will use whichever experiment name that is globally configured. If a + name is provided, the logger class will call set_experiment to that supplied name. + :type mlflow_experiment_name: str, optional + :param mlflow_start_run_args: Additional arguments for starting an MLflow run, defaults to None. + :type mlflow_start_run_args: dict, optional + :param mlflow_log_params_args: Additional arguments for logging parameters to MLflow, defaults to None. + :type mlflow_log_params_args: dict, optional + """ + super().__init__(name) + + if mlflow_uri is not None: + try: + mlflow.set_tracking_uri(mlflow_uri) + except Exception as e: + raise RuntimeError(f"Error setting MLflow tracking URI: {e}") + + if mlflow_experiment_name is not None: + try: + mlflow.set_experiment(mlflow_experiment_name) + except Exception as e: + raise RuntimeError(f"Error setting MLflow experiment: {e}") + + self._artifact_name = artifact_name + self._mlflow_start_run_args = mlflow_start_run_args + self._mlflow_log_params_args = mlflow_log_params_args + + def on_train_start(self): + """ + Called at the start of training. + + Calls mlflow start run and logs params if provided + """ + + if self._mlflow_start_run_args is None: + pass + elif isinstance(self._mlflow_start_run_args, Dict): + mlflow.start_run( + **self._mlflow_start_run_args + ) + else: + raise TypeError("mlflow_start_run_args must be None or a dictionary.") + + if self._mlflow_log_params_args is None: + pass + elif isinstance(self._mlflow_log_params_args, Dict): + mlflow.log_params( + self._mlflow_log_params_args + ) + else: + raise TypeError("mlflow_log_params_args must be None or a dictionary.") + + def on_epoch_end(self): + """ + Called at the end of each epoch. + + Iterate over the most recent log items in trainer and call mlflow log metric + """ + for key, values in self.trainer.log.items(): + if values is not None and len(values) > 0: + value = values[-1] + else: + value = None + mlflow.log_metric(key, value, step=self.trainer.epoch) + + def on_train_end(self): + """ + Called at the end of training. + + Saves trainer best model to a temporary directory and calls mlflow log artifact + Then ends run + """ + # Save weights to a temporary directory and log artifacts + with tempfile.TemporaryDirectory() as tmpdirname: + weights_path = os.path.join(tmpdirname, self._artifact_name) + torch.save(self.trainer.best_model, weights_path) + mlflow.log_artifact(weights_path, artifact_path="models") + + mlflow.end_run() \ No newline at end of file diff --git a/callbacks/README.md b/callbacks/README.md new file mode 100644 index 0000000..d776cf7 --- /dev/null +++ b/callbacks/README.md @@ -0,0 +1,3 @@ +Here lives the callback classes that are meant to be fed into trainers to do stuff like saving images every epoch and logging. + +The callback classes must inherit the abstract class. \ No newline at end of file diff --git a/cp_gan_env.yml b/cp_gan_env.yml new file mode 100644 index 0000000..4e68d25 --- /dev/null +++ b/cp_gan_env.yml @@ -0,0 +1,39 @@ +name: cp_gan_env +channels: + - anaconda + - pytorch + - nvidia + - conda-forge +dependencies: + - conda-forge::python=3.9 + - conda-forge::pip + - pytorch::pytorch + - pytorch::torchvision + - pytorch::torchaudio + - pytorch::pytorch-cuda=12.1 + - conda-forge::seaborn + - conda-forge::matplotlib + - conda-forge::jupyter + - conda-forge::pre_commit + - conda-forge::pandas + - conda-forge::pillow + - conda-forge::numpy + - conda-forge::pathlib2 + - conda-forge::scikit-learn + - conda-forge::opencv + - conda-forge::pyarrow + - conda-forge::ipython + - conda-forge::notebook + - conda-forge::albumentations + - conda-forge::optuna + - conda-forge::mysqlclient + - conda-forge::openjdk + - conda-forge::gtk2 + - conda-forge::typing-extensions + - conda-forge::Jinja2 + - conda-forge::inflect + - conda-forge::wxpython + - conda-forge::sentry-sdk + - pip: + - mlflow + - cellprofiler==4.2.8 diff --git a/datasets/CachedDataset.py b/datasets/CachedDataset.py new file mode 100644 index 0000000..04bc237 --- /dev/null +++ b/datasets/CachedDataset.py @@ -0,0 +1,220 @@ +from typing import Optional + +from torch.utils.data import Dataset +from collections import OrderedDict + +class CachedDataset(Dataset): + """ + A patched dataset that caches data from dataset objects that + dynamically loads the data to reduce memory overhead during training + """ + + def __init__( + self, + dataset: Dataset, + cache_size: Optional[int]=None, + prefill_cache: bool=False, + **kwargs + ): + """ + Initialize the CachedDataset from a dataset object + + :param dataset: Dataset object to cache data from + :type dataset: Dataset + :param cache_size: Size of the cache, if None, the cache + size is set to the length of the dataset. + :type cache_size: int + :param prefill_cache: Whether to prefill the cache + :type prefill_cache: bool + """ + + if len(dataset) == 0: + raise ValueError("Dataset is empty") + + self.__dataset = dataset + + self.__cache_size = cache_size if cache_size is not None else len(dataset) + self.__cache = OrderedDict() + + # cache for metadata + self.__cache_input_names = OrderedDict() + self.__cache_target_names = OrderedDict() + + # pointer to the current patch index + self._current_idx = None + + if prefill_cache: + self.populate_cache() + + """Overriden methods for Dataset class""" + def __len__(self): + """ + Return the length of the dataset + """ + return len(self.__dataset) + + def __getitem__(self, _idx: int): + """ + Get the data from the dataset object at the given index + If the data is not in the cache, load it from the dataset object and update the cache + + :param _idx: Index of the data to get + :type _idx: int + """ + self._current_idx = _idx + + if _idx in self.__cache: + # cache hit + return self.__cache[_idx] + else: + # cache miss, load from parent class method dynamically + self._push_cache(_idx) + return self.__cache[_idx] + + """Setters""" + + def set_cache_size(self, cache_size: int): + """ + Set the cache size. Does not automatically repopulate the cache but + will pop the cache if the size is exceeded + + :param cache_size: Size of the cache + :type cache_size: int + """ + self.__cache_size = cache_size + # pop the cache if the size is exceeded + while len(self.__cache) > self.__cache_size: + self._pop_cache() + + """Properties to remain accessible""" + @property + def input_names(self): + """ + Get the input names from the dataset object + """ + if self._current_idx is not None: + if self._current_idx in self.__cache_input_names: + return self.__cache_input_names[self._current_idx] + else: + _ = self.__dataset[self._current_idx] + return self.__dataset.input_names + else: + raise ValueError("No current index set") + + @property + def target_names(self): + """ + Get the target names from the dataset object + """ + if self._current_idx is not None: + ## TODO: need to think over if this is at all necessary + if self._current_idx in self.__cache_target_names: + return self.__cache_target_names[self._current_idx] + else: + _ = self.__dataset[self._current_idx] + return self.__dataset.target_names + else: + raise ValueError("No current index set") + + @property + def input_channel_keys(self): + """ + Get the input channel keys from the dataset object + """ + try: + return self.__dataset.input_channel_keys + except AttributeError: + return None + + @property + def target_channel_keys(self): + """ + Get the target channel keys from the dataset object + """ + try: + return self.__dataset.target_channel_keys + except AttributeError: + return None + + @property + def input_transform(self): + """ + Get the input transform from the dataset object + """ + return self.__dataset.input_transform + + @property + def target_transform(self): + """ + Get the target transform from the dataset object + """ + return self.__dataset.target_transform + + @property + def dataset(self): + """ + Get the dataset object + """ + return self.__dataset + + """Cache method""" + def populate_cache(self): + """ + Populates/clears the current cache and re-populate the cache with data from the dataset object + Iteratively calls the _push_cache method on a sequence of indices + """ + self._clear_cache() + for _idx in range(min(self.__cache_size, len(self.__dataset))): + self._push_cache(_idx) + + """Internal helper methods""" + + def _push_cache(self, _idx: int): + """ + Update the cache with a single item retrieved from the dataset object. + Calls the update cache metadata method as well to sync data and metadata + Pops the cache if the cache size is exceeded on a first in, first out basis + + :param _idx: Index of the data to cache + :type _idx: int + """ + self._current_idx = _idx + self.__cache[_idx] = self.__dataset[_idx] + if len(self.__cache) >= self.__cache_size: + self._pop_cache() + self._push_cache_metadata(_idx) + + def _pop_cache(self): + """ + Helper method to pop the cache on a first in, first out basis + """ + self.__cache.popitem(last=False) + + def _push_cache_metadata(self, _idx: int): + """ + Update the cache metadata with data from the dataset object + Meant to be called by _push_cache method + + :param _idx: Index of the data to cache + :type _idx: int + """ + self.__cache_input_names[_idx] = self.__dataset.input_names + self.__cache_target_names[_idx] = self.__dataset.target_names + + if len(self.__cache_input_names) >= self.__cache_size: + self._pop_cache_metadata() + + def _pop_cache_metadata(self): + """ + Helper method to pop the cache metadata on a first in, first out basis + """ + self.__cache_input_names.popitem(last=False) + self.__cache_target_names.popitem(last=False) + + def _clear_cache(self): + """ + Clear the cache and cache metadata + """ + self.__cache.clear() + self.__cache_input_names.clear() + self.__cache_target_names.clear() \ No newline at end of file diff --git a/datasets/GenericImageDataset.py b/datasets/GenericImageDataset.py new file mode 100644 index 0000000..71713f7 --- /dev/null +++ b/datasets/GenericImageDataset.py @@ -0,0 +1,299 @@ +import logging +import pathlib +import re +from collections import defaultdict +from typing import List, Optional, Union, Tuple, Dict + +import numpy as np +import torch +from PIL import Image +from albumentations import ImageOnlyTransform +from albumentations.core.composition import Compose +from torch.utils.data import Dataset + + +class GenericImageDataset(Dataset): + """ + A generic image dataset that automatically associates images under a supplied path + with sites and channels based on two separate regex patterns for site and channel detection. + """ + + def __init__( + self, + image_dir: Union[str, pathlib.Path], + site_pattern: str, + channel_pattern: str, + _input_channel_keys: Optional[Union[str, List[str]]] = None, + _target_channel_keys: Optional[Union[str, List[str]]] = None, + _input_transform: Optional[Union[Compose, ImageOnlyTransform]] = None, + _target_transform: Optional[Union[Compose, ImageOnlyTransform]] = None, + _PIL_image_mode: str = 'I;16', + verbose: bool = False, + check_exists: bool = True, + **kwargs + ): + """ + Initialize the dataset. + + :param image_dir: Directory containing the images. + :param site_pattern: Regex pattern to extract site identifiers. + :param channel_pattern: Regex pattern to extract channel identifiers. + :param _input_channel_keys: List of channel names to use as inputs. + :param _target_channel_keys: List of channel names to use as targets. + :param _input_transform: Transformations to apply to input images. + :param _target_transform: Transformations to apply to target images. + :param _PIL_image_mode: Mode for loading images. + :param check_exists: Whether to check if all referenced image files exist. + """ + + self._initialize_logger(verbose) + self.image_dir = pathlib.Path(image_dir).resolve() + self.site_pattern = re.compile(site_pattern) + self.channel_pattern = re.compile(channel_pattern) + self._PIL_image_mode = _PIL_image_mode + + if not self.image_dir.exists(): + raise FileNotFoundError(f"Image directory {self.image_dir} not found") + + # Parse images and organize by site + self._channel_keys = [] + self.__image_paths = self._get_image_paths(check_exists) + + # Set input and target channel keys + self._input_channel_keys = self.__check_channel_keys(_input_channel_keys) + self._target_channel_keys = self.__check_channel_keys(_target_channel_keys) + + self.set_input_transform(_input_transform) + self.set_target_transform(_target_transform) + + # Index patches and images + self.__iter_image_id = list(range(len(self.__image_paths))) + + # Initialize cache + self.__input_cache = {} + self.__target_cache = {} + self.__cache_image_id = None + + # Initialize the current input and target names + self.__current_input_names = None + self.__current_target_names = None + + """ + Properties + """ + + @property + def image_paths(self): + return self.__image_paths + + @property + def input_transform(self): + return self._input_transform + + @property + def target_transform(self): + return self._target_transform + + @property + def input_channel_keys(self): + return self._input_channel_keys + + @property + def target_channel_keys(self): + return self._target_channel_keys + @property + def input_names(self): + return self.__current_input_names + + @property + def target_names(self): + return self.__current_target_names + + """ + Setters + """ + + def set_input_transform(self, _input_transform: Optional[Union[Compose, ImageOnlyTransform]] = None): + """Sets the input image transform.""" + self.logger.debug("Setting input transform ...") + self._input_transform = _input_transform + + def set_target_transform(self, _target_transform: Optional[Union[Compose, ImageOnlyTransform]] = None): + """Sets the target image transform.""" + self.logger.debug("Setting target transform ...") + self._target_transform = _target_transform + + def set_input_channel_keys(self, _input_channel_keys: Union[str, List[str]]): + """ + Set the input channel keys + + :param _input_channel_keys: The input channel keys + :type _input_channel_keys: str or list of str + """ + self._input_channel_keys = self.__check_channel_keys(_input_channel_keys) + self.logger.debug(f"Set input channel(s) as {self._input_channel_keys}") + + # clear the cache + self.__cache_image_id = None + + def set_target_channel_keys(self, _target_channel_keys: Union[str, List[str]]): + """ + Set the target channel keys + + :param _target_channel_keys: The target channel keys + :type _target_channel_keys: str or list of str + """ + self._target_channel_keys = self.__check_channel_keys(_target_channel_keys) + self.logger.debug(f"Set target channel(s) as {self._target_channel_keys}") + + # clear the cache + self.__cache_image_id = None + + """ + Logging and Debugging + """ + + def _initialize_logger(self, verbose: bool): + """Initializes the logger.""" + self.logger = logging.getLogger(f"{__name__}.{id(self)}") + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.DEBUG if verbose else logging.WARNING) + + """ + Internal helper functions + """ + + def _get_image_paths(self, check_exists: bool): + + # sets for all unique sites and channels + sites = set() + channels = set() + image_files = list(self.image_dir.glob("*")) + + site_to_channels = defaultdict(dict) + for file in image_files: + site_match = self.site_pattern.search(file.name) + try: + site = site_match.group(1) + except: + continue + sites.add(site) + + channel_match = self.channel_pattern.search(file.name) + try: + channel = channel_match.group(1) + except: + continue + channels.add(channel) + + site_to_channels[site][channel] = file + + # format as list of dicts + image_paths = [] + for site, channel_to_file in site_to_channels.items(): + ## Keep only sites with all channels + if all([c in site_to_channels[site] for c in channels]): + if check_exists and not all(path.exists() for path in channel_to_file.values()): + continue + image_paths.append(channel_to_file) + + self.logger.debug(f"Channel keys: {channels} detected") + self._channel_keys = list(channels) + + return image_paths + + def __len__(self): + return len(self.__image_paths) + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Retrieves the input and target images for a given index. + + :param idx: The index of the image. + :return: Tuple of input and target images as tensors. + """ + if idx >= len(self) or idx < 0: + raise IndexError("Index out of bounds") + + site_id = self.__iter_image_id[idx] + self._cache_image(site_id) + + # Stack input and target images + input_images = np.stack([self.__input_cache[key] for key in self._input_channel_keys], axis=0) + target_images = np.stack([self.__target_cache[key] for key in self._target_channel_keys], axis=0) + + # Apply transformations + if self._input_transform: + input_images = self._input_transform(image=input_images)['image'] + if self._target_transform: + target_images = self._target_transform(image=target_images)['image'] + + return torch.from_numpy(input_images).float(), torch.from_numpy(target_images).float() + + def _cache_image(self, site_id: str) -> None: + """ + Loads and caches images for a given site ID. + + :param site_id: The site ID. + """ + if self.__cache_image_id != site_id: + self.__cache_image_id = site_id + self.__input_cache = {} + self.__target_cache = {} + + ## Update target and input names (which are just file path(s)) + self.__current_input_names = [self.__image_paths[site_id][key] for key in self._input_channel_keys] + self.__current_target_names = [self.__image_paths[site_id][key] for key in self._target_channel_keys] + + for key in self._input_channel_keys: + self.__input_cache[key] = self._read_convert_image(self.__image_paths[site_id][key]) + for key in self._target_channel_keys: + self.__target_cache[key] = self._read_convert_image(self.__image_paths[site_id][key]) + + def _read_convert_image(self, image_path: pathlib.Path) -> np.ndarray: + """ + Reads and converts an image to a numpy array. + + :param image_path: The image file path. + :return: The image as a numpy array. + """ + return np.array(Image.open(image_path).convert(self._PIL_image_mode)) + + def __check_channel_keys( + self, + channel_keys: Optional[Union[str, List[str]]] + ) -> List[str]: + """ + Checks user supplied channel key against the inferred ones from the file + + :param channel_keys: user supplied list or single object of string channel keys + :type channel_keys: string or list of strings + """ + if channel_keys is None: + self.logger.debug("No channel keys specified, skip") + return None + elif isinstance(channel_keys, str): + channel_keys = [channel_keys] + elif isinstance(channel_keys, list): + if not all([isinstance(key, str) for key in channel_keys]): + raise ValueError('Channel keys must be a string or a list of strings.') + else: + raise ValueError('Channel keys must be a string or a list of strings.') + + ## Check supplied channel keys against inferred ones + filtered_channel_keys = [] + for key in channel_keys: + if not key in self._channel_keys: + self.logger.debug( + f"ignoring channel key {key} as it does not match loaddata csv file" + ) + else: + filtered_channel_keys.append(key) + + if len(filtered_channel_keys) == 0: + raise ValueError(f'None of the supplied channel keys match the loaddata csv file') + + return filtered_channel_keys \ No newline at end of file diff --git a/datasets/ImageDataset.py b/datasets/ImageDataset.py new file mode 100644 index 0000000..bf6033f --- /dev/null +++ b/datasets/ImageDataset.py @@ -0,0 +1,459 @@ +import logging +import pathlib +from random import randint +from typing import List, Optional, Union, Tuple + +import numpy as np +import pandas as pd +from PIL import Image +from albumentations import ImageOnlyTransform +from albumentations.core.composition import Compose +import torch +from torch.utils.data import Dataset + +class ImageDataset(Dataset): + """ + Image Dataset Class from pe2loaddata generated cellprofiler loaddata csv + """ + def __init__( + self, + _loaddata_csv, + _input_channel_keys: Optional[Union[str, List[str]]] = None, + _target_channel_keys: Optional[Union[str, List[str]]] = None, + _input_transform: Optional[Union[Compose, ImageOnlyTransform]] = None, + _target_transform: Optional[Union[Compose, ImageOnlyTransform]] = None, + _PIL_image_mode: str = 'I;16', + verbose: bool = False, + file_column_prefix: str = 'FileName_', + path_column_prefix: str = 'PathName_', + check_exists: bool = False, + **kwargs + ): + """ + Initialize the ImageDataset. + + :param _loaddata_csv: The dataframe or path to a csv file containing the image paths and labels. + :type _loaddata_csv: Union[pd.DataFrame, str] + :param _input_channel_keys: Keys for input channels. Can be a single key or a list of keys. + :type _input_channel_keys: Optional[Union[str, List[str]]] + :param _target_channel_keys: Keys for target channels. Can be a single key or a list of keys. + :type _target_channel_keys: Optional[Union[str, List[str]]] + :param _input_transform: Transformations to apply to the input images. + :type _input_transform: Optional[Union[Compose, ImageOnlyTransform]] + :param _target_transform: Transformations to apply to the target images. + :type _target_transform: Optional[Union[Compose, ImageOnlyTransform]] + :param _PIL_image_mode: Mode to use when loading images with PIL. Default is 'I;16'. + :type _PIL_image_mode: str + :param kwargs: Additional keyword arguments. + """ + + self._initialize_logger(verbose) + self._loaddata_df = self._load_loaddata(_loaddata_csv, **kwargs) + self._channel_keys = list(self.__infer_channel_keys(file_column_prefix, path_column_prefix)) + + # Initialize the cache for the input and target images + self.__input_cache = {} + self.__target_cache = {} + self.__cache_image_id = None + + # Set input/target channels + self.logger.debug("Setting input channel(s) ...") + self._input_channel_keys = self.__check_channel_keys(_input_channel_keys) + self.logger.debug("Setting target channel(s) ...") + self._target_channel_keys = self.__check_channel_keys(_target_channel_keys) + + self.set_input_transform(_input_transform) + self.set_target_transform(_target_transform) + + self._PIL_image_mode = _PIL_image_mode + + # Obtain image paths + self.__image_paths = self._get_image_paths( + file_column_prefix=file_column_prefix, + path_column_prefix=path_column_prefix, + check_exists=check_exists, + **kwargs + ) + # Index patches and images + self.__iter_image_id = list(range(len(self.__image_paths))) + + # Initialize the current input and target names + self.__current_input_names = None + self.__current_target_names = None + + """ + Overridden Iterator functions + """ + def __len__(self): + return len(self.__image_paths) + + def __getitem__(self, _idx: int)->Tuple[torch.Tensor, torch.Tensor]: + """ + Return the input and target images + :param _idx: The index of the image + :type _idx: int + :return: The input and target images, each with dimension [n_channels, height, width] + :rtype: Tuple[torch.Tensor, torch.Tensor] + """ + + if _idx >= len(self) or _idx < 0: + raise IndexError("Index out of bounds") + + if self._input_channel_keys is None or self._target_channel_keys is None: + raise ValueError("Input and target channel keys must be set to access data") + + image_id = self.__iter_image_id[_idx] + self._cache_image(image_id) + + ## Retrieve relevant channels as specified by input and target channel keys and stack + input_images = np.stack( + [self.__input_cache[key] for key in self._input_channel_keys], + axis=0) + target_images = np.stack( + [self.__target_cache[key] for key in self._target_channel_keys], + axis=0) + + ## Apply transform + if self._input_transform: + input_images = self._input_transform(image=input_images)['image'] + if self._target_transform: + target_images = self._target_transform(image=target_images)['image'] + + ## Cast to torch tensor and return + return torch.from_numpy(input_images).float(), torch.from_numpy(target_images).float() + + """ + Properties + """ + + @property + def image_paths(self): + return self.__image_paths + + @property + def input_transform(self): + return self._input_transform + + @property + def target_transform(self): + return self._target_transform + + @property + def input_channel_keys(self): + return self._input_channel_keys + + @property + def target_channel_keys(self): + return self._target_channel_keys + @property + def input_names(self): + return self.__current_input_names + + @property + def target_names(self): + return self.__current_target_names + + """ + Setters + """ + + def set_input_transform(self, _input_transform: Optional[Union[Compose, ImageOnlyTransform]]=None): + """ + Set the input transform + + :param _input_transform: The input transform + :type _input_transform: Compose or ImageOnlyTransform + """ + # Check and set input/target transforms + self.logger.debug("Setting input transform ...") + if self.__check_transforms(_input_transform): + self._input_transform = _input_transform + + + def set_target_transform(self, _target_transform: Optional[Union[Compose, ImageOnlyTransform]]=None): + """ + Set the target transform + + :param _target_transform: The target transform + :type _target_transform: Compose or ImageOnlyTransform + """ + # Check and set input/target transforms + self.logger.debug("Setting target transform ...") + if self.__check_transforms(_target_transform): + self._target_transform = _target_transform + + def set_input_channel_keys(self, _input_channel_keys: Union[str, List[str]]): + """ + Set the input channel keys + + :param _input_channel_keys: The input channel keys + :type _input_channel_keys: str or list of str + """ + self._input_channel_keys = self.__check_channel_keys(_input_channel_keys) + self.logger.debug(f"Set input channel(s) as {self._input_channel_keys}") + + # clear the cache + self.__cache_image_id = None + + def set_target_channel_keys(self, _target_channel_keys: Union[str, List[str]]): + """ + Set the target channel keys + + :param _target_channel_keys: The target channel keys + :type _target_channel_keys: str or list of str + """ + self._target_channel_keys = self.__check_channel_keys(_target_channel_keys) + self.logger.debug(f"Set target channel(s) as {self._target_channel_keys}") + + # clear the cache + self.__cache_image_id = None + + """ + Internal Helper functions + """ + def _initialize_logger(self, verbose: bool): + """ + Initialize logger instance + """ + self.logger = logging.getLogger(f"{__name__}.{id(self)}") + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.DEBUG if verbose else logging.WARNING) + + def _load_loaddata( + self, + _loaddata_csv: Union[pd.DataFrame, pathlib.Path], + **kwargs + ) -> pd.DataFrame: + """ + Read loaddata csv file, also does type checking + + :param _loaddata_csv: The path to the loaddata CSV file or a DataFrame. + :type _loaddata_csv: Union[pd.DataFrame, pathlib.Path] + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :raises ValueError: If no loaddata CSV is supplied or if the file type is not supported. + :raises FileNotFoundError: If the specified file does not exist. + :return: The loaded data as a DataFrame. + :rtype: pd.DataFrame + """ + + if _loaddata_csv is None: + raise ValueError("No loaddata csv supplied") + elif isinstance(_loaddata_csv, pd.DataFrame): + self.logger.debug("Dataframe supplied for loaddata_csv, using as is") + return _loaddata_csv + else: + self.logger.debug("Loading loaddata csv from file") + ## Convert string to pathlib Path + if not isinstance(_loaddata_csv, pathlib.Path): + try: + _loaddata_csv = pathlib.Path(_loaddata_csv) + except e: + raise e + + ## Handle file not exist + if not _loaddata_csv.exists(): + raise FileNotFoundError(f"File {_loaddata_csv} not found") + + ## Determine file extension and load accordingly + if _loaddata_csv.suffix == '.csv': + return pd.read_csv(_loaddata_csv) + elif _loaddata_csv.suffix == '.parquet': + return pd.read_parquet(_loaddata_csv) + else: + raise ValueError(f"File type {_loaddata_csv.suffix} not supported") + + def __infer_channel_keys( + self, + file_column_prefix: str, + path_column_prefix: str + ) -> set[str]: + """ + Infer channel names from the columns of loaddata csv. + This method identifies and returns the set of channel keys by comparing + the columns in the dataframe that start with the specified file and path + column prefixes. The channel keys are the suffixes of these columns after + removing the prefixes. + + :param file_column_prefix: The prefix for columns that indicate filenames. + :type file_column_prefix: str + :param path_column_prefix: The prefix for columns that indicate paths. + :type path_column_prefix: str + :return: A set of channel keys inferred from the loaddata csv. + :rtype: set[str] + :raises ValueError: If no path or file columns are found, or if no matching + channel keys are found between file and path columns. + """ + + self.logger.debug("Inferring channel keys from loaddata csv") + # Retrieve columns that indicate path and filename to image files + file_columns = [col for col in self._loaddata_df.columns if col.startswith(file_column_prefix)] + path_columns = [col for col in self._loaddata_df.columns if col.startswith(path_column_prefix)] + + if len(file_columns) == 0 or len(path_columns) == 0: + raise ValueError('No path or file columns found in loaddata csv.') + + # Anything following the prefix should be the channel names + file_channel_keys = [col.replace(file_column_prefix, '') for col in file_columns] + path_channel_keys = [col.replace(path_column_prefix, '') for col in path_columns] + channel_keys = set(file_channel_keys).intersection(set(path_channel_keys)) + + if len(channel_keys) == 0: + raise ValueError('No matching channel keys found between file and path columns.') + + self.logger.debug(f"Channel keys: {channel_keys} inferred from loaddata csv") + + return channel_keys + + def __check_channel_keys( + self, + channel_keys: Optional[Union[str, List[str]]] + ) -> List[str]: + """ + Checks user supplied channel key against the inferred ones from the file + + :param channel_keys: user supplied list or single object of string channel keys + :type channel_keys: string or list of strings + """ + if channel_keys is None: + self.logger.debug("No channel keys specified, skip") + return None + elif isinstance(channel_keys, str): + channel_keys = [channel_keys] + elif isinstance(channel_keys, list): + if not all([isinstance(key, str) for key in channel_keys]): + raise ValueError('Channel keys must be a string or a list of strings.') + else: + raise ValueError('Channel keys must be a string or a list of strings.') + + ## Check supplied channel keys against inferred ones + filtered_channel_keys = [] + for key in channel_keys: + if not key in self._channel_keys: + self.logger.debug( + f"ignoring channel key {key} as it does not match loaddata csv file" + ) + else: + filtered_channel_keys.append(key) + + if len(filtered_channel_keys) == 0: + raise ValueError(f'None of the supplied channel keys match the loaddata csv file') + + return filtered_channel_keys + + def __check_transforms( + self, + transforms: Optional[Union[Compose, ImageOnlyTransform]] + ) -> bool: + """ + Checks if supplied iamge only transform is of valid type, if so, return True + + :param transforms: Transform + :type transforms: ImageOnlyTransform or Compose of ImageOnlyTransforms + :return: Boolean indicator of success + :rtype: bool + """ + if transforms is None: + pass + elif isinstance(transforms, Compose): + pass + elif isinstance(transforms, ImageOnlyTransform): + pass + else: + raise TypeError('Invalid image transform type') + + return True + + def _get_image_paths(self, + file_column_prefix: str, + path_column_prefix: str, + check_exists: bool = False, + **kwargs, + ) -> List[dict]: + """ + From loaddata csv, extract the paths to all image channels cooresponding to each view/site + + :param check_exists: check if every individual image file exist and remove those that do not + :type check_exists: bool + :return: A list of dictionaries containing the paths to the image channels + :rtype: List[dict] + """ + + # Define helper function to get the image file paths from all channels + # in a single row of loaddata csv (single view/site), organized into a dict + def get_channel_paths(row: pd.Series) -> Tuple[dict, bool]: + + missing = False + + multi_channel_paths = {} + for channel_key in self._channel_keys: + file_column = f"{file_column_prefix}{channel_key}" + path_column = f"{path_column_prefix}{channel_key}" + + if file_column in row and path_column in row: + file = pathlib.Path( + row[path_column] + ) / row[file_column] + if (not check_exists) or file.exists(): + multi_channel_paths[channel_key] = file + else: + missing = True + + return multi_channel_paths, missing + + image_paths = [] + self.logger.debug( + "Extracting image channel paths of site/view and associated"\ + "cell coordinates (if applicable) from loaddata csv") + + for _, row in self._loaddata_df.iterrows(): + multi_channel_paths, missing = get_channel_paths(row) + if not missing: + image_paths.append(multi_channel_paths) + + self.logger.debug(f"Extracted images of all input and target channels for {len(image_paths)} unique sites/view") + return image_paths + + def _read_convert_image(self, _image_path: pathlib.Path)->np.ndarray: + """ + Read and convert the image to a numpy array + + :param _image_path: The path to the image + :type _image_path: pathlib.Path + :return: The image as a numpy array + :rtype: np.ndarray + """ + return np.array(Image.open(_image_path).convert(self._PIL_image_mode)) + + def _cache_image(self, _id: int)->None: + """ + Determines if cached images need to be updated and updates the self.__input_cache and self.__target_cache + Meant to be called by __getitem__ method in dynamic patch cropping + + :param _id: The index of the image + :type _id: int + :return: None + :rtype: None + """ + + if self.__cache_image_id is None or self.__cache_image_id != _id: + self.__cache_image_id = _id + self.__input_cache = {} + self.__target_cache = {} + + ## Update target and input names (which are just file path(s)) + self.__current_input_names = [self.__image_paths[_id][key] for key in self._input_channel_keys] + self.__current_target_names = [self.__image_paths[_id][key] for key in self._target_channel_keys] + + for key in self._input_channel_keys: + self.__input_cache[key] = self._read_convert_image(self.__image_paths[_id][key]) + for key in self._target_channel_keys: + self.__target_cache[key] = self._read_convert_image(self.__image_paths[_id][key]) + else: + # No need to update the cache + pass + + return None \ No newline at end of file diff --git a/datasets/PatchDataset.py b/datasets/PatchDataset.py new file mode 100644 index 0000000..08cecae --- /dev/null +++ b/datasets/PatchDataset.py @@ -0,0 +1,677 @@ +import math +import pathlib +import random +from random import randint +from typing import List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +from pyarrow import parquet as pq +import torch + +from .ImageDataset import ImageDataset + +class PatchDataset(ImageDataset): + """ + Patch Dataset Class from pe2loaddata generated cellprofiler loaddata csv and sc features + """ + def __init__( + self, + _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]] = None, + patch_size: int = 64, + patch_generation_method: str = 'random', + patch_generation_random_seed: Optional[int] = None, + patch_generation_max_attempts: int = 1_000, + n_expected_patches_per_img: int = 5, + candidate_x: str = 'Metadata_Cells_Location_Center_X', + candidate_y: str = 'Metadata_Cells_Location_Center_Y', + **kwargs + ): + """ + Initialize the PatchDataset. + This method initializes the PatchDataset by setting up patch size, coordinates, and other parameters. + It also generates patches and initializes caches for input and target images. + + :param _sc_feature: Single-cell feature data or path to the data, by default None. + :type _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]], optional + :param patch_size: Size of the patches to generate, by default 64. + :type patch_size: int, optional + :param patch_generation_method: Method to generate patches ('random' or other methods), by default 'random'. + :type patch_generation_method: str, optional + :param patch_generation_random_seed: Random seed for patch generation, by default None. + :type patch_generation_random_seed: Optional[int], optional + :param patch_generation_max_attempts: Maximum number of attempts to generate patches, by default 1,000. + :type patch_generation_max_attempts: int, optional + :param n_expected_patches_per_img: Number of expected patches per image, by default 5. + :type n_expected_patches_per_img: int, optional + :param candidate_x: Column name for x-coordinates of candidate cells, by default 'Metadata_Cells_Location_Center_X'. + :type candidate_x: str, optional + :param candidate_y: Column name for y-coordinates of candidate cells, by default 'Metadata_Cells_Location_Center_Y'. + :type candidate_y: str, optional + :param kwargs: Additional keyword arguments. Namely those required by ImageDataset + :type kwargs: dict + """ + + self._patch_size = patch_size + self._merge_fields = None + self._x_col = None + self._y_col = None + self.__cell_coords = [] + + # This intializes the channels keys, loaddata loading, image mode and + # the overriden methods further merge the loaddata with sc features + super().__init__(_sc_feature=_sc_feature, + candidate_x=candidate_x, + candidate_y=candidate_y, + **kwargs) + + ## Generates patches with the specified arguments + self.__patch_coords = self._generate_patches( + _patch_size=self._patch_size, + patch_generation_method=patch_generation_method, + patch_generation_random_seed=patch_generation_random_seed, + n_expected_patches_per_img=n_expected_patches_per_img, + max_attempts=patch_generation_max_attempts, + consistent_img_size=kwargs.get('consistent_img_size', True), + ) + + # Index patches and images + self.__iter_image_id = [] + self.__iter_patch_id = [] + for i, _patch_coords in enumerate(self.__patch_coords): + for j, _ in enumerate(_patch_coords): + self.__iter_image_id.append(i) + self.__iter_patch_id.append(j) + + # Initialize the cache for the input and target images + self.__input_cache = {} + self.__target_cache = {} + self.__cache_image_id = None + + # Initialize the current input and target names and patch coordinates + self.__current_input_names = None + self.__current_target_names = None + self.__current_patch_coords = None + + """ + Overridden Iterator functions + """ + def __len__(self): + return len(self.__iter_patch_id) + + def __getitem__(self, _idx: int)->Tuple[torch.Tensor, torch.Tensor]: + """ + Return the input and target images + + :param _idx: The index of the image + :type _idx: int + :return: The input and target images, each with dimension [n_channels, height, width] + :rtype: Tuple[torch.Tensor, torch.Tensor] + """ + + if _idx >= len(self) or _idx < 0: + raise IndexError("Index out of bounds") + + if self._input_channel_keys is None or self._target_channel_keys is None: + raise ValueError("Input and target channel keys must be set to access data") + + image_id = self.__iter_image_id[_idx] + patch_id = self.__iter_patch_id[_idx] + self.__current_patch_coords = self.__patch_coords[image_id][patch_id] + + self._cache_image(image_id) + + ## Retrieve relevant channels as specified by input and target channel keys and stack + ## And further crop the patches with __current_patch_coords + input_images = np.stack( + [self._ImageDataset__input_cache[key][ + self.__current_patch_coords[1]:self.__current_patch_coords[1] + self._patch_size, + self.__current_patch_coords[0]:self.__current_patch_coords[0] + self._patch_size + ] for key in self._input_channel_keys], + axis=0) + target_images = np.stack( + [self._ImageDataset__target_cache[key][ + self.__current_patch_coords[1]:self.__current_patch_coords[1] + self._patch_size, + self.__current_patch_coords[0]:self.__current_patch_coords[0] + self._patch_size + ] for key in self._target_channel_keys], + axis=0) + + ## Apply transform + if self._input_transform: + input_images = self._input_transform(image=input_images)['image'] + if self._target_transform: + target_images = self._target_transform(image=target_images)['image'] + + ## Cast to torch tensor and return + return torch.from_numpy(input_images).float(), torch.from_numpy(target_images).float() + + """ + Properties + """ + + @property + def patch_size(self): + return self._patch_size + + @property + def cell_coords(self): + return self.__cell_coords + + @property + def all_patch_coords(self): + return self.__patch_coords + + @property + def patch_coords(self): + return self.__current_patch_coords + + @property + def raw_input(self): + """ + Returns a tuple of input, target raw images where the current patch is cropped + from. Relies on the parent class _ImageDataset__input_cache and _ImageDataset__target_cache. + Returns None when the cache is empty. Raises an error if the input channel keys are not set. + + :return: Tuple of input, target raw images + :rtype: Tuple[np.ndarray, np + """ + + if self._input_channel_keys is None: + raise ValueError("Input channel keys not set") + + return np.stack( + [self._ImageDataset__input_cache[key] for key in self._input_channel_keys], + axis=0) if self._input_channel_keys is not None else None + + @property + def raw_target(self): + """ + Returns a tuple of input, target raw images where the current patch is cropped + from. Relies on the parent class _ImageDataset__input_cache and _ImageDataset__target_cache. + Returns None when the cache is empty. Raises an error if the target channel keys are not set. + + :return: Tuple of input, target raw images + :rtype: Tuple[np.ndarray, np.ndarray] + """ + + if self._target_channel_keys is None: + raise ValueError("Target channel keys not set") + + return np.stack( + [self._ImageDataset__target_cache[key] for key in self._target_channel_keys], + axis=0) if self._target_channel_keys is not None else None + + """ + Internal Helper functions + """ + + def __preload_sc_feature(self, + _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]]=None) -> List[str]: + """ + Preload the sc feature dataframe/parquet file limiting only to the column headers + If a dataframe is supplied, use as is and return the column names + If a path to a csv file is supplied, load only the header row + If a path to a parquet file is supplied, load only the parquet schema name + + :param _sc_feature: The path to a csv file containing the cell profiler sc features + :type _sc_feature: str or pathlib.Path + :return: List of column names of dataframe/csv/parquet file + :rtype: List of strings + """ + + if _sc_feature is None: + # No sc feature supplied, cell coordinates not available, patch generation will fixed random + self.logger.debug("No sc feature supplied, patch generation will be random") + self._patch_generation_method = 'random' + return None + + elif isinstance(_sc_feature, pd.DataFrame): + self.logger.debug("Dataframe supplied for sc_feature, using as is") + return _sc_feature.columns.tolist() + + else: + self.logger.debug("Preloading sc feature from file") + if not isinstance(_sc_feature, pathlib.Path): + try: + _sc_feature = pathlib.Path(_sc_feature) + except e: + raise e + + if not _sc_feature.exists(): + raise FileNotFoundError(f"File {_sc_feature} not found") + + if _sc_feature.suffix == '.csv': + self.logger.debug("Preloading sc feature from csv file") + return pd.read_csv(_sc_feature, nrows=0).columns.tolist() + elif _sc_feature.suffix == '.parquet': + pq_file = pq.ParquetFile(_sc_feature) + return pq_file.schema.names + else: + raise ValueError(f"File type {_sc_feature.suffix} not supported") + + def __infer_merge_fields(self, + _loaddata_df, + _sc_col_names: List[str] + ) -> Union[List[str], None]: + """ + Find the columns that are common to both dataframes to use in an inner join + Mean to be used to associate loaddata_csv with sc features + + :param loaddata_csv: The first dataframe + :type loaddata_csv: pd.DataFrame + :param sc_feature: The second dataframe + :type sc_feature: pd.DataFrame + :return: The columns that are common to both dataframes + :rtype: List[str] + """ + if _sc_col_names is None: + return None + + self.logger.debug("Both loaddata_csv and sc_feature supplied, " \ + "inferring merge fields to associate the two dataframes") + merge_fields = list(set(_loaddata_df.columns).intersection(set(_sc_col_names))) + if len(merge_fields) == 0: + raise ValueError("No common columns found between loaddata_csv and sc_feature") + self.logger.debug(f"Merge fields inferred: {merge_fields}") + + return merge_fields + + def __infer_x_y_columns(self, + _loaddata_df, + _sc_col_names: List[str], + candidate_x: str, + candidate_y: str) -> Tuple[str, str]: + """ + Infer the columns that contain the x and y coordinates of the patches. + Will look for user specified patterns first but when no patterns + match this function returns the first columns that ends with _x and _y + + :param candidate_x: The candidate column name for the x coordinates + :type candidate_x: str + :param candidate_y: The candidate column name for the y coordinates + :type candidate_y: str + :return: The columns that contain the x and y coordinates of the patches + :rtype: Tuple[str, str] + """ + + if _loaddata_df is None: + return None, None + + if candidate_x not in _sc_col_names or candidate_y not in _sc_col_names: + self.logger.debug(f"X and Y columns {candidate_x}, {candidate_y} not detected in sc_features, attempting to infer from sc_feature dataframe") + + # infer the columns that contain the x and y coordinates + x_col_candidates = [col for col in _sc_col_names if col.lower().endswith('_x')] + y_col_candidates = [col for col in _sc_col_names if col.lower().endswith('_y')] + + if len(x_col_candidates) == 0 or len(y_col_candidates) == 0: + raise ValueError("No columns found containing the x and y coordinates") + else: + # sort x_col and y_col candidates + x_col_candidates.sort() + y_col_candidates.sort() + x_col_detected = x_col_candidates[0] + y_col_detected = y_col_candidates[0] + self.logger.debug(f"X and Y columns {x_col_detected}, {y_col_detected} detected in sc_feature dataframe, using as the coordinates for cell centers") + return x_col_detected, y_col_detected + else: + self.logger.debug(f"X and Y columns {candidate_x}, {candidate_y} detected in sc_feature dataframe, using as the coordinates for cell centers") + return candidate_x, candidate_y + + def __load_sc_feature(self, + _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]], + _merge_fields: List[str], + _x_col: str, + _y_col: str + ) -> Union[pd.DataFrame, None]: + """ + Load the actual sc feature as a dataframe, limiting the columns + to the merge fields and the x and y coordinates + + :param _sc_feature: The path to a csv file containing the cell profiler sc features + :type _sc_feature: str or pathlib.Path + :return: The dataframe containing the cell profiler sc features + :rtype: pd.DataFrame + """ + + if _sc_feature is None: + return None + elif isinstance(_sc_feature, pd.DataFrame): + self.logger.debug("Dataframe supplied for sc_feature, using as is") + return _sc_feature + else: + self.logger.debug("Loading sc feature from file") + if not isinstance(_sc_feature, pathlib.Path): + try: + _sc_feature = pathlib.Path(_sc_feature) + except e: + raise e + + if not _sc_feature.exists(): + raise FileNotFoundError(f"File {_sc_feature} not found") + + if _sc_feature.suffix == '.csv': + return pd.read_csv(_sc_feature, + usecols=_merge_fields + [_x_col, _y_col]) + elif _sc_feature.suffix == '.parquet': + return pq.read_table(_sc_feature, columns=_merge_fields + [_x_col, _y_col]).to_pandas() + else: + raise ValueError(f"File type {_sc_feature.suffix} not supported") + + """ + Overriden parent class helper functions + """ + + def _load_loaddata(self, + _loaddata_csv: Union[pd.DataFrame, pathlib.Path], + _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]], + candidate_x: str, + candidate_y: str, + ): + """ + Overridden function from parent class + Calls the parent class to get the loaddata df and then merges it with sc_feature + + :param _loaddata_csv: The path to the loaddata CSV file or a DataFrame. + :type _loaddata_csv: Union[pd.DataFrame, pathlib.Path] + :param _sc_feature: The path to the single cell feature parquet file, csv file or Datafarme + :type _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]] + :param candidate_x: User specified column to access cell x coords + :type candidate_x: str + :param candidate_y: User specified column to access cell y coords + :type candidate_y: str + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :raises ValueError: If no loaddata CSV is supplied or if the file type is not supported. + :raises FileNotFoundError: If the specified file does not exist. + :return: The loaded data as a DataFrame. + :rtype: pd.DataFrame + """ + + ## First calls the parent class to get the full loaddata df + loaddata_df = super()._load_loaddata(_loaddata_csv) + + ## Obtain column names of sc features first to avoid needing to read in the whole + ## Parquet file as only a very small number of columns are needed + sc_feature_col_names = self.__preload_sc_feature(_sc_feature) + + ## Infer columns corresponding to x and y coordinates to cells + self._x_col, self._y_col = self.__infer_x_y_columns( + loaddata_df, sc_feature_col_names, candidate_x, candidate_y) + + ## Infer merge fields between the sc features and loaddata + self._merge_fields = self.__infer_merge_fields(loaddata_df, sc_feature_col_names) + + ## Load sc features + sc_feature_df = self.__load_sc_feature( + _sc_feature, self._merge_fields, self._x_col, self._y_col) + + ## Perform the merge and return the merged dataframe (which is loaddata plus columns for x and y coordinates) + return loaddata_df.merge(sc_feature_df, on=self._merge_fields, how='inner') + + def _get_image_paths(self, + file_column_prefix: str, + path_column_prefix: str, + check_exists: bool = False, + **kwargs, + ) -> List[dict]: + """ + Overridden function + From loaddata csv, extract the paths to all image channels cooresponding to each view/site + + :param check_exists: check if every individual image file exist and remove those that do not + :type check_exists: bool + :return: A list of dictionaries containing the paths to the image channels + :rtype: List[dict] + """ + + # Define helper function to get the image file paths from all channels + # in a single row of loaddata csv (single view/site), organized into a dict + def get_channel_paths(row: pd.Series) -> Tuple[dict, bool]: + + missing = False + + multi_channel_paths = {} + for channel_key in self._channel_keys: + file_column = f"{file_column_prefix}{channel_key}" + path_column = f"{path_column_prefix}{channel_key}" + + if file_column in row and path_column in row: + file = pathlib.Path( + row[path_column] + ) / row[file_column] + if (not check_exists) or file.exists(): + multi_channel_paths[channel_key] = file + else: + missing = True + + return multi_channel_paths, missing + + # Define helper function to get the coordinates associated with a condition + def get_associated_coords(group): + + try: + return group.loc[:, [self._x_col, self._y_col]].values + except: + return None + + image_paths = [] + cell_coords = [] + self.logger.debug( + "Extracting image channel paths of site/view and associated"\ + "cell coordinates (if applicable) from loaddata csv") + + n_cells = 0 + ## Group by identifier of site/view + grouped = self._loaddata_df.groupby(self._merge_fields) + for _, group in grouped: + ## Retrieve image file paths for each channel from the first row + ## (any row within group should have the same filenames) + _, row = next(group.iterrows()) + multi_channel_paths, missing = get_channel_paths(row) + if not missing: + image_paths.append(multi_channel_paths) + # Get cell coords associated with each site/view + coords = get_associated_coords(group) + n_cells += len(coords) + cell_coords.append(coords) + + self.logger.debug("Extracted images of all input and target channels for " \ + f"{len(image_paths)} unique sites/view and {n_cells} cells") + + ## Save cell coords (a list of np 2d array) as attribute + self.__cell_coords = cell_coords + + ## Image paths will be saved as attribute by parent class init + return image_paths + + """ + Patch generation helper functions + """ + + def _generate_patches(self, + _patch_size: int, + patch_generation_method: str, + patch_generation_random_seed: int, + n_expected_patches_per_img=5, + max_attempts=1_000, + consistent_img_size=True, + )->None: + """ + Generate patches for each image in the dataset + + :param patch_generation_method: The method to use for generating patches + :type patch_generation_method: str + :param patch_generation_random_seed: The random seed to use for patch generation + :type patch_generation_random_seed: int + :param consistent_img_size: Whether the images are consistent in size. + If True, the patch generation will be based on the size of the first input channel of first image + If False, the patch generation will be based on the size of each image + :type consistent_img_size: bool + :param n_expected_patches_per_img: The number of patches to generate per image + :type n_expected_patches_per_img: int + :param max_attempts: The maximum number of attempts to generate a patch + :type max_attempts: int + :return: The coordinates of the patches + :rtype: List[List[Tuple[int + """ + if patch_generation_method == 'random_cell': + if self.__cell_coords is None: + raise ValueError("Cell coordinates not available for generating cell containing patches") + else: + self.logger.debug("Generating patches that contain cells") + def patch_fn(image_size, patch_size, cell_coords, n_expected_patches_per_img, max_attempts): + return self.__generate_cell_containing_patches_unit( + image_size, patch_size, cell_coords, n_expected_patches_per_img, max_attempts) + pass + elif patch_generation_method == 'random': + self.logger.debug("Generating random patches") + def patch_fn(image_size, patch_size, cell_coords, n_expected_patches_per_img, max_attempts): + # cell_coords is not used in this case + return self.__generate_random_patches_unit(image_size, patch_size, n_expected_patches_per_img, max_attempts) + pass + else: + raise ValueError("Patch generation method not supported") + + # Generate patches for each image + image_size = None + patch_count = 0 + patch_coords = [] + + # set random seed + if patch_generation_random_seed is not None: + random.seed(patch_generation_random_seed) + for channel_paths, cell_coords in zip(self._ImageDataset__image_paths, self.__cell_coords): + if consistent_img_size: + if image_size is not None: + pass + else: + try: + image_size = self._read_convert_image(channel_paths[self._channel_keys[0]]).shape[0] + self.logger.debug( + f"Image size inferred: {image_size} for all images " + "to force redetect image sizes for each view/site set consistent_img_size=False" + ) + except: + raise ValueError("Error reading image size") + pass + else: + try: + image_size = self._read_convert_image(channel_paths[self._channel_keys[0]]).shape[0] + except: + raise ValueError("Error reading image size") + + coords = patch_fn( + image_size=image_size, + patch_size=_patch_size, + cell_coords=cell_coords, + n_expected_patches_per_img=n_expected_patches_per_img, + max_attempts=max_attempts + ) + patch_coords.append(coords) + patch_count += len(coords) + + self.logger.debug(f"Generated {patch_count} patches for {len(self._ImageDataset__image_paths)} site/view") + return patch_coords + + @staticmethod + def __generate_cell_containing_patches_unit( + image_size, + patch_size, + cell_coords, + expected_n_patches=5, + max_attempts=1_000): + """ + Static helper function to generate patches that contain the cell coordinates + + :param image_size: The size of the image (square) + :type image_size: int + :param patch_size: The size of the square patches to generate + :type patch_size: int + :param cell_coords: The coordinates of the cells + :type cell_coords: List[Tuple[int, int]] + :param expected_n_patches: The number of patches to generate + :type expected_n_patches: int + :return: The coordinates of the patches + """ + + unit_size = math.gcd(image_size, patch_size) + tile_size_units = patch_size // unit_size + grid_size_units = image_size // unit_size + + cell_containing_units = {(x // unit_size, y // unit_size) for x, y in cell_coords} + placed_tiles = set() + retained_tiles = [] + + attempts = 0 + n_tiles = 0 + while attempts < max_attempts: + top_left_x = randint(0, grid_size_units - tile_size_units) + top_left_y = randint(0, grid_size_units - tile_size_units) + + tile_units = {(x, y) for x in range(top_left_x, top_left_x + tile_size_units) + for y in range(top_left_y, top_left_y + tile_size_units)} + + if any(tile_units & placed_tile for placed_tile in placed_tiles): + attempts += 1 + continue + + if tile_units & cell_containing_units: + retained_tiles.append((top_left_x * unit_size, top_left_y * unit_size)) + placed_tiles.add(frozenset(tile_units)) + n_tiles += 1 + + attempts += 1 + if n_tiles >= expected_n_patches: + break + + return retained_tiles + + @staticmethod + def __generate_random_patches_unit( + image_size, + patch_size, + expected_n_patches=5, + max_attempts=1_000): + """ + Static helper function to generate random patches + + :param image_size: The size of the image (square) + :type image_size: int + :param patch_size: The size of the square patches to generate + :type patch_size: int + :param expected_n_patches: The number of patches to generate + :type expected_n_patches: int + :return: The coordinates of the patches + """ + unit_size = math.gcd(image_size, patch_size) + tile_size_units = patch_size // unit_size + grid_size_units = image_size // unit_size + + placed_tiles = set() + retained_tiles = [] + + attempts = 0 + n_tiles = 0 + while attempts < max_attempts: + top_left_x = randint(0, grid_size_units - tile_size_units) + top_left_y = randint(0, grid_size_units - tile_size_units) + + # check for overlap with already placed tiles + tile_units = {(x, y) for x in range(top_left_x, top_left_x + tile_size_units) + for y in range(top_left_y, top_left_y + tile_size_units)} + + if any(tile_units & placed_tile for placed_tile in placed_tiles): + attempts += 1 + continue + + # no overlap, add the tile to the list of retained tiles + retained_tiles.append((top_left_x * unit_size, top_left_y * unit_size)) + placed_tiles.add(frozenset(tile_units)) + n_tiles += 1 + + attempts += 1 + if n_tiles >= expected_n_patches: + break + + return retained_tiles \ No newline at end of file diff --git a/datasets/README.md b/datasets/README.md new file mode 100644 index 0000000..6dd6d9e --- /dev/null +++ b/datasets/README.md @@ -0,0 +1,2 @@ +Here lives the dataset classes for interacting with cell painting images. +The datasets are currently completely dependent on the pe2loaddata generated csv files. \ No newline at end of file diff --git a/evaluation/README.md b/evaluation/README.md new file mode 100644 index 0000000..a698a2a --- /dev/null +++ b/evaluation/README.md @@ -0,0 +1,3 @@ +Here lives some collection of functions useful for evaluating model performance and plotting + +These code are pretty messy and will need major revisions. \ No newline at end of file diff --git a/evaluation/evaluation_utils.py b/evaluation/evaluation_utils.py new file mode 100644 index 0000000..e9d5c48 --- /dev/null +++ b/evaluation/evaluation_utils.py @@ -0,0 +1,42 @@ +from typing import List, Optional + +import pandas as pd +import torch +from torch.nn import Module +from torch.utils.data import DataLoader + +def evaluate_per_image_metric( + predictions: torch.Tensor, + targets: torch.Tensor, + metrics: List[Module], + indices: Optional[List[int]] = None +) -> pd.DataFrame: + """ + Computes a set of metrics on a per-image basis and returns the results as a pandas DataFrame. + + :param predictions: Predicted images, shape (N, C, H, W). + :type predictions: torch.Tensor + :param targets: Target images, shape (N, C, H, W). + :type targets: torch.Tensor + :param metrics: List of metric functions to evaluate. + :type metrics: List[torch.nn.Module] + :param indices: Optional list of indices to subset the dataset before inference. If None, all images are evaluated. + :type indices: Optional[List[int]], optional + + :return: A DataFrame where each row corresponds to an image and each column corresponds to a metric. + :rtype: pd.DataFrame + """ + if predictions.shape != targets.shape: + raise ValueError(f"Shape mismatch: predictions {predictions.shape} vs targets {targets.shape}") + + results = [] + + if indices is None: + indices = range(predictions.shape[0]) + + for i in indices: # Iterate over images/subset + pred, target = predictions[i].unsqueeze(0), targets[i].unsqueeze(0) # Keep batch dimension + metric_scores = {metric.__class__.__name__: metric.forward(target, pred).item() for metric in metrics} + results.append(metric_scores) + + return pd.DataFrame(results) \ No newline at end of file diff --git a/evaluation/predict_utils.py b/evaluation/predict_utils.py new file mode 100644 index 0000000..bba623b --- /dev/null +++ b/evaluation/predict_utils.py @@ -0,0 +1,103 @@ +from typing import Optional, List, Tuple, Callable + +import torch +import numpy as np +from torch.utils.data import DataLoader, Dataset, Subset +from albumentations import ImageOnlyTransform, Compose + +def predict_image( + dataset: Dataset, + model: torch.nn.Module, + batch_size: int = 1, + device: str = "cpu", + num_workers: int = 0, + indices: Optional[List[int]] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Runs a model on a dataset, performing a forward pass on all (or a subset of) input images + in evaluation mode and returning a stacked tensor of predictions. + DOES NOT check if the dataset dimensions are compatible with the model. + + :param dataset: A dataset that returns (input_tensor, target_tensor) tuples, + where input_tensor has shape (C, H, W). + :type dataset: torch.utils.data.Dataset + :param model: A PyTorch model that is compatible with the dataset inputs. + :type model: torch.nn.Module + :param batch_size: The number of samples per batch (default is 1). + :type batch_size: int, optional + :param device: The device to run inference on, e.g., "cpu" or "cuda". + :type device: str, optional + :param num_workers: Number of workers for the DataLoader (default is 0). + :type num_workers: int, optional + :param indices: Optional list of dataset indices to subset the dataset before inference. + :type indices: Optional[List[int]], optional + + :return: Tuple of stacked target and prediction tensors. + :rtype: Tuple[torch.Tensor, torch.Tensor] + """ + # Subset the dataset if indices are provided + if indices is not None: + dataset = Subset(dataset, indices) + + # Create DataLoader for efficient batch processing + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + model.to(device) + model.eval() + + predictions, targets = [], [] + + with torch.no_grad(): + for inputs, target in dataloader: # Unpacking (input_tensor, target_tensor) + inputs = inputs.to(device) # Move input data to the specified device + + # Forward pass + outputs = model(inputs) + + # output both target and prediction tensors for metric + targets.append(target.cpu()) + predictions.append(outputs.cpu()) # Move to CPU for stacking + + return torch.cat(targets, dim=0), torch.cat(predictions, dim=0) + +def process_tensor_image( + img_tensor: torch.Tensor, + dtype: Optional[np.dtype] = None, + dataset: Optional[Dataset] = None, + invert_function: Optional[Callable] = None +) -> np.ndarray: + """ + Processes model output/other image tensor by casting to numpy, applying an optional dtype casting, + and inverting target transformations if a dataset with `target_transform` is provided. + + :param img_tensor: Tensor stack of model-predicted images with shape (N, C, H, W). + :type img_tensor: torch.Tensor + :param dtype: Optional numpy dtype to cast the output array (default: None). + :type dtype: Optional[np.dtype], optional + :param dataset: Optional dataset object with `target_transform` to invert transformations. + :type dataset: Optional[torch.utils.data.Dataset], optional + :param invert_function: Optional function to invert transformations applied to the images. + If provided, overrides the invert function call from dataset transform. + :type invert_function: Optional[Callable], optional + + :return: Processed numpy array of images with shape (N, C, H, W). + :rtype: np.ndarray + """ + # Convert img_tensor to CPU and NumPy + output_images = img_tensor.cpu().numpy() + + # Optionally cast to specified dtype + if dtype is not None: + output_images = output_images.astype(dtype) + + # Apply invert function when supplied or transformation if invert function is supplied + if invert_function is not None and isinstance(invert_function, Callable): + output_images = np.array([invert_function(img) for img in output_images]) + elif dataset is not None and hasattr(dataset, "target_transform"): + # Apply inverted target transformation if available + target_transform = dataset.target_transform + if isinstance(target_transform, (ImageOnlyTransform, Compose)): + # Apply the transformation on each image + output_images = np.array([target_transform.invert(img) for img in output_images]) + + return output_images \ No newline at end of file diff --git a/evaluation/visualization_utils.py b/evaluation/visualization_utils.py new file mode 100644 index 0000000..1b27699 --- /dev/null +++ b/evaluation/visualization_utils.py @@ -0,0 +1,174 @@ +from typing import List, Union, Optional + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle + +from ..datasets.PatchDataset import PatchDataset +from ..evaluation.predict_utils import predict_image, process_tensor_image +from ..evaluation.evaluation_utils import evaluate_per_image_metric + +def _plot_predictions_grid( + inputs: Union[np.ndarray, torch.Tensor], + targets: Union[np.ndarray, torch.Tensor], + predictions: Union[np.ndarray, torch.Tensor], + raw_images: Optional[Union[np.ndarray, torch.Tensor]] = None, + patch_coords: Optional[List[tuple]] = None, + metrics_df: Optional[pd.DataFrame] = None, + save_path: Optional[str] = None, + **kwargs +): + """ + Generalized function to plot a grid of images with predictions and optional raw images. + The Batch dimensions of (raw_image), input, target, and prediction should match and so should the length of metrics_df. + + :param inputs: Input images (N, C, H, W) or (N, H, W). + :param targets: Target images (N, C, H, W) or (N, H, W). + :param predictions: Model predictions (N, C, H, W) or (N, H, W). + :param raw_images: Optional raw images for PatchDataset (N, H, W). + :param patch_coords: Optional list of (x, y) coordinates for patches. + Only used if raw_images is provided. Length match the first dimension of inputs/targets/predictions. + :param metrics_df: Optional DataFrame with per-image metrics. + :param save_path: If provided, saves figure. + :param kwargs: Additional keyword arguments to pass to plt.subplots. + """ + + cmap = kwargs.get("cmap", "gray") + panel_width = kwargs.get("panel_width", 5) + show_plot = kwargs.get("show_plot", True) + fig_size = kwargs.get("fig_size", None) + + num_samples = len(inputs) + is_patch_dataset = raw_images is not None + num_cols = 4 if is_patch_dataset else 3 # (Raw | Input | Target | Prediction) vs (Input | Target | Prediction) + + fig_size = (panel_width * num_cols, panel_width * num_samples) if fig_size is None else fig_size + fig, axes = plt.subplots(num_samples, num_cols, figsize=fig_size) + column_titles = ["Raw Image", "Input", "Target", "Prediction"] if is_patch_dataset else ["Input", "Target", "Prediction"] + + for row_idx in range(num_samples): + img_set = [raw_images[row_idx]] if is_patch_dataset else [] + img_set.extend([inputs[row_idx], targets[row_idx], predictions[row_idx]]) + + for col_idx, img in enumerate(img_set): + ax = axes[row_idx, col_idx] + ax.imshow(img.squeeze(), cmap=cmap) + ax.set_title(column_titles[col_idx]) + ax.axis("off") + + # Draw rectangle on raw image if PatchDataset + if is_patch_dataset and col_idx == 0 and patch_coords is not None: + patch_x, patch_y = patch_coords[row_idx] # (x, y) coordinates + patch_size = targets.shape[-1] # Assume square patches from target size + rect = Rectangle((patch_x, patch_y), patch_size, patch_size, linewidth=2, edgecolor="r", facecolor="none") + ax.add_patch(rect) + + # Display metrics if provided + if metrics_df is not None: + metric_values = metrics_df.iloc[row_idx] + metric_text = "\n".join([f"{key}: {value:.3f}" for key, value in metric_values.items()]) + axes[row_idx, -1].set_title( + axes[row_idx, -1].get_title() + "\n" + metric_text, fontsize=10, pad=10) + + # Save and/or show the plot + if save_path: + plt.savefig(save_path, bbox_inches="tight", dpi=300) + if show_plot: + plt.show() + else: + plt.close() + +def plot_predictions_grid_from_eval( + dataset: Dataset, + predictions: Union[torch.Tensor, np.ndarray], + indices: List[int], + metrics_df: Optional[pd.DataFrame] = None, + save_path: Optional[str] = None, + **kwargs +): + """ + Wrapper function to extract dataset samples and call `_plot_predictions_grid`. + This function operates on the outputs downstream of `evaluate_per_image_metric` + and `predict_image` to avoid unecessary forward pass. + + :param dataset: Dataset (either normal or PatchDataset). + :param predictions: Subsetted tensor/NumPy array of predictions. + :param indices: Indices corresponding to the subset. + :param metrics_df: DataFrame with per-image metrics for the subset. + :param save_path: If provided, saves figure. + :param kwargs: Additional keyword arguments to pass to `_plot_predictions_grid`. + """ + + is_patch_dataset = isinstance(dataset, PatchDataset) + + # Extract input, target, and (optional) raw images & patch coordinates + raw_images, inputs, targets, patch_coords = [], [], [], [] + for i in indices: + inputs.append(dataset[i][0]) + targets.append(dataset[i][1]) + if is_patch_dataset: + raw_images.append(dataset.raw_input) + patch_coords.append(dataset.patch_coords) # Get patch location + + inputs_numpy = process_tensor_image(torch.stack(inputs), invert_function=dataset.input_transform.invert) + targets_numpy = process_tensor_image(torch.stack(targets), invert_function=dataset.target_transform.invert) + + # Pass everything to the core grid function + _plot_predictions_grid( + inputs_numpy, targets_numpy, predictions[indices], + raw_images if is_patch_dataset else None, + patch_coords if is_patch_dataset else None, + metrics_df, save_path, **kwargs + ) + +def plot_predictions_grid_from_model( + model: torch.nn.Module, + dataset: Dataset, + indices: List[int], + metrics: List[torch.nn.Module], + device: str = "cuda", + save_path: Optional[str] = None, + **kwargs +): + """ + Wrapper plot function that internally performs inference and evaluation with the following steps: + 1. Perform inference on a subset of the dataset given the model. + 2. Compute per-image metrics on that subset. + 3. Plot the results with core `_plot_predictions_grid` function. + + :param model: PyTorch model for inference. + :param dataset: The dataset to use for evaluation and plotting. + :param indices: List of dataset indices to evaluate and visualize. + :param metrics: List of metric functions to evaluate. + :param device: Device to run inference on ("cpu" or "cuda"). + :param save_path: Optional path to save the plot. + :param kwargs: Additional keyword arguments to pass to `_plot_predictions_grid`. + """ + # Step 1: Run inference on the selected subset + targets, predictions = predict_image(dataset, model, indices=indices, device=device) + + # Step 2: Compute per-image metrics for the subset + metrics_df = evaluate_per_image_metric(predictions, targets, metrics) + + # Step 3: Extract subset of inputs & targets and plot + is_patch_dataset = isinstance(dataset, PatchDataset) + raw_images, inputs, targets, patch_coords = [], [], [], [] + for i in indices: + inputs.append(dataset[i][0]) + targets.append(dataset[i][1]) + if is_patch_dataset: + raw_images.append(dataset.raw_input) + patch_coords.append(dataset.patch_coords) # Get patch location + + _plot_predictions_grid( + torch.stack(inputs), + torch.stack(targets), + predictions, + raw_images=raw_images if is_patch_dataset else None, + patch_coords=patch_coords if is_patch_dataset else None, + metrics_df=metrics_df, + save_path=save_path, + **kwargs) \ No newline at end of file diff --git a/examples/generate_sample_patch_dataset.ipynb b/examples/generate_sample_patch_dataset.ipynb new file mode 100644 index 0000000..bef2c71 --- /dev/null +++ b/examples/generate_sample_patch_dataset.ipynb @@ -0,0 +1,2262 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/weishanli/Waylab\n" + ] + } + ], + "source": [ + "import sys\n", + "import pathlib\n", + "\n", + "import imageio\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "sys.path.append(str(pathlib.Path('.').absolute().parent.parent))\n", + "print(str(pathlib.Path('.').absolute().parent.parent))\n", + "\n", + "from virtual_stain_flow.datasets.PatchDataset import PatchDataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-02-20 00:15:46,794 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n", + "2025-02-20 00:15:46,794 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-02-20 00:15:46,795 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n", + "2025-02-20 00:15:46,795 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n", + "2025-02-20 00:15:46,795 - DEBUG - Merge fields inferred: ['Metadata_Site', 'Metadata_Plate', 'Metadata_Well']\n", + "2025-02-20 00:15:46,795 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-02-20 00:15:46,816 - DEBUG - Inferring channel keys from loaddata csv\n", + "2025-02-20 00:15:46,817 - DEBUG - Channel keys: {'OrigBrightfield', 'OrigAGP', 'OrigDNA', 'OrigRNA', 'OrigMito', 'OrigER'} inferred from loaddata csv\n", + "2025-02-20 00:15:46,817 - DEBUG - Setting input channel(s) ...\n", + "2025-02-20 00:15:46,817 - DEBUG - No channel keys specified, skip\n", + "2025-02-20 00:15:46,817 - DEBUG - Setting target channel(s) ...\n", + "2025-02-20 00:15:46,817 - DEBUG - No channel keys specified, skip\n", + "2025-02-20 00:15:46,817 - DEBUG - Setting input transform ...\n", + "2025-02-20 00:15:46,817 - DEBUG - Setting target transform ...\n", + "2025-02-20 00:15:46,817 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n", + "2025-02-20 00:15:46,834 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n", + "2025-02-20 00:15:46,835 - DEBUG - Generating patches that contain cells\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-02-20 00:15:46,848 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n", + "2025-02-20 00:15:47,226 - DEBUG - Generated 461 patches for 93 site/view\n" + ] + } + ], + "source": [ + "## REPLACE WITH YOUR OWN PATHS\n", + "analysis_home_path = pathlib.Path('/home/weishanli/Waylab/ALSF_pilot/ALSF_img2img_prototyping')\n", + "sc_features_parquet_path = pathlib.Path(\n", + " '/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/preprocessed_profiles_SN0313537/single_cell_profiles'\n", + ")\n", + "\n", + "loaddata_csv_path = analysis_home_path \\\n", + " / '0.data_analysis_and_preprocessing' / 'loaddata_csvs'\n", + "\n", + "if loaddata_csv_path.exists():\n", + " try:\n", + " loaddata_csv = next(loaddata_csv_path.glob('*.csv'))\n", + " except:\n", + " raise FileNotFoundError(\"No loaddata csv found\")\n", + "else:\n", + " raise ValueError(\"Incorrect loaddata csv path\")\n", + "\n", + "loaddata_df = pd.read_csv(loaddata_csv)\n", + "# subsample to reduce runtime\n", + "loaddata_df = loaddata_df.sample(n=100, random_state=42)\n", + "\n", + "sc_features = pd.DataFrame()\n", + "for plate in loaddata_df['Metadata_Plate'].unique():\n", + " sc_features_parquet = sc_features_parquet_path / f'{plate}_sc_normalized.parquet'\n", + " if not sc_features_parquet.exists():\n", + " print(f'{sc_features_parquet} does not exist, skipping...')\n", + " continue \n", + " else:\n", + " sc_features = pd.concat([\n", + " sc_features, \n", + " pd.read_parquet(\n", + " sc_features_parquet,\n", + " columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']\n", + " )\n", + " ])\n", + "\n", + "PATCH_SIZE = 256\n", + "\n", + "channel_names = [\n", + " \"OrigBrightfield\",\n", + " \"OrigDNA\",\n", + " \"OrigER\",\n", + " \"OrigMito\",\n", + " \"OrigRNA\",\n", + " \"OrigAGP\",\n", + "]\n", + "input_channel_name = \"OrigBrightfield\"\n", + "target_channel_names = [ch for ch in channel_names if ch != input_channel_name]\n", + "\n", + "pds = PatchDataset(\n", + " _loaddata_csv=loaddata_df,\n", + " _sc_feature=sc_features,\n", + " _input_channel_keys=None,\n", + " _target_channel_keys=None,\n", + " _input_transform=None,\n", + " _target_transform=None,\n", + " patch_size=PATCH_SIZE,\n", + " verbose=True,\n", + " patch_generation_method=\"random_cell\",\n", + " patch_generation_random_seed=42\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "OrigBrightfield/r06c12f02p01-ch1sk1fk1fl1_8_680.tiff" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.microsoft.datawrangler.viewer.v0+json": { + "columns": [ + { + "name": "index", + "rawType": "int64", + "type": "integer" + }, + { + "name": "FileName_OrigBrightfield", + "rawType": "object", + "type": "string" + }, + { + "name": "PathName_OrigBrightfield", + "rawType": "object", + "type": "string" + }, + { + "name": "FileName_OrigER", + "rawType": "object", + "type": "string" + }, + { + "name": "PathName_OrigER", + "rawType": "object", + "type": "string" + }, + { + "name": "FileName_OrigAGP", + "rawType": "object", + "type": "string" + }, + { + "name": "PathName_OrigAGP", + "rawType": "object", + "type": "string" + }, + { + "name": "FileName_OrigMito", + "rawType": "object", + "type": "string" + }, + { + "name": "PathName_OrigMito", + "rawType": "object", + "type": "string" + }, + { + "name": "FileName_OrigDNA", + "rawType": "object", + "type": "string" + }, + { + "name": "PathName_OrigDNA", + "rawType": "object", + "type": "string" + }, + { + "name": "FileName_OrigRNA", + "rawType": "object", + "type": "string" + }, + { + "name": "PathName_OrigRNA", + "rawType": "object", + "type": "string" + }, + { + "name": "Metadata_Plate", + "rawType": "object", + "type": "string" + }, + { + "name": "Metadata_Well", + "rawType": "object", + "type": "string" + }, + { + "name": "Metadata_Site", + "rawType": "int64", + "type": "integer" + }, + { + "name": "Metadata_AbsPositionZ", + "rawType": "float64", + "type": "float" + }, + { + "name": "Metadata_ChannelID", + "rawType": "int64", + "type": "integer" + }, + { + "name": "Metadata_Col", + "rawType": "int64", + "type": "integer" + }, + { + "name": "Metadata_FieldID", + "rawType": "int64", + "type": "integer" + }, + { + "name": "Metadata_PlaneID", + "rawType": "int64", + "type": "integer" + }, + { + "name": "Metadata_PositionX", + "rawType": "float64", + "type": "float" + }, + { + "name": "Metadata_PositionY", + "rawType": "float64", + "type": "float" + }, + { + "name": "Metadata_PositionZ", + "rawType": "float64", + "type": "float" + }, + { + "name": "Metadata_Row", + "rawType": "int64", + "type": "integer" + }, + { + "name": "Metadata_Reimaged", + "rawType": "bool", + "type": "boolean" + } + ], + "conversionMethod": "pd.DataFrame", + "ref": "9503fb45-c7d0-43d4-8b8c-183f24a0d822", + "rows": [ + [ + "2079", + "r06c22f01p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c22f01p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c22f01p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c22f01p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c22f01p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c22f01p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "F22", + "1", + "0.134357795", + "6", + "22", + "1", + "1", + "0.0", + "0.0", + "-6e-06", + "6", + "False" + ], + [ + "668", + "r05c09f03p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c09f03p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c09f03p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c09f03p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c09f03p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c09f03p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "E09", + "3", + "0.1344053", + "6", + "9", + "3", + "1", + "0.0", + "0.000645814", + "-6e-06", + "5", + "False" + ], + [ + "2073", + "r05c22f04p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c22f04p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c22f04p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c22f04p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c22f04p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c22f04p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "E22", + "4", + "0.134365693", + "6", + "22", + "4", + "1", + "0.000645814", + "0.000645814", + "-6e-06", + "5", + "False" + ], + [ + "1113", + "r06c13f07p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c13f07p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c13f07p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c13f07p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c13f07p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c13f07p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "F13", + "7", + "0.134346902", + "6", + "13", + "7", + "1", + "-0.000645814", + "-0.000645814", + "-6e-06", + "6", + "False" + ], + [ + "788", + "r06c10f06p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c10f06p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c10f06p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c10f06p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c10f06p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c10f06p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "F10", + "6", + "0.134381399", + "6", + "10", + "6", + "1", + "-0.000645814", + "0.0", + "-6e-06", + "6", + "False" + ], + [ + "1780", + "r08c19f08p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c19f08p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c19f08p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c19f08p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c19f08p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c19f08p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "H19", + "8", + "0.134326905", + "6", + "19", + "8", + "1", + "0.0", + "-0.000645814", + "-6e-06", + "8", + "False" + ], + [ + "1672", + "r08c18f08p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c18f08p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c18f08p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c18f08p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c18f08p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c18f08p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "H18", + "8", + "0.134322807", + "6", + "18", + "8", + "1", + "0.0", + "-0.000645814", + "-6e-06", + "8", + "False" + ], + [ + "1717", + "r13c18f08p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c18f08p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c18f08p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c18f08p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c18f08p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c18f08p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "M18", + "8", + "0.134328201", + "6", + "18", + "8", + "1", + "0.0", + "-0.000645814", + "-6e-06", + "13", + "False" + ], + [ + "926", + "r09c11f09p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images", + "r09c11f09p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images", + "r09c11f09p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images", + "r09c11f09p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images", + "r09c11f09p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images", + "r09c11f09p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images", + "BR00143976", + "I11", + "9", + "0.134355694", + "6", + "11", + "9", + "1", + "0.000645814", + "-0.000645814", + "-3e-06", + "9", + "True" + ], + [ + "2157", + "r14c22f07p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c22f07p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c22f07p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c22f07p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c22f07p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c22f07p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "N22", + "7", + "0.134358302", + "6", + "22", + "7", + "1", + "-0.000645814", + "-0.000645814", + "-6e-06", + "14", + "False" + ], + [ + "674", + "r05c09f09p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c09f09p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c09f09p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c09f09p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c09f09p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c09f09p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "E09", + "9", + "0.134399906", + "6", + "9", + "9", + "1", + "0.000645814", + "-0.000645814", + "-6e-06", + "5", + "False" + ], + [ + "2011", + "r10c21f05p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r10c21f05p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r10c21f05p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r10c21f05p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r10c21f05p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r10c21f05p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "J21", + "5", + "0.134337202", + "6", + "21", + "5", + "1", + "0.000645814", + "0.0", + "-6e-06", + "10", + "False" + ], + [ + "96", + "r13c03f07p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c03f07p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c03f07p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c03f07p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c03f07p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c03f07p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "M03", + "7", + "0.134531394", + "6", + "3", + "7", + "1", + "-0.000645814", + "-0.000645814", + "-6e-06", + "13", + "False" + ], + [ + "1164", + "r12c13f04p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r12c13f04p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r12c13f04p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r12c13f04p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r12c13f04p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r12c13f04p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "L13", + "4", + "0.134333", + "6", + "13", + "4", + "1", + "0.000645814", + "0.000645814", + "-6e-06", + "12", + "False" + ], + [ + "567", + "r06c08f01p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c08f01p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c08f01p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c08f01p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c08f01p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c08f01p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "F08", + "1", + "0.134413004", + "6", + "8", + "1", + "1", + "0.0", + "0.0", + "-6e-06", + "6", + "False" + ], + [ + "135", + "r06c04f01p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c04f01p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c04f01p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c04f01p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c04f01p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c04f01p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "F04", + "1", + "0.134507298", + "6", + "4", + "1", + "1", + "0.0", + "0.0", + "-6e-06", + "6", + "False" + ], + [ + "29", + "r06c03f03p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c03f03p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c03f03p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c03f03p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c03f03p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c03f03p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "F03", + "3", + "0.134534299", + "6", + "3", + "3", + "1", + "0.0", + "0.000645814", + "-6e-06", + "6", + "False" + ], + [ + "1641", + "r05c18f04p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c18f04p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c18f04p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c18f04p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c18f04p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c18f04p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "E18", + "4", + "0.134341702", + "6", + "18", + "4", + "1", + "0.000645814", + "0.000645814", + "-6e-06", + "5", + "False" + ], + [ + "1845", + "r04c20f01p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c20f01p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c20f01p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c20f01p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c20f01p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c20f01p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "BR00143976", + "D20", + "1", + "0.134361506", + "6", + "20", + "1", + "1", + "0.0", + "0.0", + "-4e-06", + "4", + "True" + ], + [ + "1940", + "r14c20f06p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c20f06p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c20f06p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c20f06p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c20f06p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c20f06p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "N20", + "6", + "0.134341404", + "6", + "20", + "6", + "1", + "-0.000645814", + "0.0", + "-6e-06", + "14", + "False" + ], + [ + "621", + "r12c08f01p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c08f01p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c08f01p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c08f01p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c08f01p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c08f01p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "BR00143976", + "L08", + "1", + "0.134399801", + "6", + "8", + "1", + "1", + "0.0", + "0.0", + "-4e-06", + "12", + "True" + ], + [ + "1204", + "r04c14f08p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c14f08p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c14f08p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c14f08p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c14f08p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c14f08p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "BR00143976", + "D14", + "8", + "0.134359106", + "6", + "14", + "8", + "1", + "0.0", + "-0.000645814", + "-4e-06", + "4", + "True" + ], + [ + "1455", + "r08c16f07p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c16f07p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c16f07p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c16f07p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c16f07p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c16f07p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "H16", + "7", + "0.134320304", + "6", + "16", + "7", + "1", + "-0.000645814", + "-0.000645814", + "-6e-06", + "8", + "False" + ], + [ + "1057", + "r12c12f05p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c12f05p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c12f05p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c12f05p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c12f05p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c12f05p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "BR00143976", + "L12", + "5", + "0.134343997", + "6", + "12", + "5", + "1", + "0.000645814", + "0.0", + "-4e-06", + "12", + "True" + ], + [ + "350", + "r05c06f09p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c06f09p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c06f09p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c06f09p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c06f09p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c06f09p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "E06", + "9", + "0.134458706", + "6", + "6", + "9", + "1", + "0.000645814", + "-0.000645814", + "-6e-06", + "5", + "False" + ], + [ + "1586", + "r11c17f03p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r11c17f03p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r11c17f03p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r11c17f03p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r11c17f03p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r11c17f03p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "K17", + "3", + "0.134314403", + "6", + "17", + "3", + "1", + "0.0", + "0.000645814", + "-6e-06", + "11", + "False" + ], + [ + "1000", + "r06c12f02p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c12f02p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c12f02p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c12f02p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c12f02p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c12f02p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "F12", + "2", + "0.134358004", + "6", + "12", + "2", + "1", + "-0.000645814", + "0.000645814", + "-6e-06", + "6", + "False" + ], + [ + "251", + "r06c05f09p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c05f09p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c05f09p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c05f09p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c05f09p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c05f09p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "F05", + "9", + "0.134474307", + "6", + "5", + "9", + "1", + "0.000645814", + "-0.000645814", + "-6e-06", + "6", + "False" + ], + [ + "1420", + "r04c16f08p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c16f08p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c16f08p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c16f08p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c16f08p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r04c16f08p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "BR00143976", + "D16", + "8", + "0.134352103", + "6", + "16", + "8", + "1", + "0.0", + "-0.000645814", + "-4e-06", + "4", + "True" + ], + [ + "965", + "r14c11f03p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r14c11f03p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r14c11f03p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r14c11f03p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r14c11f03p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r14c11f03p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "BR00143976", + "N11", + "3", + "0.134372801", + "6", + "11", + "3", + "1", + "0.0", + "0.000645814", + "-2e-06", + "14", + "True" + ], + [ + "1189", + "r03c14f02p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r03c14f02p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r03c14f02p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r03c14f02p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r03c14f02p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "r03c14f02p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images", + "BR00143976", + "C14", + "2", + "0.134374201", + "6", + "14", + "2", + "1", + "-0.000645814", + "0.000645814", + "-4e-06", + "3", + "True" + ], + [ + "2002", + "r09c21f05p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r09c21f05p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r09c21f05p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r09c21f05p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r09c21f05p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r09c21f05p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "I21", + "5", + "0.134339705", + "6", + "21", + "5", + "1", + "0.000645814", + "0.0", + "-6e-06", + "9", + "False" + ], + [ + "111", + "r03c04f04p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r03c04f04p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r03c04f04p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r03c04f04p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r03c04f04p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r03c04f04p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "C04", + "4", + "0.134528205", + "6", + "4", + "4", + "1", + "0.000645814", + "0.000645814", + "-6e-06", + "3", + "False" + ], + [ + "184", + "r11c04f05p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r11c04f05p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r11c04f05p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r11c04f05p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r11c04f05p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r11c04f05p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "BR00143976", + "K04", + "5", + "0.134485096", + "6", + "4", + "5", + "1", + "0.000645814", + "0.0", + "-4e-06", + "11", + "True" + ], + [ + "1726", + "r14c18f08p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c18f08p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c18f08p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c18f08p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c18f08p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c18f08p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "N18", + "8", + "0.134337395", + "6", + "18", + "8", + "1", + "0.0", + "-0.000645814", + "-6e-06", + "14", + "False" + ], + [ + "1747", + "r05c19f02p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c19f02p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c19f02p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c19f02p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c19f02p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r05c19f02p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "E19", + "2", + "0.134342298", + "6", + "19", + "2", + "1", + "-0.000645814", + "0.000645814", + "-6e-06", + "5", + "False" + ], + [ + "1667", + "r08c18f03p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c18f03p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c18f03p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c18f03p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c18f03p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r08c18f03p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "H18", + "3", + "0.134323403", + "6", + "18", + "3", + "1", + "0.0", + "0.000645814", + "-6e-06", + "8", + "False" + ], + [ + "2022", + "r11c21f07p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r11c21f07p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r11c21f07p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r11c21f07p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r11c21f07p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r11c21f07p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "K21", + "7", + "0.134334505", + "6", + "21", + "7", + "1", + "-0.000645814", + "-0.000645814", + "-6e-06", + "11", + "False" + ], + [ + "422", + "r13c06f09p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r13c06f09p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r13c06f09p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r13c06f09p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r13c06f09p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r13c06f09p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "BR00143976", + "M06", + "9", + "0.1344482", + "6", + "6", + "9", + "1", + "0.000645814", + "-0.000645814", + "-2e-06", + "13", + "True" + ], + [ + "508", + "r11c07f05p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r11c07f05p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r11c07f05p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r11c07f05p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r11c07f05p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r11c07f05p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "BR00143976", + "K07", + "5", + "0.134413093", + "6", + "7", + "5", + "1", + "0.000645814", + "0.0", + "-4e-06", + "11", + "True" + ], + [ + "1331", + "r06c15f09p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c15f09p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c15f09p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c15f09p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c15f09p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c15f09p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "F15", + "9", + "0.134334505", + "6", + "15", + "9", + "1", + "0.000645814", + "-0.000645814", + "-6e-06", + "6", + "False" + ], + [ + "1469", + "r10c16f03p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r10c16f03p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r10c16f03p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r10c16f03p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r10c16f03p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r10c16f03p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "J16", + "3", + "0.134315595", + "6", + "16", + "3", + "1", + "0.0", + "0.000645814", + "-6e-06", + "10", + "False" + ], + [ + "1935", + "r14c20f01p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c20f01p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c20f01p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c20f01p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c20f01p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c20f01p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "N20", + "1", + "0.134343296", + "6", + "20", + "1", + "1", + "0.0", + "0.0", + "-6e-06", + "14", + "False" + ], + [ + "630", + "r13c08f01p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r13c08f01p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r13c08f01p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r13c08f01p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r13c08f01p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "r13c08f01p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images", + "BR00143976", + "M08", + "1", + "0.134411901", + "6", + "8", + "1", + "1", + "0.0", + "0.0", + "-2e-06", + "13", + "True" + ], + [ + "1393", + "r13c15f08p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c15f08p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c15f08p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c15f08p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c15f08p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c15f08p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "M15", + "8", + "0.134330496", + "6", + "15", + "8", + "1", + "0.0", + "-0.000645814", + "-6e-06", + "13", + "False" + ], + [ + "1933", + "r13c20f08p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c20f08p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c20f08p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c20f08p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c20f08p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r13c20f08p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "M20", + "8", + "0.134337306", + "6", + "20", + "8", + "1", + "0.0", + "-0.000645814", + "-6e-06", + "13", + "False" + ], + [ + "1293", + "r14c14f07p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c14f07p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c14f07p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c14f07p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c14f07p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r14c14f07p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "N14", + "7", + "0.134347007", + "6", + "14", + "7", + "1", + "-0.000645814", + "-0.000645814", + "-6e-06", + "14", + "False" + ], + [ + "973", + "r03c12f02p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r03c12f02p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r03c12f02p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r03c12f02p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r03c12f02p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r03c12f02p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "C12", + "2", + "0.134390101", + "6", + "12", + "2", + "1", + "-0.000645814", + "0.000645814", + "-6e-06", + "3", + "False" + ], + [ + "297", + "r12c05f01p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c05f01p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c05f01p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c05f01p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c05f01p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "r12c05f01p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images", + "BR00143976", + "L05", + "1", + "0.134464994", + "6", + "5", + "1", + "1", + "0.0", + "0.0", + "-4e-06", + "12", + "True" + ], + [ + "1761", + "r06c19f07p01-ch1sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c19f07p01-ch2sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c19f07p01-ch3sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c19f07p01-ch4sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c19f07p01-ch5sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "r06c19f07p01-ch6sk1fk1fl1.tiff", + "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images", + "BR00143976", + "F19", + "7", + "0.134334296", + "6", + "19", + "7", + "1", + "-0.000645814", + "-0.000645814", + "-6e-06", + "6", + "False" + ] + ], + "shape": { + "columns": 25, + "rows": 100 + } + }, + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FileName_OrigBrightfieldPathName_OrigBrightfieldFileName_OrigERPathName_OrigERFileName_OrigAGPPathName_OrigAGPFileName_OrigMitoPathName_OrigMitoFileName_OrigDNAPathName_OrigDNA...Metadata_AbsPositionZMetadata_ChannelIDMetadata_ColMetadata_FieldIDMetadata_PlaneIDMetadata_PositionXMetadata_PositionYMetadata_PositionZMetadata_RowMetadata_Reimaged
2079r06c22f01p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c22f01p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c22f01p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c22f01p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c22f01p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.134358622110.0000000.000000-0.0000066False
668r05c09f03p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c09f03p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c09f03p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c09f03p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c09f03p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13440569310.0000000.000646-0.0000065False
2073r05c22f04p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c22f04p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c22f04p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c22f04p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c22f04p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.134366622410.0006460.000646-0.0000065False
1113r06c13f07p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c13f07p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c13f07p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c13f07p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c13f07p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13434761371-0.000646-0.000646-0.0000066False
788r06c10f06p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c10f06p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c10f06p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c10f06p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c10f06p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13438161061-0.0006460.000000-0.0000066False
..................................................................
1730r03c19f03p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c19f03p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c19f03p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c19f03p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c19f03p01-ch6sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.134366619310.0000000.000646-0.0000043True
196r12c04f08p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r12c04f08p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r12c04f08p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r12c04f08p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r12c04f08p01-ch6sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13449164810.000000-0.000646-0.00000412True
367r07c06f08p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r07c06f08p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r07c06f08p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r07c06f08p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r07c06f08p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13444766810.000000-0.000646-0.0000067False
650r03c09f03p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c09f03p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c09f03p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c09f03p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c09f03p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13442869310.0000000.000646-0.0000063False
2064r04c22f04p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r04c22f04p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r04c22f04p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r04c22f04p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r04c22f04p01-ch6sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.134379622410.0006460.000646-0.0000044True
\n", + "

100 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " FileName_OrigBrightfield \\\n", + "2079 r06c22f01p01-ch1sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch1sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch1sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch1sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch1sk1fk1fl1.tiff \n", + "... ... \n", + "1730 r03c19f03p01-ch1sk1fk1fl1.tiff \n", + "196 r12c04f08p01-ch1sk1fk1fl1.tiff \n", + "367 r07c06f08p01-ch1sk1fk1fl1.tiff \n", + "650 r03c09f03p01-ch1sk1fk1fl1.tiff \n", + "2064 r04c22f04p01-ch1sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigBrightfield \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "... ... \n", + "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigER \\\n", + "2079 r06c22f01p01-ch2sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch2sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch2sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch2sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch2sk1fk1fl1.tiff \n", + "... ... \n", + "1730 r03c19f03p01-ch2sk1fk1fl1.tiff \n", + "196 r12c04f08p01-ch2sk1fk1fl1.tiff \n", + "367 r07c06f08p01-ch2sk1fk1fl1.tiff \n", + "650 r03c09f03p01-ch2sk1fk1fl1.tiff \n", + "2064 r04c22f04p01-ch2sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigER \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "... ... \n", + "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigAGP \\\n", + "2079 r06c22f01p01-ch3sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch3sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch3sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch3sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch3sk1fk1fl1.tiff \n", + "... ... \n", + "1730 r03c19f03p01-ch4sk1fk1fl1.tiff \n", + "196 r12c04f08p01-ch4sk1fk1fl1.tiff \n", + "367 r07c06f08p01-ch3sk1fk1fl1.tiff \n", + "650 r03c09f03p01-ch3sk1fk1fl1.tiff \n", + "2064 r04c22f04p01-ch4sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigAGP \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "... ... \n", + "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigMito \\\n", + "2079 r06c22f01p01-ch4sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch4sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch4sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch4sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch4sk1fk1fl1.tiff \n", + "... ... \n", + "1730 r03c19f03p01-ch3sk1fk1fl1.tiff \n", + "196 r12c04f08p01-ch3sk1fk1fl1.tiff \n", + "367 r07c06f08p01-ch4sk1fk1fl1.tiff \n", + "650 r03c09f03p01-ch4sk1fk1fl1.tiff \n", + "2064 r04c22f04p01-ch3sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigMito \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "... ... \n", + "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigDNA \\\n", + "2079 r06c22f01p01-ch5sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch5sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch5sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch5sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch5sk1fk1fl1.tiff \n", + "... ... \n", + "1730 r03c19f03p01-ch6sk1fk1fl1.tiff \n", + "196 r12c04f08p01-ch6sk1fk1fl1.tiff \n", + "367 r07c06f08p01-ch5sk1fk1fl1.tiff \n", + "650 r03c09f03p01-ch5sk1fk1fl1.tiff \n", + "2064 r04c22f04p01-ch6sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigDNA ... \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "... ... ... \n", + "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "\n", + " Metadata_AbsPositionZ Metadata_ChannelID Metadata_Col Metadata_FieldID \\\n", + "2079 0.134358 6 22 1 \n", + "668 0.134405 6 9 3 \n", + "2073 0.134366 6 22 4 \n", + "1113 0.134347 6 13 7 \n", + "788 0.134381 6 10 6 \n", + "... ... ... ... ... \n", + "1730 0.134366 6 19 3 \n", + "196 0.134491 6 4 8 \n", + "367 0.134447 6 6 8 \n", + "650 0.134428 6 9 3 \n", + "2064 0.134379 6 22 4 \n", + "\n", + " Metadata_PlaneID Metadata_PositionX Metadata_PositionY \\\n", + "2079 1 0.000000 0.000000 \n", + "668 1 0.000000 0.000646 \n", + "2073 1 0.000646 0.000646 \n", + "1113 1 -0.000646 -0.000646 \n", + "788 1 -0.000646 0.000000 \n", + "... ... ... ... \n", + "1730 1 0.000000 0.000646 \n", + "196 1 0.000000 -0.000646 \n", + "367 1 0.000000 -0.000646 \n", + "650 1 0.000000 0.000646 \n", + "2064 1 0.000646 0.000646 \n", + "\n", + " Metadata_PositionZ Metadata_Row Metadata_Reimaged \n", + "2079 -0.000006 6 False \n", + "668 -0.000006 5 False \n", + "2073 -0.000006 5 False \n", + "1113 -0.000006 6 False \n", + "788 -0.000006 6 False \n", + "... ... ... ... \n", + "1730 -0.000004 3 True \n", + "196 -0.000004 12 True \n", + "367 -0.000006 7 False \n", + "650 -0.000006 3 False \n", + "2064 -0.000004 4 True \n", + "\n", + "[100 rows x 25 columns]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaddata_df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "EXAMPLE_PATCH_DATA_EXPORT_PATH = pathlib.Path('.').absolute().parent.parent / 'example_patch_data'\n", + "EXAMPLE_PATCH_DATA_EXPORT_PATH.mkdir(exist_ok=True)\n", + "INPUT_EXPORT_PATH = EXAMPLE_PATCH_DATA_EXPORT_PATH / input_channel_name\n", + "INPUT_EXPORT_PATH.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-02-20 00:15:47,252 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 00:15:47,252 - DEBUG - Set target channel(s) as ['OrigDNA']\n", + "2025-02-20 00:15:47,475 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 00:15:47,475 - DEBUG - Set target channel(s) as ['OrigER']\n", + "2025-02-20 00:15:47,676 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 00:15:47,677 - DEBUG - Set target channel(s) as ['OrigMito']\n", + "2025-02-20 00:15:47,850 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 00:15:47,850 - DEBUG - Set target channel(s) as ['OrigRNA']\n", + "2025-02-20 00:15:48,034 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 00:15:48,035 - DEBUG - Set target channel(s) as ['OrigAGP']\n" + ] + } + ], + "source": [ + "for j, channel_name in enumerate(target_channel_names):\n", + "\n", + " pds.set_input_channel_keys([input_channel_name])\n", + " pds.set_target_channel_keys([channel_name])\n", + "\n", + " CHANNEL_EXPORT_PATH = EXAMPLE_PATCH_DATA_EXPORT_PATH / channel_name\n", + " CHANNEL_EXPORT_PATH.mkdir(exist_ok=True)\n", + "\n", + " for i in range(len(pds)):\n", + " input, target = pds[i]\n", + " input_name = pds.input_names\n", + " target_name = pds.target_names\n", + " patch_coord = pds.patch_coords\n", + "\n", + " if j == 0:\n", + " imageio.imwrite(\n", + " INPUT_EXPORT_PATH / f'{input_name[0].stem}_{patch_coord[0]}_{patch_coord[1]}.tiff', \n", + " input[0].numpy().astype(np.uint16))\n", + "\n", + " imageio.imwrite(\n", + " CHANNEL_EXPORT_PATH / f'{target_name[0].stem}_{patch_coord[0]}_{patch_coord[1]}.tiff', \n", + " target[0].numpy().astype(np.uint16))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "speckle_analysis", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/minimal_example.ipynb b/examples/minimal_example.ipynb new file mode 100644 index 0000000..d3842c5 --- /dev/null +++ b/examples/minimal_example.ipynb @@ -0,0 +1,1019 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A minimal example to demonstrate how the trainer for FNet and wGAN GP plus the callbacks works along with patched dataset\n", + "\n", + "Is dependent on the files produced by 1.illumination_correction/0.create_loaddata_csvs ALSF pilot data repo https://github.com/WayScience/pediatric_cancer_atlas_profiling" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/weishanli/Waylab\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/weishanli/anaconda3/envs/cp_gan_env/lib/python3.9/site-packages/albumentations/__init__.py:28: UserWarning: A new version of Albumentations is available: '2.0.5' (you have '2.0.4'). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n", + " check_for_updates()\n" + ] + } + ], + "source": [ + "import sys\n", + "import pathlib\n", + "\n", + "import pandas as pd\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "\n", + "sys.path.append(str(pathlib.Path('.').absolute().parent.parent))\n", + "print(str(pathlib.Path('.').absolute().parent.parent))\n", + "\n", + "## Dataset\n", + "from virtual_stain_flow.datasets.PatchDataset import PatchDataset\n", + "from virtual_stain_flow.datasets.CachedDataset import CachedDataset\n", + "\n", + "## FNet training\n", + "from virtual_stain_flow.models.fnet import FNet\n", + "from virtual_stain_flow.trainers.Trainer import Trainer\n", + "\n", + "## wGaN training\n", + "from virtual_stain_flow.models.unet import UNet\n", + "from virtual_stain_flow.models.discriminator import GlobalDiscriminator\n", + "from virtual_stain_flow.trainers.WGANTrainer import WGANTrainer\n", + "\n", + "## wGaN losses\n", + "from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss\n", + "from virtual_stain_flow.losses.DiscriminatorLoss import WassersteinLoss\n", + "from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss\n", + "\n", + "from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize\n", + "\n", + "## Metrics\n", + "from virtual_stain_flow.metrics.MetricsWrapper import MetricsWrapper\n", + "from virtual_stain_flow.metrics.PSNR import PSNR\n", + "from virtual_stain_flow.metrics.SSIM import SSIM\n", + "\n", + "## callback\n", + "from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger\n", + "from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePlot\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specify train output paths" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "EXAMPLE_DIR = pathlib.Path('.').absolute() / 'example_train'\n", + "EXAMPLE_DIR.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "!rm -rf example_train/*" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "PLOT_DIR = EXAMPLE_DIR / 'plot'\n", + "PLOT_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "MLFLOW_DIR =EXAMPLE_DIR / 'mlflow'\n", + "MLFLOW_DIR.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specify paths to loaddata and read single cell features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " FileName_OrigBrightfield \\\n", + "2079 r06c22f01p01-ch1sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch1sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch1sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch1sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch1sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigBrightfield \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigER \\\n", + "2079 r06c22f01p01-ch2sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch2sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch2sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch2sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch2sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigER \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigAGP \\\n", + "2079 r06c22f01p01-ch3sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch3sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch3sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch3sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch3sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigAGP \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigMito \\\n", + "2079 r06c22f01p01-ch4sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch4sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch4sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch4sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch4sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigMito \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigDNA \\\n", + "2079 r06c22f01p01-ch5sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch5sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch5sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch5sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch5sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigDNA ... \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "\n", + " Metadata_AbsPositionZ Metadata_ChannelID Metadata_Col Metadata_FieldID \\\n", + "2079 0.134358 6 22 1 \n", + "668 0.134405 6 9 3 \n", + "2073 0.134366 6 22 4 \n", + "1113 0.134347 6 13 7 \n", + "788 0.134381 6 10 6 \n", + "\n", + " Metadata_PlaneID Metadata_PositionX Metadata_PositionY \\\n", + "2079 1 0.000000 0.000000 \n", + "668 1 0.000000 0.000646 \n", + "2073 1 0.000646 0.000646 \n", + "1113 1 -0.000646 -0.000646 \n", + "788 1 -0.000646 0.000000 \n", + "\n", + " Metadata_PositionZ Metadata_Row Metadata_Reimaged \n", + "2079 -0.000006 6 False \n", + "668 -0.000006 5 False \n", + "2073 -0.000006 5 False \n", + "1113 -0.000006 6 False \n", + "788 -0.000006 6 False \n", + "\n", + "[5 rows x 25 columns]\n", + " Metadata_Plate Metadata_Well Metadata_Site \\\n", + "0 BR00143976 C03 2 \n", + "1 BR00143976 C03 6 \n", + "2 BR00143976 C03 9 \n", + "3 BR00143976 C03 5 \n", + "4 BR00143976 C03 7 \n", + "\n", + " Metadata_Cells_Location_Center_X Metadata_Cells_Location_Center_Y \n", + "0 629.552987 62.017799 \n", + "1 279.951864 56.588228 \n", + "2 876.508878 205.794360 \n", + "3 479.254866 45.496581 \n", + "4 866.557068 205.908787 \n" + ] + } + ], + "source": [ + "## REPLACE WITH YOUR OWN PATHS\n", + "loaddata_csv_path = pathlib.Path(\n", + " '/REPLACE/WITH/YOUR/PATH'\n", + " )\n", + "sc_features_parquet_path = pathlib.Path(\n", + " '/REPLACE/WITH/YOUR/PATH'\n", + " )\n", + "\n", + "if loaddata_csv_path.exists():\n", + " try:\n", + " loaddata_csv = next(loaddata_csv_path.glob('*.csv'))\n", + " except:\n", + " raise FileNotFoundError(\"No loaddata csv found\")\n", + "else:\n", + " raise ValueError(\"Incorrect loaddata csv path\")\n", + "\n", + "loaddata_df = pd.read_csv(loaddata_csv)\n", + "# subsample to reduce runtime\n", + "loaddata_df = loaddata_df.sample(n=100, random_state=42)\n", + "\n", + "sc_features = pd.DataFrame()\n", + "for plate in loaddata_df['Metadata_Plate'].unique():\n", + " sc_features_parquet = sc_features_parquet_path / f'{plate}_sc_normalized.parquet'\n", + " if not sc_features_parquet.exists():\n", + " print(f'{sc_features_parquet} does not exist, skipping...')\n", + " continue \n", + " else:\n", + " sc_features = pd.concat([\n", + " sc_features, \n", + " pd.read_parquet(\n", + " sc_features_parquet,\n", + " columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']\n", + " )\n", + " ])\n", + "\n", + "print(loaddata_df.head())\n", + "print(sc_features.head())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Patch size and channels" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "PATCH_SIZE = 256\n", + "\n", + "channel_names = [\n", + " \"OrigBrightfield\",\n", + " \"OrigDNA\",\n", + " \"OrigER\",\n", + " \"OrigMito\",\n", + " \"OrigRNA\",\n", + " \"OrigAGP\",\n", + "]\n", + "input_channel_name = \"OrigBrightfield\"\n", + "target_channel_names = [ch for ch in channel_names if ch != input_channel_name]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prep Patch dataset and Cache" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-03-03 20:31:52,602 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n", + "2025-03-03 20:31:52,603 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-03-03 20:31:52,604 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n", + "2025-03-03 20:31:52,605 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n", + "2025-03-03 20:31:52,605 - DEBUG - Merge fields inferred: ['Metadata_Plate', 'Metadata_Well', 'Metadata_Site']\n", + "2025-03-03 20:31:52,605 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-03-03 20:31:52,673 - DEBUG - Inferring channel keys from loaddata csv\n", + "2025-03-03 20:31:52,674 - DEBUG - Channel keys: {'OrigBrightfield', 'OrigMito', 'OrigAGP', 'OrigER', 'OrigRNA', 'OrigDNA'} inferred from loaddata csv\n", + "2025-03-03 20:31:52,674 - DEBUG - Setting input channel(s) ...\n", + "2025-03-03 20:31:52,674 - DEBUG - No channel keys specified, skip\n", + "2025-03-03 20:31:52,675 - DEBUG - Setting target channel(s) ...\n", + "2025-03-03 20:31:52,675 - DEBUG - No channel keys specified, skip\n", + "2025-03-03 20:31:52,675 - DEBUG - Setting input transform ...\n", + "2025-03-03 20:31:52,676 - DEBUG - Setting target transform ...\n", + "2025-03-03 20:31:52,676 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-03-03 20:31:53,018 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n", + "2025-03-03 20:31:53,018 - DEBUG - Generating patches that contain cells\n", + "2025-03-03 20:31:53,052 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n", + "2025-03-03 20:31:53,681 - DEBUG - Generated 461 patches for 93 site/view\n", + "2025-03-03 20:31:53,682 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-03-03 20:31:53,682 - DEBUG - Set target channel(s) as ['OrigDNA']\n" + ] + } + ], + "source": [ + "pds = PatchDataset(\n", + " _loaddata_csv=loaddata_df,\n", + " _sc_feature=sc_features,\n", + " _input_channel_keys=None,\n", + " _target_channel_keys=None,\n", + " _input_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),\n", + " _target_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),\n", + " patch_size=PATCH_SIZE,\n", + " verbose=True,\n", + " patch_generation_method=\"random_cell\",\n", + " patch_generation_random_seed=42\n", + ")\n", + "\n", + "## Set input and target channels\n", + "pds.set_input_channel_keys([input_channel_name])\n", + "pds.set_target_channel_keys('OrigDNA')\n", + "\n", + "## Cache for faster training \n", + "cds = CachedDataset(\n", + " pds,\n", + " prefill_cache=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FNet trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model without callback and check logs" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "model = FNet(depth=4)\n", + "lr = 3e-4\n", + "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=None,\n", + " early_termination_metric = 'L1Loss',\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda'\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.microsoft.datawrangler.viewer.v0+json": { + "columns": [ + { + "name": "index", + "rawType": "int64", + "type": "integer" + }, + { + "name": "epoch", + "rawType": "int64", + "type": "integer" + }, + { + "name": "L1Loss", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_L1Loss", + "rawType": "float64", + "type": "float" + }, + { + "name": "psnr", + "rawType": "float64", + "type": "float" + }, + { + "name": "ssim", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_psnr", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_ssim", + "rawType": "float64", + "type": "float" + } + ], + "conversionMethod": "pd.DataFrame", + "ref": "0e376d8b-d69e-4ca5-a37b-ad7cc30c4534", + "rows": [ + [ + "0", + "1", + "0.30758210165160044", + "0.367111998796463", + "10.07480525970459", + "0.026603518053889275", + "8.649313926696777", + "0.03540125489234924" + ], + [ + "1", + "2", + "0.1586258773292814", + "0.15944983959197997", + "14.711766242980957", + "0.04422195255756378", + "15.713263511657715", + "0.06773694604635239" + ], + [ + "2", + "3", + "0.09979256739219029", + "0.09274622052907944", + "17.329845428466797", + "0.07210950553417206", + "20.035673141479492", + "0.11819823086261749" + ], + [ + "3", + "4", + "0.06892516552692368", + "0.050482964515686034", + "19.144039154052734", + "0.1051420047879219", + "24.395761489868164", + "0.2566310465335846" + ], + [ + "4", + "5", + "0.04841376912026178", + "0.04219272881746292", + "24.528972625732422", + "0.30562901496887207", + "25.807828903198242", + "0.3405899703502655" + ], + [ + "5", + "6", + "0.03341594719815822", + "0.027642089501023294", + "27.669620513916016", + "0.45510202646255493", + "28.5550594329834", + "0.5035992860794067" + ], + [ + "6", + "7", + "0.026435295385973796", + "0.025301840528845786", + "28.809947967529297", + "0.5208737254142761", + "29.219507217407227", + "0.563963770866394" + ], + [ + "7", + "8", + "0.02121852764061519", + "0.017991484329104423", + "29.607341766357422", + "0.5826537609100342", + "30.165979385375977", + "0.6275455355644226" + ], + [ + "8", + "9", + "0.018385388267536957", + "0.017969632521271706", + "29.834793090820312", + "0.6006285548210144", + "29.763931274414062", + "0.582464873790741" + ], + [ + "9", + "10", + "0.015868896308044594", + "0.018145012110471724", + "30.2650203704834", + "0.6225405931472778", + "29.08074951171875", + "0.48871317505836487" + ] + ], + "shape": { + "columns": 7, + "rows": 10 + } + }, + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochL1Lossval_L1Losspsnrssimval_psnrval_ssim
010.3075820.36711210.0748050.0266048.6493140.035401
120.1586260.15945014.7117660.04422215.7132640.067737
230.0997930.09274617.3298450.07211020.0356730.118198
340.0689250.05048319.1440390.10514224.3957610.256631
450.0484140.04219324.5289730.30562925.8078290.340590
560.0334160.02764227.6696210.45510228.5550590.503599
670.0264350.02530228.8099480.52087429.2195070.563964
780.0212190.01799129.6073420.58265430.1659790.627546
890.0183850.01797029.8347930.60062929.7639310.582465
9100.0158690.01814530.2650200.62254129.0807500.488713
\n", + "
" + ], + "text/plain": [ + " epoch L1Loss val_L1Loss psnr ssim val_psnr val_ssim\n", + "0 1 0.307582 0.367112 10.074805 0.026604 8.649314 0.035401\n", + "1 2 0.158626 0.159450 14.711766 0.044222 15.713264 0.067737\n", + "2 3 0.099793 0.092746 17.329845 0.072110 20.035673 0.118198\n", + "3 4 0.068925 0.050483 19.144039 0.105142 24.395761 0.256631\n", + "4 5 0.048414 0.042193 24.528973 0.305629 25.807829 0.340590\n", + "5 6 0.033416 0.027642 27.669621 0.455102 28.555059 0.503599\n", + "6 7 0.026435 0.025302 28.809948 0.520874 29.219507 0.563964\n", + "7 8 0.021219 0.017991 29.607342 0.582654 30.165979 0.627546\n", + "8 9 0.018385 0.017970 29.834793 0.600629 29.763931 0.582465\n", + "9 10 0.015869 0.018145 30.265020 0.622541 29.080750 0.488713" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame(trainer.log)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model with alternative early termination metric" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Early termination at epoch 6 with best validation metric 9.887849807739258\n" + ] + } + ], + "source": [ + "model = FNet(depth=4)\n", + "lr = 3e-4\n", + "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=None,\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda',\n", + " early_termination_metric = 'psnr' # set early termination metric as psnr for the sake of demonstration\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train with mlflow logger callbacks" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'lr': 3e-4\n", + " },\n", + " )\n", + "\n", + "del trainer\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=[mlflow_logger_callback],\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda',\n", + " early_termination_metric = 'L1Loss'\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# wGaN GP example with mlflow logger callback and plot callback" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "generator = UNet(\n", + " n_channels=1,\n", + " n_classes=1\n", + ")\n", + "\n", + "discriminator = GlobalDiscriminator(\n", + " n_in_channels = 2,\n", + " n_in_filters = 64,\n", + " _conv_depth = 4,\n", + " _pool_before_fc = True,\n", + " batch_norm=True\n", + ")\n", + "\n", + "generator_optimizer = optim.Adam(generator.parameters(), \n", + " lr=0.0002, \n", + " betas=(0., 0.9))\n", + "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n", + " lr=0.00002, \n", + " betas=(0., 0.9),\n", + " weight_decay=0.001)\n", + "\n", + "gp_loss = GradientPenaltyLoss(\n", + " _metric_name='gp_loss',\n", + " discriminator=discriminator,\n", + " weight=10.0,\n", + ")\n", + "\n", + "gen_loss = GeneratorLoss(\n", + " _metric_name='gen_loss'\n", + ")\n", + "\n", + "disc_loss = WassersteinLoss(\n", + " _metric_name='disc_loss'\n", + ")\n", + "\n", + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train_wgan', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'gen_lr': 0.0002,\n", + " 'disc_lr': 0.00002\n", + " },\n", + " )\n", + "\n", + "plot_callback = IntermediatePlot(\n", + " name='plotter',\n", + " path=PLOT_DIR,\n", + " dataset=pds, # give it the patch dataset as opposed to the cached dataset\n", + " indices=[1,3,5,7,9], # plot 5 selected patches images from the dataset\n", + " plot_metrics=[SSIM(_metric_name='ssim'), PSNR(_metric_name='psnr')],\n", + " figsize=(20, 25),\n", + " show_plot=False,\n", + ")\n", + "\n", + "wgan_trainer = WGANTrainer(\n", + " dataset=cds,\n", + " batch_size=16,\n", + " epochs=20,\n", + " patience=20, # setting this to prevent unwanted early termination here\n", + " device='cuda',\n", + " generator=generator,\n", + " discriminator=discriminator,\n", + " gen_optimizer=generator_optimizer,\n", + " disc_optimizer=discriminator_optimizer,\n", + " generator_loss_fn=gen_loss,\n", + " discriminator_loss_fn=disc_loss,\n", + " gradient_penalty_fn=gp_loss,\n", + " discriminator_update_freq=1,\n", + " generator_update_freq=2,\n", + " callbacks=[mlflow_logger_callback, plot_callback],\n", + " metrics={'ssim': SSIM(_metric_name='ssim'), \n", + " 'psnr': PSNR(_metric_name='psnr')\n", + " },\n", + " early_termination_metric = 'GeneratorLoss'\n", + ")\n", + "\n", + "wgan_trainer.train()\n", + "\n", + "del generator\n", + "del wgan_trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## # wGaN GP example with mlflow logger callback and alternative early termination loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "generator = UNet(\n", + " n_channels=1,\n", + " n_classes=1\n", + ")\n", + "\n", + "discriminator = GlobalDiscriminator(\n", + " n_in_channels = 2,\n", + " n_in_filters = 64,\n", + " _conv_depth = 4,\n", + " _pool_before_fc = True,\n", + " batch_norm=True\n", + ")\n", + "\n", + "generator_optimizer = optim.Adam(generator.parameters(), \n", + " lr=0.0002, \n", + " betas=(0., 0.9))\n", + "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n", + " lr=0.00002, \n", + " betas=(0., 0.9),\n", + " weight_decay=0.001)\n", + "\n", + "gp_loss = GradientPenaltyLoss(\n", + " _metric_name='gp_loss',\n", + " discriminator=discriminator,\n", + " weight=10.0,\n", + ")\n", + "\n", + "gen_loss = GeneratorLoss(\n", + " _metric_name='gen_loss'\n", + ")\n", + "\n", + "disc_loss = WassersteinLoss(\n", + " _metric_name='disc_loss'\n", + ")\n", + "\n", + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train_wgan_mae_early_term', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'gen_lr': 0.0002,\n", + " 'disc_lr': 0.00002\n", + " },\n", + " )\n", + "\n", + "wgan_trainer = WGANTrainer(\n", + " dataset=cds,\n", + " batch_size=16,\n", + " epochs=20,\n", + " patience=5, # lower patience here\n", + " device='cuda',\n", + " generator=generator,\n", + " discriminator=discriminator,\n", + " gen_optimizer=generator_optimizer,\n", + " disc_optimizer=discriminator_optimizer,\n", + " generator_loss_fn=gen_loss,\n", + " discriminator_loss_fn=disc_loss,\n", + " gradient_penalty_fn=gp_loss,\n", + " discriminator_update_freq=1,\n", + " generator_update_freq=2,\n", + " callbacks=[mlflow_logger_callback],\n", + " metrics={'ssim': SSIM(_metric_name='ssim'), \n", + " 'psnr': PSNR(_metric_name='psnr'),\n", + " 'mae': MetricsWrapper(_metric_name='mae', module=nn.L1Loss()) # use a wrapper for torch nn L1Loss\n", + " },\n", + " early_termination_metric = 'mae' # update early temrination loss with the supplied L1Loss/mae metric instead of the default GaN generator loss\n", + ")\n", + "\n", + "wgan_trainer.train()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cp_gan_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/minimal_example_generic_dataset.ipynb b/examples/minimal_example_generic_dataset.ipynb new file mode 100644 index 0000000..6eb2b1c --- /dev/null +++ b/examples/minimal_example_generic_dataset.ipynb @@ -0,0 +1,840 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A minimal example to demonstrate how the trainer for FNet and wGAN GP plus the callbacks works along with patched dataset\n", + "\n", + "Is will not be dependent on the pe2loaddata generated index file from the ALSF pilot data repo unlike the other example notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/weishanli/Waylab\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/weishanli/anaconda3/envs/cp_gan_env/lib/python3.9/site-packages/albumentations/__init__.py:28: UserWarning: A new version of Albumentations is available: '2.0.5' (you have '2.0.4'). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n", + " check_for_updates()\n" + ] + } + ], + "source": [ + "import sys\n", + "import pathlib\n", + "\n", + "import pandas as pd\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "\n", + "sys.path.append(str(pathlib.Path('.').absolute().parent.parent))\n", + "print(str(pathlib.Path('.').absolute().parent.parent))\n", + "\n", + "## Dataset\n", + "from virtual_stain_flow.datasets.GenericImageDataset import GenericImageDataset\n", + "from virtual_stain_flow.datasets.CachedDataset import CachedDataset\n", + "\n", + "## FNet training\n", + "from virtual_stain_flow.models.fnet import FNet\n", + "from virtual_stain_flow.trainers.Trainer import Trainer\n", + "\n", + "## wGaN training\n", + "from virtual_stain_flow.models.unet import UNet\n", + "from virtual_stain_flow.models.discriminator import GlobalDiscriminator\n", + "from virtual_stain_flow.trainers.WGANTrainer import WGANTrainer\n", + "\n", + "## wGaN losses\n", + "from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss\n", + "from virtual_stain_flow.losses.DiscriminatorLoss import WassersteinLoss\n", + "from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss\n", + "\n", + "from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize\n", + "\n", + "## Metrics\n", + "from virtual_stain_flow.metrics.MetricsWrapper import MetricsWrapper\n", + "from virtual_stain_flow.metrics.PSNR import PSNR\n", + "from virtual_stain_flow.metrics.SSIM import SSIM\n", + "\n", + "## callback\n", + "from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger\n", + "from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePlot\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specify train data and output paths" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "EXAMPLE_PATCH_DATA_EXPORT_PATH = '/REPLACE/WITH/PATH/TO/DATA'\n", + "\n", + "EXAMPLE_DIR = pathlib.Path('.').absolute() / 'example_train_generic_dataset'\n", + "EXAMPLE_DIR.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "!rm -rf example_train_generic_dataset/*" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "PLOT_DIR = EXAMPLE_DIR / 'plot'\n", + "PLOT_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "MLFLOW_DIR =EXAMPLE_DIR / 'mlflow'\n", + "MLFLOW_DIR.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure channels" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "channel_names = [\n", + " \"OrigBrightfield\",\n", + " \"OrigDNA\",\n", + " \"OrigER\",\n", + " \"OrigMito\",\n", + " \"OrigRNA\",\n", + " \"OrigAGP\",\n", + "]\n", + "input_channel_name = \"OrigBrightfield\"\n", + "target_channel_names = [ch for ch in channel_names if ch != input_channel_name]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prep Patch dataset and Cache" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-03-03 20:30:03,223 - DEBUG - Channel keys: {'OrigMito', 'OrigBrightfield', 'OrigDNA', 'OrigRNA', 'OrigER', 'OrigAGP'} detected\n", + "2025-03-03 20:30:03,225 - DEBUG - No channel keys specified, skip\n", + "2025-03-03 20:30:03,226 - DEBUG - No channel keys specified, skip\n", + "2025-03-03 20:30:03,226 - DEBUG - Setting input transform ...\n", + "2025-03-03 20:30:03,226 - DEBUG - Setting target transform ...\n", + "2025-03-03 20:30:03,226 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-03-03 20:30:03,227 - DEBUG - Set target channel(s) as ['OrigDNA']\n" + ] + } + ], + "source": [ + "ds = GenericImageDataset(\n", + " image_dir=EXAMPLE_PATCH_DATA_EXPORT_PATH,\n", + " site_pattern=r\"^([^_]+_[^_]+_[^_]+)\",\n", + " channel_pattern=r\"_([^_]+)\\.tiff$\",\n", + " verbose=True\n", + ")\n", + "\n", + "## Set input and target channels\n", + "ds.set_input_channel_keys([input_channel_name])\n", + "ds.set_target_channel_keys('OrigDNA')\n", + "\n", + "## Cache for faster training \n", + "cds = CachedDataset(\n", + " ds,\n", + " prefill_cache=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FNet trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model without callback and check logs" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "model = FNet(depth=4)\n", + "lr = 3e-4\n", + "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=None,\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda',\n", + " early_termination_metric = None\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.microsoft.datawrangler.viewer.v0+json": { + "columns": [ + { + "name": "index", + "rawType": "int64", + "type": "integer" + }, + { + "name": "epoch", + "rawType": "int64", + "type": "integer" + }, + { + "name": "L1Loss", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_L1Loss", + "rawType": "float64", + "type": "float" + }, + { + "name": "psnr", + "rawType": "float64", + "type": "float" + }, + { + "name": "ssim", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_psnr", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_ssim", + "rawType": "float64", + "type": "float" + } + ], + "conversionMethod": "pd.DataFrame", + "ref": "57db911c-095d-4f8b-8f57-ab73118f2b71", + "rows": [ + [ + "0", + "1", + "1556.9512803819443", + "1610.1253051757812", + "-70.28494262695312", + "9.063276795728825e-10", + "-70.56584167480469", + "-8.826770980796539e-10" + ], + [ + "1", + "2", + "1702.884250217014", + "1610.0806274414062", + "-70.88795471191406", + "-3.853541929998983e-09", + "-70.56578063964844", + "-1.0126930405363055e-09" + ], + [ + "2", + "3", + "1515.5343560112847", + "1609.9841918945312", + "-70.11585998535156", + "-4.895769567525576e-09", + "-70.56565856933594", + "-1.7035615140770233e-09" + ], + [ + "3", + "4", + "1559.3525390625", + "1609.861083984375", + "-70.39810943603516", + "-3.283770810824649e-09", + "-70.56550598144531", + "-1.8262883427766496e-09" + ], + [ + "4", + "5", + "1545.5415174696182", + "1609.8245849609375", + "-70.3438491821289", + "-3.253234126532334e-09", + "-70.56546020507812", + "-1.8821437741678437e-09" + ], + [ + "5", + "6", + "1542.607638888889", + "1609.786376953125", + "-70.49755859375", + "-8.243815630137874e-10", + "-70.56541442871094", + "-1.5619402438105112e-09" + ], + [ + "6", + "7", + "1501.347873263889", + "1609.7578735351562", + "-69.79637908935547", + "3.219659261421981e-10", + "-70.56538391113281", + "-1.5038892353658184e-09" + ], + [ + "7", + "8", + "1526.6503228081597", + "1609.74462890625", + "-70.21217346191406", + "-1.515618464065227e-10", + "-70.56536102294922", + "-9.222221319937773e-10" + ], + [ + "8", + "9", + "1527.799533420139", + "1609.7286987304688", + "-70.20182800292969", + "-8.36740537968339e-11", + "-70.56534576416016", + "-1.539568472708197e-09" + ], + [ + "9", + "10", + "1627.6031358506943", + "1609.71826171875", + "-70.63406372070312", + "-1.5872374595216066e-11", + "-70.5653305053711", + "-7.939728874362117e-10" + ] + ], + "shape": { + "columns": 7, + "rows": 10 + } + }, + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochL1Lossval_L1Losspsnrssimval_psnrval_ssim
011556.9512801610.125305-70.2849439.063277e-10-70.565842-8.826771e-10
121702.8842501610.080627-70.887955-3.853542e-09-70.565781-1.012693e-09
231515.5343561609.984192-70.115860-4.895770e-09-70.565659-1.703562e-09
341559.3525391609.861084-70.398109-3.283771e-09-70.565506-1.826288e-09
451545.5415171609.824585-70.343849-3.253234e-09-70.565460-1.882144e-09
561542.6076391609.786377-70.497559-8.243816e-10-70.565414-1.561940e-09
671501.3478731609.757874-69.7963793.219659e-10-70.565384-1.503889e-09
781526.6503231609.744629-70.212173-1.515618e-10-70.565361-9.222221e-10
891527.7995331609.728699-70.201828-8.367405e-11-70.565346-1.539568e-09
9101627.6031361609.718262-70.634064-1.587237e-11-70.565331-7.939729e-10
\n", + "
" + ], + "text/plain": [ + " epoch L1Loss val_L1Loss psnr ssim val_psnr \\\n", + "0 1 1556.951280 1610.125305 -70.284943 9.063277e-10 -70.565842 \n", + "1 2 1702.884250 1610.080627 -70.887955 -3.853542e-09 -70.565781 \n", + "2 3 1515.534356 1609.984192 -70.115860 -4.895770e-09 -70.565659 \n", + "3 4 1559.352539 1609.861084 -70.398109 -3.283771e-09 -70.565506 \n", + "4 5 1545.541517 1609.824585 -70.343849 -3.253234e-09 -70.565460 \n", + "5 6 1542.607639 1609.786377 -70.497559 -8.243816e-10 -70.565414 \n", + "6 7 1501.347873 1609.757874 -69.796379 3.219659e-10 -70.565384 \n", + "7 8 1526.650323 1609.744629 -70.212173 -1.515618e-10 -70.565361 \n", + "8 9 1527.799533 1609.728699 -70.201828 -8.367405e-11 -70.565346 \n", + "9 10 1627.603136 1609.718262 -70.634064 -1.587237e-11 -70.565331 \n", + "\n", + " val_ssim \n", + "0 -8.826771e-10 \n", + "1 -1.012693e-09 \n", + "2 -1.703562e-09 \n", + "3 -1.826288e-09 \n", + "4 -1.882144e-09 \n", + "5 -1.561940e-09 \n", + "6 -1.503889e-09 \n", + "7 -9.222221e-10 \n", + "8 -1.539568e-09 \n", + "9 -7.939729e-10 " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame(trainer.log)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model with alternative early termination metric" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Early termination at epoch 6 with best validation metric -69.82839965820312\n" + ] + } + ], + "source": [ + "model = FNet(depth=4)\n", + "lr = 3e-4\n", + "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=None,\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda',\n", + " early_termination_metric = 'psnr' # set early termination metric as psnr for the sake of demonstration\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train with mlflow logger callbacks" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'lr': 3e-4\n", + " },\n", + " )\n", + "\n", + "del trainer\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=[mlflow_logger_callback],\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda'\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# wGaN GP example with mlflow logger callback and plot callback" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "generator = UNet(\n", + " n_channels=1,\n", + " n_classes=1\n", + ")\n", + "\n", + "discriminator = GlobalDiscriminator(\n", + " n_in_channels = 2,\n", + " n_in_filters = 64,\n", + " _conv_depth = 4,\n", + " _pool_before_fc = True\n", + ")\n", + "\n", + "generator_optimizer = optim.Adam(generator.parameters(), \n", + " lr=0.0002, \n", + " betas=(0., 0.9))\n", + "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n", + " lr=0.00002, \n", + " betas=(0., 0.9),\n", + " weight_decay=0.001)\n", + "\n", + "gp_loss = GradientPenaltyLoss(\n", + " _metric_name='gp_loss',\n", + " discriminator=discriminator,\n", + " weight=10.0,\n", + ")\n", + "\n", + "gen_loss = GeneratorLoss(\n", + " _metric_name='gen_loss'\n", + ")\n", + "\n", + "disc_loss = WassersteinLoss(\n", + " _metric_name='disc_loss'\n", + ")\n", + "\n", + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train_wgan', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'gen_lr': 0.0002,\n", + " 'disc_lr': 0.00002\n", + " },\n", + " )\n", + "\n", + "plot_callback = IntermediatePlot(\n", + " name='plotter',\n", + " path=PLOT_DIR,\n", + " dataset=ds, # give it the patch dataset as opposed to the cached dataset\n", + " indices=[1,3,5,7,9], # plot 5 selected patches images from the dataset\n", + " plot_metrics=[SSIM(_metric_name='ssim'), PSNR(_metric_name='psnr')],\n", + " figsize=(20, 25),\n", + " show_plot=False,\n", + ")\n", + "\n", + "wgan_trainer = WGANTrainer(\n", + " dataset=cds,\n", + " batch_size=16,\n", + " epochs=20,\n", + " patience=20, # setting this to prevent unwanted early termination here\n", + " device='cuda',\n", + " generator=generator,\n", + " discriminator=discriminator,\n", + " gen_optimizer=generator_optimizer,\n", + " disc_optimizer=discriminator_optimizer,\n", + " generator_loss_fn=gen_loss,\n", + " discriminator_loss_fn=disc_loss,\n", + " gradient_penalty_fn=gp_loss,\n", + " discriminator_update_freq=1,\n", + " generator_update_freq=2,\n", + " callbacks=[mlflow_logger_callback, plot_callback],\n", + " metrics={'ssim': SSIM(_metric_name='ssim'), \n", + " 'psnr': PSNR(_metric_name='psnr')}\n", + ")\n", + "\n", + "wgan_trainer.train()\n", + "\n", + "del generator\n", + "del wgan_trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## # wGaN GP example with mlflow logger callback and alternative early termination loss" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "generator = UNet(\n", + " n_channels=1,\n", + " n_classes=1\n", + ")\n", + "\n", + "discriminator = GlobalDiscriminator(\n", + " n_in_channels = 2,\n", + " n_in_filters = 64,\n", + " _conv_depth = 4,\n", + " _pool_before_fc = True\n", + ")\n", + "\n", + "generator_optimizer = optim.Adam(generator.parameters(), \n", + " lr=0.0002, \n", + " betas=(0., 0.9))\n", + "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n", + " lr=0.00002, \n", + " betas=(0., 0.9),\n", + " weight_decay=0.001)\n", + "\n", + "gp_loss = GradientPenaltyLoss(\n", + " _metric_name='gp_loss',\n", + " discriminator=discriminator,\n", + " weight=10.0,\n", + ")\n", + "\n", + "gen_loss = GeneratorLoss(\n", + " _metric_name='gen_loss'\n", + ")\n", + "\n", + "disc_loss = WassersteinLoss(\n", + " _metric_name='disc_loss'\n", + ")\n", + "\n", + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train_wgan_mae_early_term', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'gen_lr': 0.0002,\n", + " 'disc_lr': 0.00002\n", + " },\n", + " )\n", + "\n", + "wgan_trainer = WGANTrainer(\n", + " dataset=cds,\n", + " batch_size=16,\n", + " epochs=20,\n", + " patience=5, # lower patience here\n", + " device='cuda',\n", + " generator=generator,\n", + " discriminator=discriminator,\n", + " gen_optimizer=generator_optimizer,\n", + " disc_optimizer=discriminator_optimizer,\n", + " generator_loss_fn=gen_loss,\n", + " discriminator_loss_fn=disc_loss,\n", + " gradient_penalty_fn=gp_loss,\n", + " discriminator_update_freq=1,\n", + " generator_update_freq=2,\n", + " callbacks=[mlflow_logger_callback],\n", + " metrics={'ssim': SSIM(_metric_name='ssim'), \n", + " 'psnr': PSNR(_metric_name='psnr'),\n", + " 'mae': MetricsWrapper(_metric_name='mae', module=nn.L1Loss()) # use a wrapper for torch nn L1Loss\n", + " },\n", + " early_termination_metric = 'mae' # update early temrination loss with the supplied L1Loss/mae metric instead of the default GaN generator loss\n", + ")\n", + "\n", + "wgan_trainer.train()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cp_gan_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/losses/AbstractLoss.py b/losses/AbstractLoss.py new file mode 100644 index 0000000..c96c70b --- /dev/null +++ b/losses/AbstractLoss.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +""" +Adapted from https://github.com/WayScience/nuclear_speckles_analysis +""" +class AbstractLoss(nn.Module, ABC): + """Abstract class for metrics""" + + def __init__(self, _metric_name: str): + + super(AbstractLoss, self).__init__() + + self._metric_name = _metric_name + self._trainer = None + + @property + def trainer(self): + return self._trainer + + @trainer.setter + def trainer(self, value): + """ + Setter of trainer meant to be called by the trainer class during initialization + """ + self._trainer = value + + @property + def metric_name(self, _metric_name: str): + """Defines the mertic name returned by the class.""" + return self._metric_name + + @abstractmethod + def forward(self, truth: torch.Tensor, generated: torch.Tensor + ) -> float: + """ + Computes the metric given information about the data + + :param truth: The tensor containing the ground truth image, + should be of shape [batch_size, channel_number, img_height, img_width]. + :type truth: torch.Tensor + :param generated: The tensor containing model generated image, + should be of shape [batch_size, channel_number, img_height, img_width]. + :type generated: torch.Tensor + :return: The computed metric as a float value. + :rtype: float + """ + pass \ No newline at end of file diff --git a/losses/DiscriminatorLoss.py b/losses/DiscriminatorLoss.py new file mode 100644 index 0000000..84ee7ca --- /dev/null +++ b/losses/DiscriminatorLoss.py @@ -0,0 +1,38 @@ +import torch + +from .AbstractLoss import AbstractLoss + +class WassersteinLoss(AbstractLoss): + """ + This class implements the loss function for the discriminator in a Wasserstein Generative Adversarial Network (wGAN). + The discriminator loss measures how well the discriminator is able to distinguish between real (ground expected_truth) + images and fake (expected_generated) images produced by the generator. + """ + def __init__(self, _metric_name): + super().__init__(_metric_name) + + def forward(self, expected_truth, expected_generated): + """ + Computes the Wasserstein Discriminator Loss given probability scores expected_truth and expected_generated from the discriminator + + :param expected_truth: The tensor containing the ground expected_truth + probability score predicted by the discriminator over a batch of real images (input target pair), + should be of shape [batch_size, 1]. + :type expected_truth: torch.Tensor + :param expected_generated: The tensor containing model expected_generated + probability score predicted by the discriminator over a batch of generated images (input generated pair), + should be of shape [batch_size, 1]. + :type expected_generated: torch.Tensor + :return: The computed metric as a float value. + :rtype: float + """ + + # If the probability output is more than Scalar, take the mean of the output + # For compatibility with both a Discriminator class that would output a scalar probability (currently implemented) + # and a Discriminator class that would output a 2d matrix of probabilities (currently not implemented) + if expected_truth.dim() >= 3: + expected_truth = torch.mean(expected_truth, tuple(range(2, expected_truth.dim()))) + if expected_generated.dim() >= 3: + expected_generated = torch.mean(expected_generated, tuple(range(2, expected_generated.dim()))) + + return (expected_generated - expected_truth).mean() \ No newline at end of file diff --git a/losses/GeneratorLoss.py b/losses/GeneratorLoss.py new file mode 100644 index 0000000..9ec2bd5 --- /dev/null +++ b/losses/GeneratorLoss.py @@ -0,0 +1,65 @@ +from typing import Optional + +import torch +from torch.nn import L1Loss + +from .AbstractLoss import AbstractLoss + +class GeneratorLoss(AbstractLoss): + """ + Computes the loss for the GaN generator. + Combines an adversarial loss component with an image reconstruction loss. + """ + def __init__(self, + _metric_name: str, + reconstruction_loss: Optional[torch.tensor] = L1Loss(), + reconstruction_weight: float = 1.0 + ): + """ + :param reconstruction_loss: The image reconstruction loss, + defaults to L1Loss(reduce=False) + :type reconstruction_loss: torch.tensor + :param reconstruction_weight: The weight for the image reconstruction loss, defaults to 1.0 + :type reconstruction_weight: float + """ + + super().__init__(_metric_name) + + self._reconstruction_loss = reconstruction_loss + if isinstance(reconstruction_weight, float): + self._reconstruction_weight = reconstruction_weight + else: + raise ValueError("reconstruction_weight must be a float value") + + def forward(self, + discriminator_probs: torch.tensor, + truth: torch.tensor, + generated: torch.tensor, + epoch: int = 0 + ): + """ + Computes the loss for the GaN generator. + + :param discriminator_probs: The probabilities of the discriminator for the fake images being real. + :type discriminator_probs: torch.tensor + :param truth: The tensor containing the ground truth image, + should be of shape [batch_size, channel_number, img_height, img_width]. + :type truth: torch.Tensor + :param generated: The tensor containing model generated image, + should be of shape [batch_size, channel_number, img_height, img_width]. + :type generated: torch.Tensor + :param epoch: The current epoch number. + Used for a smoothing weight for the adversarial loss component + Defaults to 0. + :type epoch: int + :return: The computed metric as a float value. + :rtype: float + """ + + # Adversarial loss + adversarial_loss = -torch.mean(discriminator_probs) + adversarial_loss = 0.01 * adversarial_loss/(epoch + 1) + + image_loss = self._reconstruction_loss(generated, truth) + + return adversarial_loss + self._reconstruction_weight * image_loss.mean() \ No newline at end of file diff --git a/losses/GradientPenaltyLoss.py b/losses/GradientPenaltyLoss.py new file mode 100644 index 0000000..fa96a00 --- /dev/null +++ b/losses/GradientPenaltyLoss.py @@ -0,0 +1,44 @@ +import torch +import torch.autograd as autograd + +from .AbstractLoss import AbstractLoss + +class GradientPenaltyLoss(AbstractLoss): + def __init__(self, _metric_name, discriminator, weight=10.0): + super().__init__(_metric_name) + + self._discriminator = discriminator + self._weight = weight + + def forward(self, truth, generated): + """ + Computes Gradient Penalty Loss for wGaN GP + + :param truth: The tensor containing the ground truth image, + should be of shape [batch_size, channel_number, img_height, img_width]. + :type truth: torch.Tensor + :param generated: The tensor containing model generated image, + should be of shape [batch_size, channel_number, img_height, img_width]. + :type generated: torch.Tensor + :return: The computed metric as a float value. + :rtype: float + """ + + device = self.trainer.device + + batch_size = truth.size(0) + eta = torch.rand(batch_size, 1, 1, 1, device=device).expand_as(truth) + interpolated = (eta * truth + (1 - eta) * generated).requires_grad_(True) + prob_interpolated = self._discriminator(interpolated) + + gradients = autograd.grad( + outputs=prob_interpolated, + inputs=interpolated, + grad_outputs=torch.ones_like(prob_interpolated), + create_graph=True, + retain_graph=True, + )[0] + + gradients = gradients.view(batch_size, -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return self._weight * gradient_penalty \ No newline at end of file diff --git a/losses/README.md b/losses/README.md new file mode 100644 index 0000000..2fe0a00 --- /dev/null +++ b/losses/README.md @@ -0,0 +1 @@ +Here lives the loss functions used by wGaN GP \ No newline at end of file diff --git a/metrics/AbstractMetrics.py b/metrics/AbstractMetrics.py new file mode 100644 index 0000000..24fc571 --- /dev/null +++ b/metrics/AbstractMetrics.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import torch +import torch.nn as nn + +class AbstractMetrics(nn.Module, ABC): + """Abstract class for metrics""" + + def __init__(self, _metric_name: str): + + super(AbstractMetrics, self).__init__() + + self.__metric_name = _metric_name + + self.__train_metric_values = [] + self.__val_metric_values = [] + + @property + def metric_name(self): + """Defines the mertic name returned by the class.""" + return self.__metric_name + + @property + def train_metric_values(self): + """Returns the training metric values.""" + return self.__train_metric_values + + @property + def val_metric_values(self): + """Returns the validation metric values.""" + return self.__val_metric_values + + @abstractmethod + def forward(self, + _generated_outputs: torch.tensor, + _targets: torch.tensor + ) -> torch.tensor: + """Computes the metric given information about the data.""" + pass + + def update(self, + _generated_outputs: torch.tensor, + _targets: torch.tensor, + validation: bool=False + ) -> None: + """Updates the metric with the new data.""" + if validation: + self.__val_metric_values.append(self.forward(_generated_outputs, _targets)) + else: + self.__train_metric_values.append(self.forward(_generated_outputs, _targets)) + + def reset(self): + """Resets the metric.""" + self.__train_metric_values = [] + self.__val_metric_values = [] + + def compute(self, **kwargs): + """ + Calls the aggregate_metrics method to compute the metric value for now + In future may be used for more complex computations + """ + return self.aggregate_metrics(**kwargs) + + def aggregate_metrics(self, aggregation: Optional[str] = 'mean'): + """ + Aggregates the metric value over batches + + :param aggregation: The aggregation method to use, by default 'mean' + :type aggregation: Optional[str] + :return: The aggregated metric value for training and validation + :rtype: Tuple[torch.tensor, torch.tensor] + """ + + if aggregation == 'mean': + return \ + torch.mean(torch.stack(self.__train_metric_values)) if len(self.__train_metric_values) > 0 else None , \ + torch.mean(torch.stack(self.__val_metric_values)) if len(self.__val_metric_values) > 0 else None + + elif aggregation == 'sum': + return \ + torch.sum(torch.stack(self.__train_metric_values)) if len(self.__train_metric_values) > 0 else None , \ + torch.sum(torch.stack(self.__val_metric_values)) if len(self.__val_metric_values) > 0 else None + + elif aggregation is None: + return \ + torch.stack(self.__train_metric_values) if len(self.__train_metric_values) > 0 else None , \ + torch.stack(self.__val_metric_values) if len(self.__val_metric_values) > 0 else None + + else: + raise ValueError(f"Aggregation method {aggregation} is not supported.") \ No newline at end of file diff --git a/metrics/MetricsWrapper.py b/metrics/MetricsWrapper.py new file mode 100644 index 0000000..d261ca2 --- /dev/null +++ b/metrics/MetricsWrapper.py @@ -0,0 +1,24 @@ +import torch + +from .AbstractMetrics import AbstractMetrics + +class MetricsWrapper(AbstractMetrics): + """Metrics wrapper class that wraps a pytorch module + and calls it forward pass function to accumulate the metric + values across batches + """ + + def __init__(self, _metric_name: str, module: torch.nn.Module): + """ + Initialize the MetricsWrapper class with the metric name and the module. + + :param _metric_name: The name of the metric. + :param module: The module to be wrapped. Needs to have a forward function. + :type module: torch.nn.Module + """ + + super(MetricsWrapper, self).__init__(_metric_name) + self._module = module + + def forward(self,_generated_outputs: torch.Tensor, _targets: torch.Tensor): + return self._module(_generated_outputs, _targets).mean() \ No newline at end of file diff --git a/metrics/PSNR.py b/metrics/PSNR.py new file mode 100644 index 0000000..d07bc4c --- /dev/null +++ b/metrics/PSNR.py @@ -0,0 +1,43 @@ +import torch + +from .AbstractMetrics import AbstractMetrics + +""" +Adapted from https://github.com/WayScience/nuclear_speckles_analysis +""" +class PSNR(AbstractMetrics): + """Computes and tracks the Peak Signal-to-Noise Ratio (PSNR).""" + + def __init__(self, _metric_name: str, _max_pixel_value: int = 1): + """ + Initializes the PSNR metric. + + :param _metric_name: The name of the metric. + :param _max_pixel_value: The maximum possible pixel value of the images, by default 1. + :type _max_pixel_value: int, optional + """ + + super(PSNR, self).__init__(_metric_name) + + self.__max_pixel_value = _max_pixel_value + + def forward(self, _generated_outputs: torch.Tensor, _targets: torch.Tensor): + """ + Computes the Peak Signal-to-Noise Ratio (PSNR) between the generated outputs and the target images. + + :param _generated_outputs: The tensor containing the generated output images. + :type _generated_outputs: torch.Tensor + :param _targets: The tensor containing the target images. + :type _targets: torch.Tensor + :return: The computed PSNR value. + :rtype: torch.Tensor + """ + + mse = torch.mean((_generated_outputs - _targets) ** 2, dim=[2, 3]) + psnr = torch.where( + mse == 0, + torch.tensor(0.0), + 10 * torch.log10((self.__max_pixel_value**2) / mse), + ) + + return psnr.mean() \ No newline at end of file diff --git a/metrics/README.md b/metrics/README.md new file mode 100644 index 0000000..ac8b458 --- /dev/null +++ b/metrics/README.md @@ -0,0 +1,2 @@ +Here lives the metric classes which is dependent on a abstract metric class +Each metric needs to have a foward function implemented over target and predict while the abstract class functions inhertied handles accumulation \ No newline at end of file diff --git a/metrics/SSIM.py b/metrics/SSIM.py new file mode 100644 index 0000000..a06deee --- /dev/null +++ b/metrics/SSIM.py @@ -0,0 +1,52 @@ +import torch + +from .AbstractMetrics import AbstractMetrics + +""" +Adapted from https://github.com/WayScience/nuclear_speckles_analysis +""" +class SSIM(AbstractMetrics): + """Computes and tracks the Structural Similarity Index Measure (SSIM).""" + + def __init__(self, _metric_name: str, _max_pixel_value: int = 1): + """ + Initializes the SSIM metric. + + :param _metric_name: The name of the metric. + :param _max_pixel_value: The maximum possible pixel value of the images, by default 1. + :type _max_pixel_value: int, optional + """ + + super(SSIM, self).__init__(_metric_name) + + self.__max_pixel_value = _max_pixel_value + + def forward(self, _generated_outputs: torch.Tensor, _targets: torch.Tensor): + """ + Computes the Structural Similarity Index Measure (SSIM) between the generated outputs and the target images. + + :param _generated_outputs: The tensor containing the generated output images. + :type _generated_outputs: torch.Tensor + :param _targets: The tensor containing the target images. + :type _targets: torch.Tensor + :return: The computed SSIM value. + :rtype: torch.Tensor + """ + + mu1 = _generated_outputs.mean(dim=[2, 3], keepdim=True) + mu2 = _targets.mean(dim=[2, 3], keepdim=True) + + sigma1_sq = ((_generated_outputs - mu1) ** 2).mean(dim=[2, 3], keepdim=True) + sigma2_sq = ((_targets - mu2) ** 2).mean(dim=[2, 3], keepdim=True) + sigma12 = ((_generated_outputs - mu1) * (_targets - mu2)).mean( + dim=[2, 3], keepdim=True + ) + + c1 = (self.__max_pixel_value * 0.01) ** 2 + c2 = (self.__max_pixel_value * 0.03) ** 2 + + ssim_value = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ( + (mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2) + ) + + return ssim_value.mean() \ No newline at end of file diff --git a/models/README.md b/models/README.md new file mode 100644 index 0000000..2c0cd8b --- /dev/null +++ b/models/README.md @@ -0,0 +1,3 @@ +Here lives the torch model and parts for FNet, UNet and wGaN GP + +Quite unclean in its current state. \ No newline at end of file diff --git a/models/discriminator.py b/models/discriminator.py new file mode 100644 index 0000000..fa93b22 --- /dev/null +++ b/models/discriminator.py @@ -0,0 +1,154 @@ +import torch +from torch import nn +import torch.nn.functional as F + +""" +Implementation of GaN discriminators to use along with UNet or FNet generator. +""" + +class PatchBasedDiscriminator(nn.Module): + + def __init__( + self, + n_in_channels: int, + n_in_filters: int, + _conv_depth: int=4, + _leaky_relu_alpha: float=0.2, + _batch_norm: bool=False + ): + """ + A patch-based discriminator for pix2pix GANs that outputs a feature map + of probabilities + + :param n_in_channels: (int) number of input channels + :type n_in_channels: int + :param n_in_filters: (int) number of filters in the first convolutional layer. + Every subsequent layer will double the number of filters + :type n_in_filters: int + :param _conv_depth: (int) depth of the convolutional network + :type _conv_depth: int + :param _leaky_relu_alpha: (float) alpha value for leaky ReLU activation. + Must be between 0 and 1 + :type _leaky_relu_alpha: float + :param _batch_norm: (bool) whether to use batch normalization, defaults to False + :type _batch_norm: bool + """ + + super().__init__() + + conv_layers = [] + + n_channels = n_in_filters + conv_layers.append( + nn.Conv2d(n_in_channels, n_channels, kernel_size=4, stride=2, padding=1) + ) + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) + + # Sequentially add convolutional layers + for _ in range(_conv_depth - 2): + conv_layers.append( + nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=2, padding=1) + ) + conv_layers.append(nn.BatchNorm2d(n_channels * 2)) + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) + n_channels *= 2 + + # Another layer of conv without downscaling + ## TODO: figure out if this is needed + conv_layers.append( + nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=1, padding=1) + ) + + if _batch_norm: + conv_layers.append(nn.BatchNorm2d(n_channels * 2)) + + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) + n_channels *= 2 + self._conv_layers = nn.Sequential(*conv_layers) + + # Output layer to get the probability map + self.out = nn.Sequential( + *[nn.Conv2d(n_channels, 1, kernel_size=4, stride=1, padding=1), + nn.Sigmoid()] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._conv_layers(x) + x = self.out(x) + + return x + +class GlobalDiscriminator(nn.Module): + + def __init__( + self, + n_in_channels: int, + n_in_filters: int, + _conv_depth: int=4, + _leaky_relu_alpha: float=0.2, + _batch_norm: bool=False, + _pool_before_fc: bool=False + ): + """ + A global discriminator for pix2pix GANs that outputs a single scalar value as the global probability + + Parameters: + :param n_in_channels: (int) number of input channels + :type n_in_channels: int + :param n_in_filters: (int) number of filters in the first convolutional layer. + Every subsequent layer will double the number of filters + :type n_in_filters: int + :param _conv_depth: (int) depth of the convolutional network + :type _conv_depth: int + :param _leaky_relu_alpha: (float) alpha value for leaky ReLU activation. + Must be between 0 and 1 + :type _leaky_relu_alpha: float + :param _batch_norm: (bool) whether to use batch normalization, defaults to False + :type _batch_norm: bool + :param _pool_before_fc: (bool) whether to pool before the fully connected network + Pooling before the fully connected network can reduce the number of parameters + :type _pool_before_fc: bool + """ + + super().__init__() + + conv_layers = [] + + n_channels = n_in_filters + conv_layers.append( + nn.Conv2d(n_in_channels, n_channels, kernel_size=4, stride=2, padding=1) + ) + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) + + # Sequentially add convolutional layers + for _ in range(_conv_depth - 1): + conv_layers.append( + nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=2, padding=1) + ) + + if _batch_norm: + conv_layers.append(nn.BatchNorm2d(n_channels * 2)) + + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) + n_channels *= 2 + + # Flattening + if _pool_before_fc: + conv_layers.append(nn.AdaptiveAvgPool2d((1, 1))) + conv_layers.append(nn.Flatten()) + self._conv_layers = nn.Sequential(*conv_layers) + + + # Fully connected network to output probability + self.fc = nn.Sequential( + nn.LazyLinear(512), + nn.LeakyReLU(_leaky_relu_alpha, inplace=True), + nn.Linear(512, 1), + nn.Sigmoid() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._conv_layers(x) + x = self.fc(x) + + return x \ No newline at end of file diff --git a/models/fnet.py b/models/fnet.py new file mode 100644 index 0000000..cb2e344 --- /dev/null +++ b/models/fnet.py @@ -0,0 +1,98 @@ +import torch + +class FNet(torch.nn.Module): + def __init__(self, depth=4, mult_chan=32, output_activation='sigmoid'): + """ + Initialize the FNet model with a customizable depth. + + :param depth: Depth of the F-Net model. + :type depth: int + :param mult_chan: Factor to determine number of output channels. + :type mult_chan: int + :param output_activation: Activation function for the output layer. + :type output_activation: str + """ + super().__init__() + self._depth = depth + self._multi_chan = mult_chan + self.net_recurse = _Net_recurse( + n_in_channels=1, + mult_chan=self._multi_chan, + depth=self._depth) + self.conv_out = torch.nn.Conv2d( + self._multi_chan, 1, kernel_size=3, padding=1) + + if output_activation == 'sigmoid': + self.output_activation = torch.nn.Sigmoid() + elif output_activation == 'relu': + self.output_activation = torch.nn.ReLU() + elif output_activation == 'linear': + self.output_activation = torch.nn.Identity() + else: + raise ValueError('Invalid output_activation') + + def forward(self, x): + x_rec = self.net_recurse(x) + x_act = self.conv_out(x_rec) + + return self.output_activation(x_act) + +class _Net_recurse(torch.nn.Module): + def __init__(self, n_in_channels, mult_chan=2, depth=0): + """Class for recursive definition of U-network.p + + Parameters: + in_channels - (int) number of channels for input. + mult_chan - (int) factor to determine number of output channels + depth - (int) if 0, this subnet will only be convolutions that double the channel count. + """ + super().__init__() + self.depth = depth + n_out_channels = n_in_channels * mult_chan + self.sub_2conv_more = SubNet2Conv(n_in_channels, n_out_channels) + + if depth > 0: + self.sub_2conv_less = SubNet2Conv(2 * n_out_channels, n_out_channels) + self.conv_down = torch.nn.Conv2d(n_out_channels, n_out_channels, 2, stride=2) + self.bn0 = torch.nn.BatchNorm2d(n_out_channels) + self.relu0 = torch.nn.ReLU() + + self.convt = torch.nn.ConvTranspose2d(2 * n_out_channels, n_out_channels, kernel_size=2, stride=2) + self.bn1 = torch.nn.BatchNorm2d(n_out_channels) + self.relu1 = torch.nn.ReLU() + self.sub_u = _Net_recurse(n_out_channels, mult_chan=2, depth=(depth - 1)) + + def forward(self, x): + if self.depth == 0: + return self.sub_2conv_more(x) + else: # depth > 0 + x_2conv_more = self.sub_2conv_more(x) + x_conv_down = self.conv_down(x_2conv_more) + x_bn0 = self.bn0(x_conv_down) + x_relu0 = self.relu0(x_bn0) + x_sub_u = self.sub_u(x_relu0) + x_convt = self.convt(x_sub_u) + x_bn1 = self.bn1(x_convt) + x_relu1 = self.relu1(x_bn1) + x_cat = torch.cat((x_2conv_more, x_relu1), 1) # concatenate + x_2conv_less = self.sub_2conv_less(x_cat) + return x_2conv_less + +class SubNet2Conv(torch.nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv1 = torch.nn.Conv2d(n_in, n_out, kernel_size=3, padding=1) + self.bn1 = torch.nn.BatchNorm2d(n_out) + self.relu1 = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(n_out, n_out, kernel_size=3, padding=1) + self.bn2 = torch.nn.BatchNorm2d(n_out) + self.relu2 = torch.nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + return x diff --git a/models/unet.py b/models/unet.py new file mode 100644 index 0000000..2e04f12 --- /dev/null +++ b/models/unet.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle + +from .unet_utils import * + +class UNet(nn.Module): + def __init__(self, + n_channels, + n_classes, + base_channels=64, + depth=4, + bilinear=False, + output_activation='sigmoid'): + """ + Initialize the U-Net model with a customizable depth. + + :param n_channels: Number of input channels. + :type n_channels: int + :param n_classes: Number of output classes. + :type n_classes: int + :param base_channels: Number of base channels to start with. + :type base_channels: int + :param depth: Depth of the U-Net model. + :type depth: int + :param bilinear: Use bilinear interpolation for ups + :type bilinear: bool + :param output_activation: Activation function for the output layer. + :type output_activation: str + """ + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.depth = depth + self.bilinear = bilinear + + in_channels = n_channels # Input channel to the first upsampling layer is the number of input channels + out_channels = base_channels # Output channel of the first upsampling layer is the base number of channels + + # Initial upsampling layer + self.inc = ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=None, + num_layers=2) + + # Contracting path + contracting_path = [] + for _ in range(self.depth): + # set the number of input channels to the output channels of the previous layer + in_channels = out_channels + # double the number of output channels for the next layer + out_channels *= 2 + contracting_path.append( + Contract( + in_channels=in_channels, + out_channels=out_channels + ) + ) + self.down = nn.ModuleList(contracting_path) + + # Bottleneck + factor = 2 if bilinear else 1 + in_channels = out_channels # Input channel to the bottleneck layer is the output channel of the last downsampling layer + out_channels = in_channels // factor + self.bottleneck = ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=None, + num_layers=2 + ) + + # Expanding path + expanding_path = [] + for _ in range(self.depth): + # input to expanding path has the same dimension as the output of the bottleneck layer + in_channels = out_channels + # half the number of output channels for the next layer + out_channels = in_channels // 2 + expanding_path.append( + Up( + in_channels=in_channels, + out_channels=out_channels, + bilinear=bilinear + ) + ) + self.up = nn.ModuleList(expanding_path) + + # Output layer + self.outc = OutConv( + base_channels, + n_classes, + output_activation=output_activation) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the U-Net model. + + :param x: Input tensor of shape (batch_size, n_channels, height, width). + :type x: torch.Tensor + :return: Output tensor of shape (batch_size, n_classes, height, width). + :rtype: torch.Tensor + """ + # Contracting path + x_contracted = [] + x = self.inc(x) + for down in self.down: + x_contracted.append(x) + x = down(x) + + # Bottleneck + x = self.bottleneck(x) + + # Expanding path + for i, up in enumerate(self.up): + x = up(x, x_contracted[-(i + 1)]) + + # Final output + logits = self.outc(x) + return logits \ No newline at end of file diff --git a/models/unet_utils.py b/models/unet_utils.py new file mode 100644 index 0000000..e8db2e0 --- /dev/null +++ b/models/unet_utils.py @@ -0,0 +1,195 @@ +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +""" +Components of the U-Net model +""" + +class ConvBnRelu(nn.Module): + """ + A customizable convolutional block: (Convolution => [BN] => ReLU) * N. + + Allows specifying the number of layers and intermediate channels. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + mid_channels: Optional[List[int]] = None, + num_layers: int = 2): + """ + Initialize the customizable DoubleConv module for upsampling/downsampling the channels. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + :param mid_channels: List of intermediate channel numbers for each convolutional layer. + If unspecified, defaults to [out_channels] * (num_layers - 1). + Order matters: mid_channels[0] corresponds to the first intermediate layer, etc. + :type mid_channels: Optional[List[int]] + :param num_layers: Number of convolutional layers in the block. + :type num_layers: int + """ + super().__init__() + + # Default intermediate channels if not specified + if mid_channels is None: + mid_channels = [out_channels] * (num_layers - 1) + + if len(mid_channels) != num_layers - 1: + raise ValueError("Length of mid_channels must be equal to num_layers - 1.") + + layers = [] + + # Add the first convolution layer + layers.append( + nn.Conv2d(in_channels, mid_channels[0], kernel_size=3, padding=1, bias=False) + ) + layers.append(nn.BatchNorm2d(mid_channels[0])) + layers.append(nn.ReLU(inplace=True)) + + # Add intermediate convolutional layers + for i in range(1, num_layers - 1): + layers.append( + nn.Conv2d(mid_channels[i - 1], mid_channels[i], kernel_size=3, padding=1, bias=False) + ) + layers.append(nn.BatchNorm2d(mid_channels[i])) + layers.append(nn.ReLU(inplace=True)) + + # Add the final convolution layer + layers.append( + nn.Conv2d(mid_channels[-1], out_channels, kernel_size=3, padding=1, bias=False) + ) + layers.append(nn.BatchNorm2d(out_channels)) + layers.append(nn.ReLU(inplace=True)) + + # Combine layers into a sequential block + self.conv_block = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ConvBnRelu module. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Processed output tensor. + :rtype: torch.Tensor + """ + return self.conv_block(x) + +class Contract(nn.Module): + """Downscaling with maxpool then 2 * ConvBnRelu""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), # Halves spatial dimensions + ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=None, + num_layers=2) # Refines features with 2 sequential convolutions + ) + + def forward(self, x): + return self.maxpool_conv(x) + +class Up(nn.Module): + """Upscaling then 2 * ConvBnRelu""" + def __init__(self, + in_channels: int, + out_channels: int, + bilinear: bool=True): + """ + Up sampling module that combines the upsampled feature map with the skip connection. + Upsampling is done via bilinear interpolation or transposed convolution. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + :param bilinear: If True, use bilinear upsampling + :type bilinear: bool + """ + super().__init__() + + # If bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv = ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=[in_channels // 2], + num_layers=2) + else: + self.up = nn.ConvTranspose2d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv = ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=None, + num_layers=2 + ) + + def forward(self, + x1: torch.Tensor, + x2: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the Up module. + :param x1: Input tensor to be upsampled. + :type x1: torch.Tensor + :param x2: Skip connection tensor. + :type x2: torch.Tensor + :return: Processed output tensor. + """ + x1 = self.up(x1) # Upsample x1 + + # Handle potential mismatches in spatial dimensions + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + # Concatenate x1 (upsampled) with x2 (skip connection) + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + +class OutConv(nn.Module): + """ + Final output layer that applies a 1x1 convolution followed by a sigmoid activation. + """ + def __init__(self, in_channels, out_channels, output_activation:str='sigmoid'): + """ + Initialize the OutConv module. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + :param output_activation: Activation function to apply to the output. + :type output_activation: str + """ + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + if output_activation == 'sigmoid': + self.output_activation = torch.nn.Sigmoid() + elif output_activation == 'relu': + self.output_activation = torch.nn.ReLU() + elif output_activation == 'linear': + self.output_activation = torch.nn.Identity() + else: + raise ValueError('Invalid output_activation') + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the OutConv module. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Processed output tensor. + :rtype: torch.Tensor + """ + return self.output_activation(self.conv(x)) \ No newline at end of file diff --git a/trainers/AbstractTrainer.py b/trainers/AbstractTrainer.py new file mode 100644 index 0000000..4fdb57d --- /dev/null +++ b/trainers/AbstractTrainer.py @@ -0,0 +1,460 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import List, Callable, Dict, Optional + +import torch +from torch.utils.data import DataLoader, random_split + +from ..metrics.AbstractMetrics import AbstractMetrics +from ..callbacks.AbstractCallback import AbstractCallback + + +class AbstractTrainer(ABC): + """ + Abstract trainer class for img2img translation models. + Provides shared dataset handling and modular callbacks for logging and evaluation. + """ + + def __init__( + self, + dataset: torch.utils.data.Dataset, + batch_size: int = 16, + epochs: int = 10, + patience: int = 5, + callbacks: List[AbstractCallback] = None, + metrics: Dict[str, AbstractMetrics] = None, + device: Optional[torch.device] = None, + early_termination_metric: str = None, + **kwargs, + ): + """ + :param dataset: The dataset to be used for training. + :type dataset: torch.utils.data.Dataset + :param batch_size: The batch size for training. + :type batch_size: int + :param epochs: The number of epochs for training. + :type epochs: int + :param patience: The number of epochs with no improvement after which training will be stopped. + :type patience: int + :param callbacks: List of callback functions to be executed + at the end of each epoch. + :type callbacks: list of callable + :param metrics: Dictionary of metrics to be logged. + :type metrics: dict + :param device: (optional) The device to be used for training. + :type device: torch.device + :param early_termination_metric: (optional) The metric to be tracked and used to update early + termination count on the validation dataset. If None, early termination is disabled and the + training will run for the specified number of epochs. + :type early_termination_metric: str + """ + + self._batch_size = batch_size + self._epochs = epochs + self._patience = patience + self.initialize_callbacks(callbacks) + self._metrics = metrics if metrics else {} + + if isinstance(device, torch.device): + self._device = device + else: + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self._best_model = None + self._best_loss = float("inf") + self._early_stop_counter = 0 + self._early_termination_metric = early_termination_metric + self._early_termination = True if early_termination_metric else False + + # Customize data splits + self._train_ratio = kwargs.get("train", 0.7) + self._val_ratio = kwargs.get("val", 0.15) + self._test_ratio = kwargs.get("test", 1.0 - self._train_ratio - self._val_ratio) + + if not (0 < self._train_ratio + self._val_ratio + self._test_ratio <= 1.0): + raise ValueError("Data split ratios must sum to 1.0 or less.") + + train_size = int(self._train_ratio * len(dataset)) + val_size = int(self._val_ratio * len(dataset)) + test_size = len(dataset) - train_size - val_size + self._train_dataset, self._val_dataset, self._test_dataset = random_split( + dataset, [train_size, val_size, test_size] + ) + + # Create DataLoaders + self._train_loader = DataLoader( + self._train_dataset, batch_size=self._batch_size, shuffle=True + ) + self._val_loader = DataLoader( + self._val_dataset, batch_size=self._batch_size, shuffle=False + ) + + # Epoch counter + self._epoch = 0 + + # Loss and metrics storage + self._train_losses = defaultdict(list) + self._val_losses = defaultdict(list) + self._train_metrics = defaultdict(list) + self._val_metrics = defaultdict(list) + + @abstractmethod + def train_step(self, inputs: torch.tensor, targets: torch.tensor)->Dict[str, torch.Tensor]: + """ + Abstract method for training the model on one batch + Must be implemented by subclasses. + This should be where the losses and metrics are calculated. + Should return a dictionary with loss name as key and torch tensor loss as value. + + :param inputs: The input data. + :type inputs: torch.Tensor + :param targets: The target data. + :type targets: torch.Tensor + :return: A dictionary containing the loss values for the batch. + :rtype: dict[str, torch.Tensor] + """ + pass + + @abstractmethod + def evaluate_step(self, inputs: torch.tensor, targets: torch.tensor)->Dict[str, torch.Tensor]: + """ + Abstract method for evaluating the model on one batch + Must be implemented by subclasses. + This should be where the losses and metrics are calculated. + Should return a dictionary with loss name as key and torch tensor loss as value. + + :param inputs: The input data. + :type inputs: torch.Tensor + :param targets: The target data. + :type targets: torch.Tensor + :return: A dictionary containing the loss values for the batch. + :rtype: dict[str, torch.Tensor] + """ + pass + + @abstractmethod + def train_epoch(self)->dict[str, torch.Tensor]: + """ + Can be overridden by subclasses to implement custom training logic. + Make calls to the train_step method for each batch + in the training DataLoader. + + Return a dictionary with loss name as key and + torch tensor loss as value. Multiple losses can be returned. + + :return: A dictionary containing the loss values for the epoch. + :rtype: dict[str, torch.Tensor] + """ + + pass + + @abstractmethod + def evaluate_epoch(self)->dict[str, torch.Tensor]: + """ + Can be overridden by subclasses to implement custom evaluation logic. + Should make calls to the evaluate_step method for each batch + in the validation DataLoader. + + Should return a dictionary with loss name as key and + torch tensor loss as value. Multiple losses can be returned. + + :return: A dictionary containing the loss values for the epoch. + :rtype: dict[str, torch.Tensor] + """ + + pass + + def train(self): + """ + Train the model for the specified number of epochs. + Make calls to the train epoch and evaluate epoch methods. + """ + + self.model.to(self.device) + + # callbacks + for callback in self.callbacks: + callback.on_train_start() + + for epoch in range(self.epochs): + + # Increment the epoch counter + self.epoch += 1 + + # callbacks + for callback in self.callbacks: + callback.on_epoch_start() + + # Access all the metrics and reset them + for _, metric in self.metrics.items(): + metric.reset() + + # Train the model for one epoch + train_loss = self.train_epoch() + for loss_name, loss in train_loss.items(): + self._train_losses[loss_name].append(loss) + + # Evaluate the model for one epoch + val_loss = self.evaluate_epoch() + for loss_name, loss in val_loss.items(): + self._val_losses[loss_name].append(loss) + + # Access all the metrics and compute the final epoch metric value + for metric_name, metric in self.metrics.items(): + train_metric, val_metric = metric.compute() + self._train_metrics[metric_name].append(train_metric.item()) + self._val_metrics[metric_name].append(val_metric.item()) + + # Invoke callback on epoch_end + for callback in self.callbacks: + callback.on_epoch_end() + + # Update early stopping + if self._early_termination_metric is None: + # Do not perform early stopping when no termination metric is specified + early_term_metric = None + else: + # First look for the metric in validation loss + if self._early_termination_metric in list(val_loss.keys()): + early_term_metric = val_loss[self._early_termination_metric] + # Then look for the metric in validation metrics + elif self._early_termination_metric in list(self._val_metrics.keys()): + early_term_metric = self._val_metrics[self._early_termination_metric][-1] + else: + raise ValueError("Invalid early termination metric") + + self.update_early_stop(early_term_metric) + + # Check if early stopping is needed + if self._early_termination and self.early_stop_counter >= self.patience: + print(f"Early termination at epoch {epoch + 1} with best validation metric {self._best_loss}") + break + + for callback in self.callbacks: + callback.on_train_end() + + def update_early_stop(self, val_loss: Optional[torch.Tensor]): + """ + Method to update the early stopping criterion + + :param val_loss: The loss value on the validation set + :type val_loss: torch.Tensor + """ + + # When early termination is disabled, the best model is updated with the current model + if not self._early_termination and val_loss is None: + self.best_model = self.model.state_dict().copy() + return + + if val_loss < self.best_loss: + self.best_loss = val_loss + self.early_stop_counter = 0 + self.best_model = self.model.state_dict().copy() + else: + self.early_stop_counter += 1 + + def initialize_callbacks(self, callbacks): + """ + Helper to iterate over all callbacks and set trainer property + + :param callbacks: List of callback objects that can be invoked + at epcoh start, epoch end, train start and train end + :type callbacks: Callback class or subclass or list of Callback class + """ + + if callbacks is None: + self._callbacks = [] + return + + if not isinstance(callbacks, List): + callbacks = [callbacks] + for callback in callbacks: + if not isinstance(callback, AbstractCallback): + raise TypeError("Invalid callback object type") + callback._set_trainer(self) + + self._callbacks = callbacks + + """ + Log property + """ + @property + def log(self): + """ + Returns the training and validation losses and metrics. + """ + log ={ + **{'epoch': list(range(1, self.epoch + 1))}, + **self._train_losses, + **{f'val_{key}': val for key, val in self._val_losses.items()}, + **self._train_metrics, + **{f'val_{key}': val for key, val in self._val_metrics.items()} + } + + return log + + """ + Properties for accessing various attributes of the trainer. + """ + @property + def train_ratio(self): + return self._train_ratio + + @property + def val_ratio(self): + return self._val_ratio + + @property + def test_ratio(self): + return self._test_ratio + + @property + def model(self): + return self._model + + @property + def optimizer(self): + return self._optimizer + + @property + def device(self): + return self._device + + @property + def batch_size(self): + return self._batch_size + + @property + def epochs(self): + return self._epochs + + @property + def patience(self): + return self._patience + + @property + def callbacks(self): + return self._callbacks + + @property + def best_model(self): + return self._best_model + + @property + def best_loss(self): + return self._best_loss + + @property + def early_stop_counter(self): + return self._early_stop_counter + + @property + def metrics(self): + return self._metrics + + @property + def epoch(self): + return self._epoch + + @property + def train_losses(self): + return self._train_losses + + @property + def val_losses(self): + return self._val_losses + + @property + def train_metrics(self): + return self._train_metrics + + @property + def val_metrics(self): + return self._val_metrics + + """ + Setters for best model and best loss and early stop counter + Meant to be used by the subclasses to update the best model and loss + """ + + @best_model.setter + def best_model(self, value: torch.nn.Module): + self._best_model = value + + @best_loss.setter + def best_loss(self, value): + self._best_loss = value + + @early_stop_counter.setter + def early_stop_counter(self, value: int): + self._early_stop_counter = value + + @epoch.setter + def epoch(self, value: int): + self._epoch = value + + """ + Update loss and metrics + """ + + def update_loss(self, + loss: torch.Tensor, + loss_name: str, + validation: bool = False): + if validation: + self._val_losses[loss_name].append(loss) + else: + self._train_losses[loss_name].append(loss) + + def update_metrics(self, + metric: torch.tensor, + metric_name: str, + validation: bool = False): + if validation: + self._val_metrics[metric_name].append(metric) + else: + self._train_metrics[metric_name].append(metric) + + """ + Properties for accessing the split datasets. + """ + @property + def train_dataset(self, loader=False): + """ + Returns the training dataset or DataLoader if loader=True + + :param loader: (bool) whether to return a DataLoader or the dataset + :type loader: bool + """ + if loader: + return self._train_loader + else: + return self._train_dataset + + @property + def val_dataset(self, loader=False): + """ + Returns the validation dataset or DataLoader if loader=True + + :param loader: (bool) whether to return a DataLoader or the dataset + :type loader: bool + """ + if loader: + return self._val_loader + else: + return self._val_dataset + + @property + def test_dataset(self, loader=False): + """ + Returns the test dataset or DataLoader if loader=True + Generates the DataLoader on the fly as the test data loader is not + pre-defined during object initialization + + :param loader: (bool) whether to return a DataLoader or the dataset + :type loader: bool + """ + if loader: + return DataLoader(self._test_dataset, batch_size=self._batch_size, shuffle=False) + else: + return self._test_dataset diff --git a/trainers/README.md b/trainers/README.md new file mode 100644 index 0000000..06173b4 --- /dev/null +++ b/trainers/README.md @@ -0,0 +1 @@ +Here lives the trainer class for FNet/UNet and wGaN GP. Shared components between trainers for different models are isolated into the asbstract trainer class \ No newline at end of file diff --git a/trainers/Trainer.py b/trainers/Trainer.py new file mode 100644 index 0000000..5c6411c --- /dev/null +++ b/trainers/Trainer.py @@ -0,0 +1,144 @@ +from collections import defaultdict +from typing import Optional, List, Union + +import torch +from torch.utils.data import DataLoader, random_split + +from .AbstractTrainer import AbstractTrainer + +class Trainer(AbstractTrainer): + """ + Trainer class for generator while backpropagating on single or multiple loss functions. + """ + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + backprop_loss: Union[torch.nn.Module, List[torch.nn.Module]], + **kwargs + ): + """ + Initialize the trainer with the model, optimizer and loss function. + + :param model: The model to be trained. + :type model: torch.nn.Module + :param optimizer: The optimizer to be used for training. + :type optimizer: torch.optim.Optimizer + :param backprop_loss: The loss function to be used for training or a list of loss functions. + :type backprop_loss: torch.nn.Module + """ + + super().__init__(**kwargs) + + self._model = model + self._optimizer = optimizer + self._backprop_loss = backprop_loss \ + if isinstance(backprop_loss, list) else [backprop_loss] + + """ + Overidden methods from the parent abstract class + """ + def train_step(self, inputs: torch.tensor, targets: torch.tensor): + """ + Perform a single training step on batch. + + :param inputs: The input image data batch + :type inputs: torch.tensor + :param targets: The target image data batch + :type targets: torch.tensor + """ + # move the data to the device + inputs, targets = inputs.to(self.device), targets.to(self.device) + + # set the model to train + self.model.train() + # set the optimizer gradients to zero + self.optimizer.zero_grad() + + # Forward pass + outputs = self.model(inputs) + + # Back propagate the loss + losses = {} + total_loss = torch.tensor(0.0, device=self.device) + for loss in self._backprop_loss: + losses[type(loss).__name__] = loss(outputs, targets) + total_loss += losses[type(loss).__name__] + + total_loss.backward() + self.optimizer.step() + + # Calculate the metrics outputs and update the metrics + for _, metric in self.metrics.items(): + metric.update(outputs, targets, validation=False) + + return { + key: value.item() for key, value in losses.items() + } + + def evaluate_step(self, inputs: torch.tensor, targets: torch.tensor): + """ + Perform a single evaluation step on batch. + + :param inputs: The input image data batch + :type inputs: torch.tensor + :param targets: The target image data batch + :type targets: torch.tensor + """ + # move the data to the device + inputs, targets = inputs.to(self.device), targets.to(self.device) + + # set the model to evaluation + self.model.eval() + + with torch.no_grad(): + # Forward pass + outputs = self.model(inputs) + + # calculate the loss + losses = {} + for loss in self._backprop_loss: + losses[type(loss).__name__] = loss(outputs, targets) + + # Calculate the metrics outputs and update the metrics + for _, metric in self.metrics.items(): + metric.update(outputs, targets, validation=True) + + return { + key: value.item() for key, value in losses.items() + } + + def train_epoch(self): + """ + Train the model for one epoch. + """ + self._model.train() + losses = defaultdict(list) + # Iterate over the train_loader + for inputs, targets in self._train_loader: + batch_loss = self.train_step(inputs, targets) + for key, value in batch_loss.items(): + losses[key].append(value) + + # reduce loss + return { + key: sum(value) / len(value) for key, value in losses.items() + } + + def evaluate_epoch(self): + """ + Evaluate the model for one epoch. + """ + + self._model.eval() + losses = defaultdict(list) + # Iterate over the val_loader + for inputs, targets in self._val_loader: + batch_loss = self.evaluate_step(inputs, targets) + for key, value in batch_loss.items(): + losses[key].append(value) + + # reduce loss + return { + key: sum(value) / len(value) for key, value in losses.items() + } \ No newline at end of file diff --git a/trainers/WGANTrainer.py b/trainers/WGANTrainer.py new file mode 100644 index 0000000..909617a --- /dev/null +++ b/trainers/WGANTrainer.py @@ -0,0 +1,268 @@ +from typing import Optional +from collections import defaultdict + +import torch +import torch.autograd as autograd +from torch.utils.data import DataLoader + +from .AbstractTrainer import AbstractTrainer + +class WGANTrainer(AbstractTrainer): + def __init__(self, + generator: torch.nn.Module, + discriminator: torch.nn.Module, + gen_optimizer: torch.optim.Optimizer, + disc_optimizer: torch.optim.Optimizer, + generator_loss_fn: torch.nn.Module, + discriminator_loss_fn: torch.nn.Module, + gradient_penalty_fn: Optional[torch.nn.Module]=None, + discriminator_update_freq: int=1, + generator_update_freq: int=5, + # rest of the arguments are passed to and handled by the parent class + # - dataset + # - batch_size + # - epochs + # - patience + # - callbacks + # - metrics + **kwargs): + """ + Initializes the WGaN Trainer class. + + :param generator: The image2image generator model (e.g., UNet) + :type generator: torch.nn.Module + :param discriminator: The discriminator model + :type discriminator: torch.nn.Module + :param gen_optimizer: Generator optimizer + :type gen_optimizer: torch.optim.Optimizer + :param disc_optimizer: Discriminator optimizer + :type disc_optimizer: torch.optim.Optimizer + :param generator_loss_fn: Generator loss function + :type generator_loss_fn: torch.nn.Module + :param discriminator_loss_fn: Adverserial loss function + :type discriminator_loss_fn: torch.nn.Module + :param gradient_penalty_fn: (optional) Gradient penalty loss function + :type gradient_penalty_fn: torch.nn.Module + :param discriminator_update_freq: How frequently to update the discriminator + :type discriminator_update_freq: int + :param generator_update_freq: How frequently to update the generator + :type generator_update_freq: int + :param kwargs: Additional arguments passed to the AbstractTrainer + :type kwargs: dict + """ + super().__init__(**kwargs) + + # Validate update frequencies + if discriminator_update_freq > 1 and generator_update_freq > 1: + raise ValueError( + "Both discriminator_update_freq and generator_update_freq cannot be greater than 1. " + "At least one network must update every epoch." + ) + + self._generator = generator + self._discriminator = discriminator + self._gen_optimizer = gen_optimizer + self._disc_optimizer = disc_optimizer + self._generator_loss_fn = generator_loss_fn + self._generator_loss_fn.trainer = self + self._discriminator_loss_fn = discriminator_loss_fn + self._discriminator_loss_fn.trainer = self + self._gradient_penalty_fn = gradient_penalty_fn + if self._gradient_penalty_fn is not None: + self._gradient_penalty_fn.trainer = self + + # Global step counter and update frequencies + self._discriminator_update_freq = discriminator_update_freq + self._generator_update_freq = generator_update_freq + + # Initialize the last losses memory to zero + self._last_discriminator_loss = torch.tensor(0.0, device=self.device).detach() + self._last_gradient_penalty_loss = torch.tensor(0.0, device=self.device).detach() + self._last_generator_loss = torch.tensor(0.0, device=self.device).detach() + + def train_step(self, + inputs: torch.tensor, + targets: torch.tensor + ): + """ + Perform a single training step on batch. + + :param inputs: The input image data batch + :type inputs: torch.tensor + :param targets: The target image data batch + :type targets: torch.tensor + """ + inputs, targets = inputs.to(self.device), targets.to(self.device) + + gp_loss = torch.tensor(0.0, device=self.device) + + # foward pass to generate image (shared by both updates) + generated_images = self._generator(inputs) + + # Train Discriminator + if self.epoch % self._discriminator_update_freq == 0: + self._disc_optimizer.zero_grad() + + real_images = targets + + # Concatenate input channel and real/generated image channels along the + # channel dimension to feed full stacked multi-channel images to the discriminator + real_input_pair = torch.cat((real_images, inputs), 1) + generated_input_pair = torch.cat((generated_images.detach(), inputs), 1) + + discriminator_real_score = self._discriminator(real_input_pair).mean() + discriminator_fake_score = self._discriminator(generated_input_pair).mean() + + # Adverserial loss + discriminator_loss = self._discriminator_loss_fn(discriminator_real_score, discriminator_fake_score) + + # Compute Gradient penalty loss if fn is supplied + if self._gradient_penalty_fn is not None: + gp_loss = self._gradient_penalty_fn(real_input_pair, generated_input_pair) + + total_discriminator_loss = discriminator_loss + gp_loss + total_discriminator_loss.backward() + self._disc_optimizer.step() + + # memorize current discriminator loss until next discriminator update + self._last_discriminator_loss = discriminator_loss.detach() + self._last_gradient_penalty_loss = gp_loss + else: + # when not being updated, use the loss from previus update + discriminator_loss = self._last_discriminator_loss + gp_loss = self._last_gradient_penalty_loss + + # Train Generator + if self.epoch % self._generator_update_freq == 0: + self._gen_optimizer.zero_grad() + + discriminator_fake_score = self._discriminator(torch.cat((generated_images, inputs), 1)).mean() + generator_loss = self._generator_loss_fn(discriminator_fake_score, real_images, generated_images, self.epoch) + generator_loss.backward() + self._gen_optimizer.step() + + # memorize current generator loss until next generator update + self._last_generator_loss = generator_loss.detach() + else: + # when not being updated, set the loss to zero + generator_loss = self._last_generator_loss + + for _, metric in self.metrics.items(): + ## TODO: centralize the update of metrics + # compute the generated fake targets regardless for use with metrics + generated_images = self._generator(inputs).detach() + metric.update(generated_images, targets, validation=False) + ## After each batch -> after each epoch + + loss = {type(self._discriminator_loss_fn).__name__: discriminator_loss.item(), + type(self._generator_loss_fn).__name__: generator_loss.item()} + if self._gradient_penalty_fn is not None: + loss = { + **loss, + **{type(self._gradient_penalty_fn).__name__: gp_loss.item()} + } + + return loss + + def evaluate_step(self, + inputs: torch.tensor, + targets: torch.tensor + ): + """ + Perform a single evaluation step on batch. + + :param inputs: The input image data batch + :type inputs: torch.tensor + :param targets: The target image data batch + :type targets: torch.tensor + """ + inputs, targets = inputs.to(self.device), targets.to(self.device) + + self._generator.eval() + self._discriminator.eval() + with torch.no_grad(): + + real_images = targets + generated_images = self._generator(inputs) + + # Concatenate input channel and real/generated image channels along the + # channel dimension to feed full stacked multi-channel images to the discriminator + real_input_pair = torch.cat((real_images, inputs), 1) + generated_input_pair = torch.cat((generated_images, inputs), 1) + + discriminator_real_score = self._discriminator(real_input_pair).mean() + discriminator_fake_score = self._discriminator(generated_input_pair).mean() + + # Compute losses + discriminator_loss = self._discriminator_loss_fn(discriminator_real_score, discriminator_fake_score) + + ## Declare an empty tensor for the gradient penalty loss as + # it is not useful during evaluation + gp_loss = torch.tensor(0.0, device=self.device) + + generator_loss = self._generator_loss_fn(discriminator_fake_score, generated_images, real_images, self.epoch) + + for _, metric in self.metrics.items(): + metric.update(generated_images, targets, validation=True) + + loss = {type(self._discriminator_loss_fn).__name__: discriminator_loss.item(), + type(self._generator_loss_fn).__name__: generator_loss.item()} + if self._gradient_penalty_fn is not None: + loss = { + **loss, + **{type(self._gradient_penalty_fn).__name__: gp_loss.item()} + } + + return loss + + def train_epoch(self): + + self._generator.train() + self._discriminator.train() + + epoch_losses = defaultdict(list) + for inputs, targets in self._train_loader: + losses = self.train_step(inputs, targets) + for key, value in losses.items(): + epoch_losses[key].append(value) + + for key, _ in epoch_losses.items(): + epoch_losses[key] = sum(epoch_losses[key])/len(self._train_loader) + + return epoch_losses + + def evaluate_epoch(self): + + self._generator.eval() + self._discriminator.eval() + + epoch_losses = defaultdict(list) + for inputs, targets in self._val_loader: + losses = self.evaluate_step(inputs, targets) + for key, value in losses.items(): + epoch_losses[key].append(value) + + for key, _ in epoch_losses.items(): + epoch_losses[key] = sum(epoch_losses[key])/len(self._val_loader) + + return epoch_losses + + def train(self): + + self._discriminator.to(self.device) + + super().train() + + @property + def model(self) -> torch.nn.Module: + """ + return the generator + """ + return self._generator + + @property + def discriminator(self) -> torch.nn.Module: + """ + returns the discriminator + """ + return self._discriminator \ No newline at end of file diff --git a/transforms/MinMaxNormalize.py b/transforms/MinMaxNormalize.py new file mode 100644 index 0000000..254f9a2 --- /dev/null +++ b/transforms/MinMaxNormalize.py @@ -0,0 +1,60 @@ +from albumentations import ImageOnlyTransform +import numpy as np + +""" +Adapted from https://github.com/WayScience/nuclear_speckles_analysis +""" +class MinMaxNormalize(ImageOnlyTransform): + """Min-Max normalize each image""" + + def __init__(self, + _normalization_factor: float, + _always_apply: bool=False, + _p: float=0.5): + """ + Initializes the MinMaxNormalize transform. + + :param _normalization_factor: The factor by which to normalize the image. + :type _normalization_factor: float + :param _always_apply: If True, always apply this transformation. + :type _always_apply: bool + :param _p: Probability of applying this transformation. + :type _p: float + """ + super(MinMaxNormalize, self).__init__(_always_apply, _p) + self.__normalization_factor = _normalization_factor + + @property + def normalization_factor(self): + return self.__normalization_factor + + def apply(self, _img, **kwargs): + """ + Apply min-max normalization to the image. + + :param _img: Input image as a numpy array. + :type _img: np.ndarray + :return: Min-max normalized image. + :rtype: np.ndarray + :raises TypeError: If the input image is not a numpy array. + """ + if isinstance(_img, np.ndarray): + return _img / self.normalization_factor + + else: + raise TypeError("Unsupported image type for transform (Should be a numpy array)") + + def invert(self, _img, **kwargs): + """ + Invert the min-max normalization. + + :param _img: Input image as a numpy array. + :type _img: np.ndarray + :return: Inverted image. + :rtype: np.ndarray + :raises TypeError: If the input image is not a numpy array. + """ + if isinstance(_img, np.ndarray): + return _img * self.normalization_factor + else: + raise TypeError("Unsupported image type for transform (Should be a numpy array)") diff --git a/transforms/PixelDepthTransform.py b/transforms/PixelDepthTransform.py new file mode 100644 index 0000000..29ef444 --- /dev/null +++ b/transforms/PixelDepthTransform.py @@ -0,0 +1,85 @@ +from albumentations import ImageOnlyTransform +import numpy as np + +class PixelDepthTransform(ImageOnlyTransform): + """ + Transform to convert images from a specified bit depth to another bit depth (e.g., 16-bit to 8-bit). + Automatically scales pixel values up or down to the target bit depth. + The only supported bit depths are 8, 16, and 32. + """ + + def __init__(self, + src_bit_depth: int = 16, + target_bit_depth: int = 8, + _always_apply: bool = True, + _p: float = 1.0): + """ + Initializes the PixelDepthTransform. + + :param src_bit_depth: Bit depth of the input image (e.g., 16 for 16-bit). + :type src_bit_depth: int + :param target_bit_depth: Bit depth to scale the image to (e.g., 8 for 8-bit). + :type target_bit_depth: int + :param _always_apply: Whether to always apply the transform. + :type _always_apply: bool + :param _p: Probability of applying the transform. + :type _p: float + :raises ValueError: If the source or target bit depth is not supported. + """ + if src_bit_depth not in [8, 16, 32]: + raise ValueError("Unsupported source bit depth (should be 8 or 16)") + if target_bit_depth not in [8, 16, 32]: + raise ValueError("Unsupported target bit depth (should be 8 or 16)") + + super(PixelDepthTransform, self).__init__(_always_apply, _p) + self.src_bit_depth = src_bit_depth + self.target_bit_depth = target_bit_depth + + def apply(self, img, **kwargs): + """ + Apply the bit depth transformation. + + :param img: Input image as a numpy array. + :type img: np.ndarray + :return: Transformed image scaled to the target bit depth. + :rtype: np.ndarray + :raises TypeError: If the input image is not a numpy array. + """ + if not isinstance(img, np.ndarray): + raise TypeError("Unsupported image type for transform (should be a numpy array)") + + # Maximum pixel value based on source and target bit depth + src_max_val = (2 ** self.src_bit_depth) - 1 + target_max_val = (2 ** self.target_bit_depth) - 1 + + if self.target_bit_depth == 32: + # Scale to the 32-bit integer range + return ((img / src_max_val) * target_max_val).astype(np.uint32) + else: + # Standard conversion for 8-bit or 16-bit integers + return ((img / src_max_val) * target_max_val).astype( + np.uint8 if self.target_bit_depth == 8 else np.uint16 + ) + + def invert(self, img, **kwargs): + """ + Optionally invert the bit depth transformation (useful for debugging or preprocessing). + + :param img: Transformed image as a numpy array. + :type img: np.ndarray + :return: Image restored to the original bit depth. + :rtype: np.ndarray + :raises TypeError: If the input image is not a numpy array. + """ + if not isinstance(img, np.ndarray): + raise TypeError("Unsupported image type for inversion (should be a numpy array)") + + target_max_val = (2 ** self.target_bit_depth) - 1 + src_max_val = (2 ** self.src_bit_depth) - 1 + + # Invert scaling back to original bit depth + img = (img / target_max_val) * src_max_val + return img.astype(np.uint16) if self.src_bit_depth == 16 else img + + def __repr__(self): + return f"PixelDepthTransform(src_bit_depth={self.src_bit_depth}, target_bit_depth={self.target_bit_depth})" \ No newline at end of file diff --git a/transforms/README.md b/transforms/README.md new file mode 100644 index 0000000..d3a5f1c --- /dev/null +++ b/transforms/README.md @@ -0,0 +1,3 @@ +Here lives the image transform class. + +For now all of these transforms are just used as input normalization. They all have an invert function to faciliate visualization. \ No newline at end of file diff --git a/transforms/ZScoreNormalize.py b/transforms/ZScoreNormalize.py new file mode 100644 index 0000000..cf91acb --- /dev/null +++ b/transforms/ZScoreNormalize.py @@ -0,0 +1,87 @@ +from albumentations import ImageOnlyTransform +import numpy as np + +""" +Wrote this to get z score normalizae to work with albumentations +""" +class ZScoreNormalize(ImageOnlyTransform): + """Z-score normalize each image""" + + def __init__(self, _mean=None, _std=None, _always_apply=False, _p=0.5): + """ + Initializes the ZScoreNormalize transform. + + :param _mean: Precomputed mean for normalization (optional). If None, compute per-image mean. + :type _mean: float, optional + :param _std: Precomputed standard deviation for normalization (optional). If None, compute per-image std. + :type _std: float, optional + :param _always_apply: If True, always apply this transformation. + :type _always_apply: bool + :param _p: Probability of applying this transformation. + :type _p: float + """ + super(ZScoreNormalize, self).__init__(_always_apply, _p) + self.__mean = _mean + self.__std = _std + + @property + def mean(self): + return self.__mean + + @property + def std(self): + return self.__std + + def apply(self, _img, **params): + """ + Apply z-score normalization to the image. + + :param _img: Input image as a numpy array. + :type _img: np.ndarray + :return: Z-score normalized image. + :rtype: np.ndarray + :raises TypeError: If the input image is not a numpy array. + :raises ValueError: If the standard deviation is zero. + """ + if not isinstance(_img, np.ndarray): + raise TypeError("Unsupported image type for transform (Should be a numpy array)") + + mean = self.__mean if self.__mean is not None else _img.mean() + std = self.__std if self.__std is not None else _img.std() + + if std == 0: + raise ValueError("Standard deviation is zero; cannot perform z-score normalization.") + + return (_img - mean) / std + + def invert(self, _img, **kwargs): + """ + Invert the z-score normalization. + If this transform is applied on image basis (without global mean and std) + Will simply return the z score transformed image back + + :param _img: Input image as a numpy array. + :type _img: np.ndarray + :return: Inverted image. + :rtype: np.ndarray + :raises TypeError: If the input image is not a numpy array. + :raises ValueError: If the standard deviation is zero. + """ + if not isinstance(_img, np.ndarray): + raise TypeError("Unsupported image type for transform (Should be a numpy array)") + + if self.__mean is None or self.__std is None: + mean = kwargs.get("mean", None) + std = kwargs.get("std", None) + if mean is None or std is None: + return _img + else: + return (_img * std) + mean + else: + mean = self.__mean if self.__mean is not None else _img.mean() + std = self.__std if self.__std is not None else _img.std() + + if std == 0: + raise ValueError("Standard deviation is zero; cannot perform z-score normalization.") + + return (_img * std) + mean