diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..2201fd5
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,6 @@
+# images and anything under mlflow
+*.png
+examples/example_train*/*
+
+# pycache
+*.pyc
\ No newline at end of file
diff --git a/callbacks/AbstractCallback.py b/callbacks/AbstractCallback.py
new file mode 100644
index 0000000..758d53e
--- /dev/null
+++ b/callbacks/AbstractCallback.py
@@ -0,0 +1,62 @@
+from abc import ABC
+
+class AbstractCallback(ABC):
+ """
+ Abstract class for callbacks in the training process.
+ Callbacks can be used to plot intermediate metrics, log contents, save checkpoints, etc.
+ """
+
+ def __init__(self, name: str):
+ """
+ :param name: Name of the callback.
+ """
+ self._name = name
+ self._trainer = None
+
+ @property
+ def name(self):
+ """
+ Getter for callback name
+ """
+ return self._name
+
+ @property
+ def trainer(self):
+ """
+ Allows for access of trainer
+ """
+ return self._trainer
+
+ def _set_trainer(self, trainer):
+ """
+ Helper function called by trainer class to initialize trainer value field
+
+ :param trainer: trainer object
+ :type trainer: AbstractTrainer or subclass
+ """
+
+ self._trainer = trainer
+
+ def on_train_start(self):
+ """
+ Called at the start of training.
+ """
+ pass
+
+ def on_epoch_start(self):
+ """
+ Called at the start of each epoch.
+ """
+ pass
+
+ def on_epoch_end(self):
+ """
+ Called at the end of each epoch.
+ """
+ pass
+
+ def on_train_end(self):
+ """
+ Called at the end of training.
+ """
+ pass
\ No newline at end of file
diff --git a/callbacks/IntermediatePlot.py b/callbacks/IntermediatePlot.py
new file mode 100644
index 0000000..ad2c9ea
--- /dev/null
+++ b/callbacks/IntermediatePlot.py
@@ -0,0 +1,117 @@
+import pathlib
+from typing import List, Union
+import random
+
+import torch
+import torch.nn as nn
+from torch.utils.data import Dataset
+
+from .AbstractCallback import AbstractCallback
+from ..datasets.PatchDataset import PatchDataset
+from ..evaluation.visualization_utils import plot_predictions_grid_from_model
+
+class IntermediatePlot(AbstractCallback):
+ """
+ Callback to plot model generated outputs, ground
+ truth, and input stained image patches at the end of each epoch.
+ """
+
+ def __init__(self,
+ name: str,
+ path: Union[pathlib.Path, str],
+ dataset: Union[Dataset, PatchDataset],
+ plot_n_patches: int=5,
+ indices: Union[List[int], None]=None,
+ plot_metrics: List[nn.Module]=None,
+ every_n_epochs: int=5,
+ random_seed: int=42,
+ **kwargs):
+ """
+ Initialize the IntermediatePlot callback.
+ Allows plots of predictions to be generated during training for monitoring of training progress.
+ Supports both PatchDataset and Dataset classes for plotting.
+ This callback, when passed into the trainer, will plot the model predictions on a subset of the provided dataset at the end of each epoch.
+
+ :param name: Name of the callback.
+ :param path: Path to save the model weights.
+ :type path: Union[pathlib.Path, str]
+ :param dataset: Dataset to be used for plotting intermediate results.
+ :type dataset: Union[Dataset, PatchDataset]
+ :param plot_n_patches: Number of patches to randomly select and plot, defaults to 5.
+ The exact patches/images being plotted may vary due to a difference in seed or dataset size.
+ To ensure best reproducibility and consistency, please use a fixed dataset and indices argument instead.
+ :type plot_n_patches: int, optional
+ :param indices: Optional list of specific indices to subset the dataset before inference.
+ Overrides the plot_n_patches and random_seed arguments and uses the indices list to subset.
+ :type indices: Union[List[int], None]
+ :param plot_metrics: List of metrics to compute and display in plot title, defaults to None.
+ :type plot_metrics: List[nn.Module], optional
+ :param kwargs: Additional keyword arguments to be passed to plot_patches.
+ :type kwargs: dict
+ :param every_n_epochs: How frequent should intermediate plots should be plotted, defaults to 5
+ :type every_n_epochs: int
+ :param random_seed: Random seed for reproducibility for random patch/image selection, defaults to 42.
+ :type random_seed: int
+ :raises TypeError: If the dataset is not an instance of PatchDataset.
+ """
+ super().__init__(name)
+ self._path = path
+ if isinstance(dataset, Dataset):
+ pass
+ elif isinstance(dataset, PatchDataset):
+ pass
+ else:
+ raise TypeError(f"Expected PatchDataset, got {type(dataset)}")
+
+ self._dataset = dataset
+
+ # Additional kwargs passed to plot_patches
+ self.plot_metrics = plot_metrics
+ self.every_n_epochs = every_n_epochs
+ self.plot_kwargs = kwargs
+
+ if indices is not None:
+ # Check if indices are within bounds
+ for i in indices:
+ if i >= len(self._dataset):
+ raise ValueError(f"Index {i} out of bounds for dataset of size {len(self._dataset)}")
+ self._dataset_subset_indices = indices
+ else:
+ # Generate random indices to subset given seed and plot_n_patches
+ plot_n_patches = min(plot_n_patches, len(self._dataset))
+ random.seed(random_seed)
+ self._dataset_subset_indices = random.sample(range(len(self._dataset)), plot_n_patches)
+
+ def on_epoch_end(self):
+ """
+ Called at the end of each epoch to plot predictions if the epoch is a multiple of `every_n_epochs`.
+ """
+ if (self.trainer.epoch + 1) % self.every_n_epochs == 0 or self.trainer.epoch + 1 == self.trainer.epoch:
+ self._plot()
+
+ def on_train_end(self):
+ """
+ Called at the end of training. Plots if not already done in the last epoch.
+ """
+ if (self.trainer.epoch + 1) % self.every_n_epochs != 0:
+ self._plot()
+
+ def _plot(self):
+ """
+ Helper method to generate and save plots.
+ Plot dataset with model predictions on n random images from dataset at the end of each epoch.
+ Called by the on_epoch_end and on_train_end methods
+ """
+
+ original_device = next(self.trainer.model.parameters()).device
+
+ plot_predictions_grid_from_model(
+ model=self.trainer.model,
+ dataset=self._dataset,
+ indices=self._dataset_subset_indices,
+ metrics=self.plot_metrics,
+ save_path=f"{self._path}/epoch_{self.trainer.epoch}.png",
+ device=original_device,
+ show=False,
+ **self.plot_kwargs
+ )
\ No newline at end of file
diff --git a/callbacks/MlflowLogger.py b/callbacks/MlflowLogger.py
new file mode 100644
index 0000000..7b7e1a4
--- /dev/null
+++ b/callbacks/MlflowLogger.py
@@ -0,0 +1,115 @@
+import os
+import pathlib
+import tempfile
+from typing import Union, Dict, Optional
+
+import mlflow
+import torch
+
+from .AbstractCallback import AbstractCallback
+
+class MlflowLogger(AbstractCallback):
+ """
+ Callback to log metrics to MLflow.
+ """
+
+ def __init__(self,
+
+ name: str,
+ artifact_name: str = 'best_model_weights.pth',
+ mlflow_uri: Union[pathlib.Path, str] = None,
+ mlflow_experiment_name: Optional[str] = None,
+ mlflow_start_run_args: dict = None,
+ mlflow_log_params_args: dict = None,
+
+ ):
+ """
+ Initialize the MlflowLogger callback.
+
+ :param name: Name of the callback.
+ :param artifact_name: Name of the artifact file to log, defaults to 'best_model_weights.pth'.
+ :param mlflow_uri: URI for the MLflow tracking server, defaults to None.
+ If a path is specified, the logger class will call set_tracking_uri to that supplied path
+ thereby initiating a new tracking server.
+ If None (default), the logger class will not tamper with mlflow server to enable logging to a global server
+ initialized outside of this class.
+ :type mlflow_uri: pathlib.Path or str, optional
+ :param mlflow_experiment_name: Name of the MLflow experiment, defaults to None, which will not call the
+ set_experiment method of mlflow and will use whichever experiment name that is globally configured. If a
+ name is provided, the logger class will call set_experiment to that supplied name.
+ :type mlflow_experiment_name: str, optional
+ :param mlflow_start_run_args: Additional arguments for starting an MLflow run, defaults to None.
+ :type mlflow_start_run_args: dict, optional
+ :param mlflow_log_params_args: Additional arguments for logging parameters to MLflow, defaults to None.
+ :type mlflow_log_params_args: dict, optional
+ """
+ super().__init__(name)
+
+ if mlflow_uri is not None:
+ try:
+ mlflow.set_tracking_uri(mlflow_uri)
+ except Exception as e:
+ raise RuntimeError(f"Error setting MLflow tracking URI: {e}")
+
+ if mlflow_experiment_name is not None:
+ try:
+ mlflow.set_experiment(mlflow_experiment_name)
+ except Exception as e:
+ raise RuntimeError(f"Error setting MLflow experiment: {e}")
+
+ self._artifact_name = artifact_name
+ self._mlflow_start_run_args = mlflow_start_run_args
+ self._mlflow_log_params_args = mlflow_log_params_args
+
+ def on_train_start(self):
+ """
+ Called at the start of training.
+
+ Calls mlflow start run and logs params if provided
+ """
+
+ if self._mlflow_start_run_args is None:
+ pass
+ elif isinstance(self._mlflow_start_run_args, Dict):
+ mlflow.start_run(
+ **self._mlflow_start_run_args
+ )
+ else:
+ raise TypeError("mlflow_start_run_args must be None or a dictionary.")
+
+ if self._mlflow_log_params_args is None:
+ pass
+ elif isinstance(self._mlflow_log_params_args, Dict):
+ mlflow.log_params(
+ self._mlflow_log_params_args
+ )
+ else:
+ raise TypeError("mlflow_log_params_args must be None or a dictionary.")
+
+ def on_epoch_end(self):
+ """
+ Called at the end of each epoch.
+
+ Iterate over the most recent log items in trainer and call mlflow log metric
+ """
+ for key, values in self.trainer.log.items():
+ if values is not None and len(values) > 0:
+ value = values[-1]
+ else:
+ value = None
+ mlflow.log_metric(key, value, step=self.trainer.epoch)
+
+ def on_train_end(self):
+ """
+ Called at the end of training.
+
+ Saves trainer best model to a temporary directory and calls mlflow log artifact
+ Then ends run
+ """
+ # Save weights to a temporary directory and log artifacts
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ weights_path = os.path.join(tmpdirname, self._artifact_name)
+ torch.save(self.trainer.best_model, weights_path)
+ mlflow.log_artifact(weights_path, artifact_path="models")
+
+ mlflow.end_run()
\ No newline at end of file
diff --git a/callbacks/README.md b/callbacks/README.md
new file mode 100644
index 0000000..d776cf7
--- /dev/null
+++ b/callbacks/README.md
@@ -0,0 +1,3 @@
+Here lives the callback classes that are meant to be fed into trainers to do stuff like saving images every epoch and logging.
+
+The callback classes must inherit the abstract class.
\ No newline at end of file
diff --git a/cp_gan_env.yml b/cp_gan_env.yml
new file mode 100644
index 0000000..4e68d25
--- /dev/null
+++ b/cp_gan_env.yml
@@ -0,0 +1,39 @@
+name: cp_gan_env
+channels:
+ - anaconda
+ - pytorch
+ - nvidia
+ - conda-forge
+dependencies:
+ - conda-forge::python=3.9
+ - conda-forge::pip
+ - pytorch::pytorch
+ - pytorch::torchvision
+ - pytorch::torchaudio
+ - pytorch::pytorch-cuda=12.1
+ - conda-forge::seaborn
+ - conda-forge::matplotlib
+ - conda-forge::jupyter
+ - conda-forge::pre_commit
+ - conda-forge::pandas
+ - conda-forge::pillow
+ - conda-forge::numpy
+ - conda-forge::pathlib2
+ - conda-forge::scikit-learn
+ - conda-forge::opencv
+ - conda-forge::pyarrow
+ - conda-forge::ipython
+ - conda-forge::notebook
+ - conda-forge::albumentations
+ - conda-forge::optuna
+ - conda-forge::mysqlclient
+ - conda-forge::openjdk
+ - conda-forge::gtk2
+ - conda-forge::typing-extensions
+ - conda-forge::Jinja2
+ - conda-forge::inflect
+ - conda-forge::wxpython
+ - conda-forge::sentry-sdk
+ - pip:
+ - mlflow
+ - cellprofiler==4.2.8
diff --git a/datasets/CachedDataset.py b/datasets/CachedDataset.py
new file mode 100644
index 0000000..04bc237
--- /dev/null
+++ b/datasets/CachedDataset.py
@@ -0,0 +1,220 @@
+from typing import Optional
+
+from torch.utils.data import Dataset
+from collections import OrderedDict
+
+class CachedDataset(Dataset):
+ """
+ A patched dataset that caches data from dataset objects that
+ dynamically loads the data to reduce memory overhead during training
+ """
+
+ def __init__(
+ self,
+ dataset: Dataset,
+ cache_size: Optional[int]=None,
+ prefill_cache: bool=False,
+ **kwargs
+ ):
+ """
+ Initialize the CachedDataset from a dataset object
+
+ :param dataset: Dataset object to cache data from
+ :type dataset: Dataset
+ :param cache_size: Size of the cache, if None, the cache
+ size is set to the length of the dataset.
+ :type cache_size: int
+ :param prefill_cache: Whether to prefill the cache
+ :type prefill_cache: bool
+ """
+
+ if len(dataset) == 0:
+ raise ValueError("Dataset is empty")
+
+ self.__dataset = dataset
+
+ self.__cache_size = cache_size if cache_size is not None else len(dataset)
+ self.__cache = OrderedDict()
+
+ # cache for metadata
+ self.__cache_input_names = OrderedDict()
+ self.__cache_target_names = OrderedDict()
+
+ # pointer to the current patch index
+ self._current_idx = None
+
+ if prefill_cache:
+ self.populate_cache()
+
+ """Overriden methods for Dataset class"""
+ def __len__(self):
+ """
+ Return the length of the dataset
+ """
+ return len(self.__dataset)
+
+ def __getitem__(self, _idx: int):
+ """
+ Get the data from the dataset object at the given index
+ If the data is not in the cache, load it from the dataset object and update the cache
+
+ :param _idx: Index of the data to get
+ :type _idx: int
+ """
+ self._current_idx = _idx
+
+ if _idx in self.__cache:
+ # cache hit
+ return self.__cache[_idx]
+ else:
+ # cache miss, load from parent class method dynamically
+ self._push_cache(_idx)
+ return self.__cache[_idx]
+
+ """Setters"""
+
+ def set_cache_size(self, cache_size: int):
+ """
+ Set the cache size. Does not automatically repopulate the cache but
+ will pop the cache if the size is exceeded
+
+ :param cache_size: Size of the cache
+ :type cache_size: int
+ """
+ self.__cache_size = cache_size
+ # pop the cache if the size is exceeded
+ while len(self.__cache) > self.__cache_size:
+ self._pop_cache()
+
+ """Properties to remain accessible"""
+ @property
+ def input_names(self):
+ """
+ Get the input names from the dataset object
+ """
+ if self._current_idx is not None:
+ if self._current_idx in self.__cache_input_names:
+ return self.__cache_input_names[self._current_idx]
+ else:
+ _ = self.__dataset[self._current_idx]
+ return self.__dataset.input_names
+ else:
+ raise ValueError("No current index set")
+
+ @property
+ def target_names(self):
+ """
+ Get the target names from the dataset object
+ """
+ if self._current_idx is not None:
+ ## TODO: need to think over if this is at all necessary
+ if self._current_idx in self.__cache_target_names:
+ return self.__cache_target_names[self._current_idx]
+ else:
+ _ = self.__dataset[self._current_idx]
+ return self.__dataset.target_names
+ else:
+ raise ValueError("No current index set")
+
+ @property
+ def input_channel_keys(self):
+ """
+ Get the input channel keys from the dataset object
+ """
+ try:
+ return self.__dataset.input_channel_keys
+ except AttributeError:
+ return None
+
+ @property
+ def target_channel_keys(self):
+ """
+ Get the target channel keys from the dataset object
+ """
+ try:
+ return self.__dataset.target_channel_keys
+ except AttributeError:
+ return None
+
+ @property
+ def input_transform(self):
+ """
+ Get the input transform from the dataset object
+ """
+ return self.__dataset.input_transform
+
+ @property
+ def target_transform(self):
+ """
+ Get the target transform from the dataset object
+ """
+ return self.__dataset.target_transform
+
+ @property
+ def dataset(self):
+ """
+ Get the dataset object
+ """
+ return self.__dataset
+
+ """Cache method"""
+ def populate_cache(self):
+ """
+ Populates/clears the current cache and re-populate the cache with data from the dataset object
+ Iteratively calls the _push_cache method on a sequence of indices
+ """
+ self._clear_cache()
+ for _idx in range(min(self.__cache_size, len(self.__dataset))):
+ self._push_cache(_idx)
+
+ """Internal helper methods"""
+
+ def _push_cache(self, _idx: int):
+ """
+ Update the cache with a single item retrieved from the dataset object.
+ Calls the update cache metadata method as well to sync data and metadata
+ Pops the cache if the cache size is exceeded on a first in, first out basis
+
+ :param _idx: Index of the data to cache
+ :type _idx: int
+ """
+ self._current_idx = _idx
+ self.__cache[_idx] = self.__dataset[_idx]
+ if len(self.__cache) >= self.__cache_size:
+ self._pop_cache()
+ self._push_cache_metadata(_idx)
+
+ def _pop_cache(self):
+ """
+ Helper method to pop the cache on a first in, first out basis
+ """
+ self.__cache.popitem(last=False)
+
+ def _push_cache_metadata(self, _idx: int):
+ """
+ Update the cache metadata with data from the dataset object
+ Meant to be called by _push_cache method
+
+ :param _idx: Index of the data to cache
+ :type _idx: int
+ """
+ self.__cache_input_names[_idx] = self.__dataset.input_names
+ self.__cache_target_names[_idx] = self.__dataset.target_names
+
+ if len(self.__cache_input_names) >= self.__cache_size:
+ self._pop_cache_metadata()
+
+ def _pop_cache_metadata(self):
+ """
+ Helper method to pop the cache metadata on a first in, first out basis
+ """
+ self.__cache_input_names.popitem(last=False)
+ self.__cache_target_names.popitem(last=False)
+
+ def _clear_cache(self):
+ """
+ Clear the cache and cache metadata
+ """
+ self.__cache.clear()
+ self.__cache_input_names.clear()
+ self.__cache_target_names.clear()
\ No newline at end of file
diff --git a/datasets/GenericImageDataset.py b/datasets/GenericImageDataset.py
new file mode 100644
index 0000000..71713f7
--- /dev/null
+++ b/datasets/GenericImageDataset.py
@@ -0,0 +1,299 @@
+import logging
+import pathlib
+import re
+from collections import defaultdict
+from typing import List, Optional, Union, Tuple, Dict
+
+import numpy as np
+import torch
+from PIL import Image
+from albumentations import ImageOnlyTransform
+from albumentations.core.composition import Compose
+from torch.utils.data import Dataset
+
+
+class GenericImageDataset(Dataset):
+ """
+ A generic image dataset that automatically associates images under a supplied path
+ with sites and channels based on two separate regex patterns for site and channel detection.
+ """
+
+ def __init__(
+ self,
+ image_dir: Union[str, pathlib.Path],
+ site_pattern: str,
+ channel_pattern: str,
+ _input_channel_keys: Optional[Union[str, List[str]]] = None,
+ _target_channel_keys: Optional[Union[str, List[str]]] = None,
+ _input_transform: Optional[Union[Compose, ImageOnlyTransform]] = None,
+ _target_transform: Optional[Union[Compose, ImageOnlyTransform]] = None,
+ _PIL_image_mode: str = 'I;16',
+ verbose: bool = False,
+ check_exists: bool = True,
+ **kwargs
+ ):
+ """
+ Initialize the dataset.
+
+ :param image_dir: Directory containing the images.
+ :param site_pattern: Regex pattern to extract site identifiers.
+ :param channel_pattern: Regex pattern to extract channel identifiers.
+ :param _input_channel_keys: List of channel names to use as inputs.
+ :param _target_channel_keys: List of channel names to use as targets.
+ :param _input_transform: Transformations to apply to input images.
+ :param _target_transform: Transformations to apply to target images.
+ :param _PIL_image_mode: Mode for loading images.
+ :param check_exists: Whether to check if all referenced image files exist.
+ """
+
+ self._initialize_logger(verbose)
+ self.image_dir = pathlib.Path(image_dir).resolve()
+ self.site_pattern = re.compile(site_pattern)
+ self.channel_pattern = re.compile(channel_pattern)
+ self._PIL_image_mode = _PIL_image_mode
+
+ if not self.image_dir.exists():
+ raise FileNotFoundError(f"Image directory {self.image_dir} not found")
+
+ # Parse images and organize by site
+ self._channel_keys = []
+ self.__image_paths = self._get_image_paths(check_exists)
+
+ # Set input and target channel keys
+ self._input_channel_keys = self.__check_channel_keys(_input_channel_keys)
+ self._target_channel_keys = self.__check_channel_keys(_target_channel_keys)
+
+ self.set_input_transform(_input_transform)
+ self.set_target_transform(_target_transform)
+
+ # Index patches and images
+ self.__iter_image_id = list(range(len(self.__image_paths)))
+
+ # Initialize cache
+ self.__input_cache = {}
+ self.__target_cache = {}
+ self.__cache_image_id = None
+
+ # Initialize the current input and target names
+ self.__current_input_names = None
+ self.__current_target_names = None
+
+ """
+ Properties
+ """
+
+ @property
+ def image_paths(self):
+ return self.__image_paths
+
+ @property
+ def input_transform(self):
+ return self._input_transform
+
+ @property
+ def target_transform(self):
+ return self._target_transform
+
+ @property
+ def input_channel_keys(self):
+ return self._input_channel_keys
+
+ @property
+ def target_channel_keys(self):
+ return self._target_channel_keys
+ @property
+ def input_names(self):
+ return self.__current_input_names
+
+ @property
+ def target_names(self):
+ return self.__current_target_names
+
+ """
+ Setters
+ """
+
+ def set_input_transform(self, _input_transform: Optional[Union[Compose, ImageOnlyTransform]] = None):
+ """Sets the input image transform."""
+ self.logger.debug("Setting input transform ...")
+ self._input_transform = _input_transform
+
+ def set_target_transform(self, _target_transform: Optional[Union[Compose, ImageOnlyTransform]] = None):
+ """Sets the target image transform."""
+ self.logger.debug("Setting target transform ...")
+ self._target_transform = _target_transform
+
+ def set_input_channel_keys(self, _input_channel_keys: Union[str, List[str]]):
+ """
+ Set the input channel keys
+
+ :param _input_channel_keys: The input channel keys
+ :type _input_channel_keys: str or list of str
+ """
+ self._input_channel_keys = self.__check_channel_keys(_input_channel_keys)
+ self.logger.debug(f"Set input channel(s) as {self._input_channel_keys}")
+
+ # clear the cache
+ self.__cache_image_id = None
+
+ def set_target_channel_keys(self, _target_channel_keys: Union[str, List[str]]):
+ """
+ Set the target channel keys
+
+ :param _target_channel_keys: The target channel keys
+ :type _target_channel_keys: str or list of str
+ """
+ self._target_channel_keys = self.__check_channel_keys(_target_channel_keys)
+ self.logger.debug(f"Set target channel(s) as {self._target_channel_keys}")
+
+ # clear the cache
+ self.__cache_image_id = None
+
+ """
+ Logging and Debugging
+ """
+
+ def _initialize_logger(self, verbose: bool):
+ """Initializes the logger."""
+ self.logger = logging.getLogger(f"{__name__}.{id(self)}")
+ handler = logging.StreamHandler()
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+ handler.setFormatter(formatter)
+ self.logger.addHandler(handler)
+ self.logger.setLevel(logging.DEBUG if verbose else logging.WARNING)
+
+ """
+ Internal helper functions
+ """
+
+ def _get_image_paths(self, check_exists: bool):
+
+ # sets for all unique sites and channels
+ sites = set()
+ channels = set()
+ image_files = list(self.image_dir.glob("*"))
+
+ site_to_channels = defaultdict(dict)
+ for file in image_files:
+ site_match = self.site_pattern.search(file.name)
+ try:
+ site = site_match.group(1)
+ except:
+ continue
+ sites.add(site)
+
+ channel_match = self.channel_pattern.search(file.name)
+ try:
+ channel = channel_match.group(1)
+ except:
+ continue
+ channels.add(channel)
+
+ site_to_channels[site][channel] = file
+
+ # format as list of dicts
+ image_paths = []
+ for site, channel_to_file in site_to_channels.items():
+ ## Keep only sites with all channels
+ if all([c in site_to_channels[site] for c in channels]):
+ if check_exists and not all(path.exists() for path in channel_to_file.values()):
+ continue
+ image_paths.append(channel_to_file)
+
+ self.logger.debug(f"Channel keys: {channels} detected")
+ self._channel_keys = list(channels)
+
+ return image_paths
+
+ def __len__(self):
+ return len(self.__image_paths)
+
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Retrieves the input and target images for a given index.
+
+ :param idx: The index of the image.
+ :return: Tuple of input and target images as tensors.
+ """
+ if idx >= len(self) or idx < 0:
+ raise IndexError("Index out of bounds")
+
+ site_id = self.__iter_image_id[idx]
+ self._cache_image(site_id)
+
+ # Stack input and target images
+ input_images = np.stack([self.__input_cache[key] for key in self._input_channel_keys], axis=0)
+ target_images = np.stack([self.__target_cache[key] for key in self._target_channel_keys], axis=0)
+
+ # Apply transformations
+ if self._input_transform:
+ input_images = self._input_transform(image=input_images)['image']
+ if self._target_transform:
+ target_images = self._target_transform(image=target_images)['image']
+
+ return torch.from_numpy(input_images).float(), torch.from_numpy(target_images).float()
+
+ def _cache_image(self, site_id: str) -> None:
+ """
+ Loads and caches images for a given site ID.
+
+ :param site_id: The site ID.
+ """
+ if self.__cache_image_id != site_id:
+ self.__cache_image_id = site_id
+ self.__input_cache = {}
+ self.__target_cache = {}
+
+ ## Update target and input names (which are just file path(s))
+ self.__current_input_names = [self.__image_paths[site_id][key] for key in self._input_channel_keys]
+ self.__current_target_names = [self.__image_paths[site_id][key] for key in self._target_channel_keys]
+
+ for key in self._input_channel_keys:
+ self.__input_cache[key] = self._read_convert_image(self.__image_paths[site_id][key])
+ for key in self._target_channel_keys:
+ self.__target_cache[key] = self._read_convert_image(self.__image_paths[site_id][key])
+
+ def _read_convert_image(self, image_path: pathlib.Path) -> np.ndarray:
+ """
+ Reads and converts an image to a numpy array.
+
+ :param image_path: The image file path.
+ :return: The image as a numpy array.
+ """
+ return np.array(Image.open(image_path).convert(self._PIL_image_mode))
+
+ def __check_channel_keys(
+ self,
+ channel_keys: Optional[Union[str, List[str]]]
+ ) -> List[str]:
+ """
+ Checks user supplied channel key against the inferred ones from the file
+
+ :param channel_keys: user supplied list or single object of string channel keys
+ :type channel_keys: string or list of strings
+ """
+ if channel_keys is None:
+ self.logger.debug("No channel keys specified, skip")
+ return None
+ elif isinstance(channel_keys, str):
+ channel_keys = [channel_keys]
+ elif isinstance(channel_keys, list):
+ if not all([isinstance(key, str) for key in channel_keys]):
+ raise ValueError('Channel keys must be a string or a list of strings.')
+ else:
+ raise ValueError('Channel keys must be a string or a list of strings.')
+
+ ## Check supplied channel keys against inferred ones
+ filtered_channel_keys = []
+ for key in channel_keys:
+ if not key in self._channel_keys:
+ self.logger.debug(
+ f"ignoring channel key {key} as it does not match loaddata csv file"
+ )
+ else:
+ filtered_channel_keys.append(key)
+
+ if len(filtered_channel_keys) == 0:
+ raise ValueError(f'None of the supplied channel keys match the loaddata csv file')
+
+ return filtered_channel_keys
\ No newline at end of file
diff --git a/datasets/ImageDataset.py b/datasets/ImageDataset.py
new file mode 100644
index 0000000..bf6033f
--- /dev/null
+++ b/datasets/ImageDataset.py
@@ -0,0 +1,459 @@
+import logging
+import pathlib
+from random import randint
+from typing import List, Optional, Union, Tuple
+
+import numpy as np
+import pandas as pd
+from PIL import Image
+from albumentations import ImageOnlyTransform
+from albumentations.core.composition import Compose
+import torch
+from torch.utils.data import Dataset
+
+class ImageDataset(Dataset):
+ """
+ Image Dataset Class from pe2loaddata generated cellprofiler loaddata csv
+ """
+ def __init__(
+ self,
+ _loaddata_csv,
+ _input_channel_keys: Optional[Union[str, List[str]]] = None,
+ _target_channel_keys: Optional[Union[str, List[str]]] = None,
+ _input_transform: Optional[Union[Compose, ImageOnlyTransform]] = None,
+ _target_transform: Optional[Union[Compose, ImageOnlyTransform]] = None,
+ _PIL_image_mode: str = 'I;16',
+ verbose: bool = False,
+ file_column_prefix: str = 'FileName_',
+ path_column_prefix: str = 'PathName_',
+ check_exists: bool = False,
+ **kwargs
+ ):
+ """
+ Initialize the ImageDataset.
+
+ :param _loaddata_csv: The dataframe or path to a csv file containing the image paths and labels.
+ :type _loaddata_csv: Union[pd.DataFrame, str]
+ :param _input_channel_keys: Keys for input channels. Can be a single key or a list of keys.
+ :type _input_channel_keys: Optional[Union[str, List[str]]]
+ :param _target_channel_keys: Keys for target channels. Can be a single key or a list of keys.
+ :type _target_channel_keys: Optional[Union[str, List[str]]]
+ :param _input_transform: Transformations to apply to the input images.
+ :type _input_transform: Optional[Union[Compose, ImageOnlyTransform]]
+ :param _target_transform: Transformations to apply to the target images.
+ :type _target_transform: Optional[Union[Compose, ImageOnlyTransform]]
+ :param _PIL_image_mode: Mode to use when loading images with PIL. Default is 'I;16'.
+ :type _PIL_image_mode: str
+ :param kwargs: Additional keyword arguments.
+ """
+
+ self._initialize_logger(verbose)
+ self._loaddata_df = self._load_loaddata(_loaddata_csv, **kwargs)
+ self._channel_keys = list(self.__infer_channel_keys(file_column_prefix, path_column_prefix))
+
+ # Initialize the cache for the input and target images
+ self.__input_cache = {}
+ self.__target_cache = {}
+ self.__cache_image_id = None
+
+ # Set input/target channels
+ self.logger.debug("Setting input channel(s) ...")
+ self._input_channel_keys = self.__check_channel_keys(_input_channel_keys)
+ self.logger.debug("Setting target channel(s) ...")
+ self._target_channel_keys = self.__check_channel_keys(_target_channel_keys)
+
+ self.set_input_transform(_input_transform)
+ self.set_target_transform(_target_transform)
+
+ self._PIL_image_mode = _PIL_image_mode
+
+ # Obtain image paths
+ self.__image_paths = self._get_image_paths(
+ file_column_prefix=file_column_prefix,
+ path_column_prefix=path_column_prefix,
+ check_exists=check_exists,
+ **kwargs
+ )
+ # Index patches and images
+ self.__iter_image_id = list(range(len(self.__image_paths)))
+
+ # Initialize the current input and target names
+ self.__current_input_names = None
+ self.__current_target_names = None
+
+ """
+ Overridden Iterator functions
+ """
+ def __len__(self):
+ return len(self.__image_paths)
+
+ def __getitem__(self, _idx: int)->Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Return the input and target images
+ :param _idx: The index of the image
+ :type _idx: int
+ :return: The input and target images, each with dimension [n_channels, height, width]
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
+ """
+
+ if _idx >= len(self) or _idx < 0:
+ raise IndexError("Index out of bounds")
+
+ if self._input_channel_keys is None or self._target_channel_keys is None:
+ raise ValueError("Input and target channel keys must be set to access data")
+
+ image_id = self.__iter_image_id[_idx]
+ self._cache_image(image_id)
+
+ ## Retrieve relevant channels as specified by input and target channel keys and stack
+ input_images = np.stack(
+ [self.__input_cache[key] for key in self._input_channel_keys],
+ axis=0)
+ target_images = np.stack(
+ [self.__target_cache[key] for key in self._target_channel_keys],
+ axis=0)
+
+ ## Apply transform
+ if self._input_transform:
+ input_images = self._input_transform(image=input_images)['image']
+ if self._target_transform:
+ target_images = self._target_transform(image=target_images)['image']
+
+ ## Cast to torch tensor and return
+ return torch.from_numpy(input_images).float(), torch.from_numpy(target_images).float()
+
+ """
+ Properties
+ """
+
+ @property
+ def image_paths(self):
+ return self.__image_paths
+
+ @property
+ def input_transform(self):
+ return self._input_transform
+
+ @property
+ def target_transform(self):
+ return self._target_transform
+
+ @property
+ def input_channel_keys(self):
+ return self._input_channel_keys
+
+ @property
+ def target_channel_keys(self):
+ return self._target_channel_keys
+ @property
+ def input_names(self):
+ return self.__current_input_names
+
+ @property
+ def target_names(self):
+ return self.__current_target_names
+
+ """
+ Setters
+ """
+
+ def set_input_transform(self, _input_transform: Optional[Union[Compose, ImageOnlyTransform]]=None):
+ """
+ Set the input transform
+
+ :param _input_transform: The input transform
+ :type _input_transform: Compose or ImageOnlyTransform
+ """
+ # Check and set input/target transforms
+ self.logger.debug("Setting input transform ...")
+ if self.__check_transforms(_input_transform):
+ self._input_transform = _input_transform
+
+
+ def set_target_transform(self, _target_transform: Optional[Union[Compose, ImageOnlyTransform]]=None):
+ """
+ Set the target transform
+
+ :param _target_transform: The target transform
+ :type _target_transform: Compose or ImageOnlyTransform
+ """
+ # Check and set input/target transforms
+ self.logger.debug("Setting target transform ...")
+ if self.__check_transforms(_target_transform):
+ self._target_transform = _target_transform
+
+ def set_input_channel_keys(self, _input_channel_keys: Union[str, List[str]]):
+ """
+ Set the input channel keys
+
+ :param _input_channel_keys: The input channel keys
+ :type _input_channel_keys: str or list of str
+ """
+ self._input_channel_keys = self.__check_channel_keys(_input_channel_keys)
+ self.logger.debug(f"Set input channel(s) as {self._input_channel_keys}")
+
+ # clear the cache
+ self.__cache_image_id = None
+
+ def set_target_channel_keys(self, _target_channel_keys: Union[str, List[str]]):
+ """
+ Set the target channel keys
+
+ :param _target_channel_keys: The target channel keys
+ :type _target_channel_keys: str or list of str
+ """
+ self._target_channel_keys = self.__check_channel_keys(_target_channel_keys)
+ self.logger.debug(f"Set target channel(s) as {self._target_channel_keys}")
+
+ # clear the cache
+ self.__cache_image_id = None
+
+ """
+ Internal Helper functions
+ """
+ def _initialize_logger(self, verbose: bool):
+ """
+ Initialize logger instance
+ """
+ self.logger = logging.getLogger(f"{__name__}.{id(self)}")
+ handler = logging.StreamHandler()
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+ handler.setFormatter(formatter)
+ self.logger.addHandler(handler)
+ self.logger.setLevel(logging.DEBUG if verbose else logging.WARNING)
+
+ def _load_loaddata(
+ self,
+ _loaddata_csv: Union[pd.DataFrame, pathlib.Path],
+ **kwargs
+ ) -> pd.DataFrame:
+ """
+ Read loaddata csv file, also does type checking
+
+ :param _loaddata_csv: The path to the loaddata CSV file or a DataFrame.
+ :type _loaddata_csv: Union[pd.DataFrame, pathlib.Path]
+ :param kwargs: Additional keyword arguments.
+ :type kwargs: dict
+ :raises ValueError: If no loaddata CSV is supplied or if the file type is not supported.
+ :raises FileNotFoundError: If the specified file does not exist.
+ :return: The loaded data as a DataFrame.
+ :rtype: pd.DataFrame
+ """
+
+ if _loaddata_csv is None:
+ raise ValueError("No loaddata csv supplied")
+ elif isinstance(_loaddata_csv, pd.DataFrame):
+ self.logger.debug("Dataframe supplied for loaddata_csv, using as is")
+ return _loaddata_csv
+ else:
+ self.logger.debug("Loading loaddata csv from file")
+ ## Convert string to pathlib Path
+ if not isinstance(_loaddata_csv, pathlib.Path):
+ try:
+ _loaddata_csv = pathlib.Path(_loaddata_csv)
+ except e:
+ raise e
+
+ ## Handle file not exist
+ if not _loaddata_csv.exists():
+ raise FileNotFoundError(f"File {_loaddata_csv} not found")
+
+ ## Determine file extension and load accordingly
+ if _loaddata_csv.suffix == '.csv':
+ return pd.read_csv(_loaddata_csv)
+ elif _loaddata_csv.suffix == '.parquet':
+ return pd.read_parquet(_loaddata_csv)
+ else:
+ raise ValueError(f"File type {_loaddata_csv.suffix} not supported")
+
+ def __infer_channel_keys(
+ self,
+ file_column_prefix: str,
+ path_column_prefix: str
+ ) -> set[str]:
+ """
+ Infer channel names from the columns of loaddata csv.
+ This method identifies and returns the set of channel keys by comparing
+ the columns in the dataframe that start with the specified file and path
+ column prefixes. The channel keys are the suffixes of these columns after
+ removing the prefixes.
+
+ :param file_column_prefix: The prefix for columns that indicate filenames.
+ :type file_column_prefix: str
+ :param path_column_prefix: The prefix for columns that indicate paths.
+ :type path_column_prefix: str
+ :return: A set of channel keys inferred from the loaddata csv.
+ :rtype: set[str]
+ :raises ValueError: If no path or file columns are found, or if no matching
+ channel keys are found between file and path columns.
+ """
+
+ self.logger.debug("Inferring channel keys from loaddata csv")
+ # Retrieve columns that indicate path and filename to image files
+ file_columns = [col for col in self._loaddata_df.columns if col.startswith(file_column_prefix)]
+ path_columns = [col for col in self._loaddata_df.columns if col.startswith(path_column_prefix)]
+
+ if len(file_columns) == 0 or len(path_columns) == 0:
+ raise ValueError('No path or file columns found in loaddata csv.')
+
+ # Anything following the prefix should be the channel names
+ file_channel_keys = [col.replace(file_column_prefix, '') for col in file_columns]
+ path_channel_keys = [col.replace(path_column_prefix, '') for col in path_columns]
+ channel_keys = set(file_channel_keys).intersection(set(path_channel_keys))
+
+ if len(channel_keys) == 0:
+ raise ValueError('No matching channel keys found between file and path columns.')
+
+ self.logger.debug(f"Channel keys: {channel_keys} inferred from loaddata csv")
+
+ return channel_keys
+
+ def __check_channel_keys(
+ self,
+ channel_keys: Optional[Union[str, List[str]]]
+ ) -> List[str]:
+ """
+ Checks user supplied channel key against the inferred ones from the file
+
+ :param channel_keys: user supplied list or single object of string channel keys
+ :type channel_keys: string or list of strings
+ """
+ if channel_keys is None:
+ self.logger.debug("No channel keys specified, skip")
+ return None
+ elif isinstance(channel_keys, str):
+ channel_keys = [channel_keys]
+ elif isinstance(channel_keys, list):
+ if not all([isinstance(key, str) for key in channel_keys]):
+ raise ValueError('Channel keys must be a string or a list of strings.')
+ else:
+ raise ValueError('Channel keys must be a string or a list of strings.')
+
+ ## Check supplied channel keys against inferred ones
+ filtered_channel_keys = []
+ for key in channel_keys:
+ if not key in self._channel_keys:
+ self.logger.debug(
+ f"ignoring channel key {key} as it does not match loaddata csv file"
+ )
+ else:
+ filtered_channel_keys.append(key)
+
+ if len(filtered_channel_keys) == 0:
+ raise ValueError(f'None of the supplied channel keys match the loaddata csv file')
+
+ return filtered_channel_keys
+
+ def __check_transforms(
+ self,
+ transforms: Optional[Union[Compose, ImageOnlyTransform]]
+ ) -> bool:
+ """
+ Checks if supplied iamge only transform is of valid type, if so, return True
+
+ :param transforms: Transform
+ :type transforms: ImageOnlyTransform or Compose of ImageOnlyTransforms
+ :return: Boolean indicator of success
+ :rtype: bool
+ """
+ if transforms is None:
+ pass
+ elif isinstance(transforms, Compose):
+ pass
+ elif isinstance(transforms, ImageOnlyTransform):
+ pass
+ else:
+ raise TypeError('Invalid image transform type')
+
+ return True
+
+ def _get_image_paths(self,
+ file_column_prefix: str,
+ path_column_prefix: str,
+ check_exists: bool = False,
+ **kwargs,
+ ) -> List[dict]:
+ """
+ From loaddata csv, extract the paths to all image channels cooresponding to each view/site
+
+ :param check_exists: check if every individual image file exist and remove those that do not
+ :type check_exists: bool
+ :return: A list of dictionaries containing the paths to the image channels
+ :rtype: List[dict]
+ """
+
+ # Define helper function to get the image file paths from all channels
+ # in a single row of loaddata csv (single view/site), organized into a dict
+ def get_channel_paths(row: pd.Series) -> Tuple[dict, bool]:
+
+ missing = False
+
+ multi_channel_paths = {}
+ for channel_key in self._channel_keys:
+ file_column = f"{file_column_prefix}{channel_key}"
+ path_column = f"{path_column_prefix}{channel_key}"
+
+ if file_column in row and path_column in row:
+ file = pathlib.Path(
+ row[path_column]
+ ) / row[file_column]
+ if (not check_exists) or file.exists():
+ multi_channel_paths[channel_key] = file
+ else:
+ missing = True
+
+ return multi_channel_paths, missing
+
+ image_paths = []
+ self.logger.debug(
+ "Extracting image channel paths of site/view and associated"\
+ "cell coordinates (if applicable) from loaddata csv")
+
+ for _, row in self._loaddata_df.iterrows():
+ multi_channel_paths, missing = get_channel_paths(row)
+ if not missing:
+ image_paths.append(multi_channel_paths)
+
+ self.logger.debug(f"Extracted images of all input and target channels for {len(image_paths)} unique sites/view")
+ return image_paths
+
+ def _read_convert_image(self, _image_path: pathlib.Path)->np.ndarray:
+ """
+ Read and convert the image to a numpy array
+
+ :param _image_path: The path to the image
+ :type _image_path: pathlib.Path
+ :return: The image as a numpy array
+ :rtype: np.ndarray
+ """
+ return np.array(Image.open(_image_path).convert(self._PIL_image_mode))
+
+ def _cache_image(self, _id: int)->None:
+ """
+ Determines if cached images need to be updated and updates the self.__input_cache and self.__target_cache
+ Meant to be called by __getitem__ method in dynamic patch cropping
+
+ :param _id: The index of the image
+ :type _id: int
+ :return: None
+ :rtype: None
+ """
+
+ if self.__cache_image_id is None or self.__cache_image_id != _id:
+ self.__cache_image_id = _id
+ self.__input_cache = {}
+ self.__target_cache = {}
+
+ ## Update target and input names (which are just file path(s))
+ self.__current_input_names = [self.__image_paths[_id][key] for key in self._input_channel_keys]
+ self.__current_target_names = [self.__image_paths[_id][key] for key in self._target_channel_keys]
+
+ for key in self._input_channel_keys:
+ self.__input_cache[key] = self._read_convert_image(self.__image_paths[_id][key])
+ for key in self._target_channel_keys:
+ self.__target_cache[key] = self._read_convert_image(self.__image_paths[_id][key])
+ else:
+ # No need to update the cache
+ pass
+
+ return None
\ No newline at end of file
diff --git a/datasets/PatchDataset.py b/datasets/PatchDataset.py
new file mode 100644
index 0000000..08cecae
--- /dev/null
+++ b/datasets/PatchDataset.py
@@ -0,0 +1,677 @@
+import math
+import pathlib
+import random
+from random import randint
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import pandas as pd
+from pyarrow import parquet as pq
+import torch
+
+from .ImageDataset import ImageDataset
+
+class PatchDataset(ImageDataset):
+ """
+ Patch Dataset Class from pe2loaddata generated cellprofiler loaddata csv and sc features
+ """
+ def __init__(
+ self,
+ _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]] = None,
+ patch_size: int = 64,
+ patch_generation_method: str = 'random',
+ patch_generation_random_seed: Optional[int] = None,
+ patch_generation_max_attempts: int = 1_000,
+ n_expected_patches_per_img: int = 5,
+ candidate_x: str = 'Metadata_Cells_Location_Center_X',
+ candidate_y: str = 'Metadata_Cells_Location_Center_Y',
+ **kwargs
+ ):
+ """
+ Initialize the PatchDataset.
+ This method initializes the PatchDataset by setting up patch size, coordinates, and other parameters.
+ It also generates patches and initializes caches for input and target images.
+
+ :param _sc_feature: Single-cell feature data or path to the data, by default None.
+ :type _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]], optional
+ :param patch_size: Size of the patches to generate, by default 64.
+ :type patch_size: int, optional
+ :param patch_generation_method: Method to generate patches ('random' or other methods), by default 'random'.
+ :type patch_generation_method: str, optional
+ :param patch_generation_random_seed: Random seed for patch generation, by default None.
+ :type patch_generation_random_seed: Optional[int], optional
+ :param patch_generation_max_attempts: Maximum number of attempts to generate patches, by default 1,000.
+ :type patch_generation_max_attempts: int, optional
+ :param n_expected_patches_per_img: Number of expected patches per image, by default 5.
+ :type n_expected_patches_per_img: int, optional
+ :param candidate_x: Column name for x-coordinates of candidate cells, by default 'Metadata_Cells_Location_Center_X'.
+ :type candidate_x: str, optional
+ :param candidate_y: Column name for y-coordinates of candidate cells, by default 'Metadata_Cells_Location_Center_Y'.
+ :type candidate_y: str, optional
+ :param kwargs: Additional keyword arguments. Namely those required by ImageDataset
+ :type kwargs: dict
+ """
+
+ self._patch_size = patch_size
+ self._merge_fields = None
+ self._x_col = None
+ self._y_col = None
+ self.__cell_coords = []
+
+ # This intializes the channels keys, loaddata loading, image mode and
+ # the overriden methods further merge the loaddata with sc features
+ super().__init__(_sc_feature=_sc_feature,
+ candidate_x=candidate_x,
+ candidate_y=candidate_y,
+ **kwargs)
+
+ ## Generates patches with the specified arguments
+ self.__patch_coords = self._generate_patches(
+ _patch_size=self._patch_size,
+ patch_generation_method=patch_generation_method,
+ patch_generation_random_seed=patch_generation_random_seed,
+ n_expected_patches_per_img=n_expected_patches_per_img,
+ max_attempts=patch_generation_max_attempts,
+ consistent_img_size=kwargs.get('consistent_img_size', True),
+ )
+
+ # Index patches and images
+ self.__iter_image_id = []
+ self.__iter_patch_id = []
+ for i, _patch_coords in enumerate(self.__patch_coords):
+ for j, _ in enumerate(_patch_coords):
+ self.__iter_image_id.append(i)
+ self.__iter_patch_id.append(j)
+
+ # Initialize the cache for the input and target images
+ self.__input_cache = {}
+ self.__target_cache = {}
+ self.__cache_image_id = None
+
+ # Initialize the current input and target names and patch coordinates
+ self.__current_input_names = None
+ self.__current_target_names = None
+ self.__current_patch_coords = None
+
+ """
+ Overridden Iterator functions
+ """
+ def __len__(self):
+ return len(self.__iter_patch_id)
+
+ def __getitem__(self, _idx: int)->Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Return the input and target images
+
+ :param _idx: The index of the image
+ :type _idx: int
+ :return: The input and target images, each with dimension [n_channels, height, width]
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
+ """
+
+ if _idx >= len(self) or _idx < 0:
+ raise IndexError("Index out of bounds")
+
+ if self._input_channel_keys is None or self._target_channel_keys is None:
+ raise ValueError("Input and target channel keys must be set to access data")
+
+ image_id = self.__iter_image_id[_idx]
+ patch_id = self.__iter_patch_id[_idx]
+ self.__current_patch_coords = self.__patch_coords[image_id][patch_id]
+
+ self._cache_image(image_id)
+
+ ## Retrieve relevant channels as specified by input and target channel keys and stack
+ ## And further crop the patches with __current_patch_coords
+ input_images = np.stack(
+ [self._ImageDataset__input_cache[key][
+ self.__current_patch_coords[1]:self.__current_patch_coords[1] + self._patch_size,
+ self.__current_patch_coords[0]:self.__current_patch_coords[0] + self._patch_size
+ ] for key in self._input_channel_keys],
+ axis=0)
+ target_images = np.stack(
+ [self._ImageDataset__target_cache[key][
+ self.__current_patch_coords[1]:self.__current_patch_coords[1] + self._patch_size,
+ self.__current_patch_coords[0]:self.__current_patch_coords[0] + self._patch_size
+ ] for key in self._target_channel_keys],
+ axis=0)
+
+ ## Apply transform
+ if self._input_transform:
+ input_images = self._input_transform(image=input_images)['image']
+ if self._target_transform:
+ target_images = self._target_transform(image=target_images)['image']
+
+ ## Cast to torch tensor and return
+ return torch.from_numpy(input_images).float(), torch.from_numpy(target_images).float()
+
+ """
+ Properties
+ """
+
+ @property
+ def patch_size(self):
+ return self._patch_size
+
+ @property
+ def cell_coords(self):
+ return self.__cell_coords
+
+ @property
+ def all_patch_coords(self):
+ return self.__patch_coords
+
+ @property
+ def patch_coords(self):
+ return self.__current_patch_coords
+
+ @property
+ def raw_input(self):
+ """
+ Returns a tuple of input, target raw images where the current patch is cropped
+ from. Relies on the parent class _ImageDataset__input_cache and _ImageDataset__target_cache.
+ Returns None when the cache is empty. Raises an error if the input channel keys are not set.
+
+ :return: Tuple of input, target raw images
+ :rtype: Tuple[np.ndarray, np
+ """
+
+ if self._input_channel_keys is None:
+ raise ValueError("Input channel keys not set")
+
+ return np.stack(
+ [self._ImageDataset__input_cache[key] for key in self._input_channel_keys],
+ axis=0) if self._input_channel_keys is not None else None
+
+ @property
+ def raw_target(self):
+ """
+ Returns a tuple of input, target raw images where the current patch is cropped
+ from. Relies on the parent class _ImageDataset__input_cache and _ImageDataset__target_cache.
+ Returns None when the cache is empty. Raises an error if the target channel keys are not set.
+
+ :return: Tuple of input, target raw images
+ :rtype: Tuple[np.ndarray, np.ndarray]
+ """
+
+ if self._target_channel_keys is None:
+ raise ValueError("Target channel keys not set")
+
+ return np.stack(
+ [self._ImageDataset__target_cache[key] for key in self._target_channel_keys],
+ axis=0) if self._target_channel_keys is not None else None
+
+ """
+ Internal Helper functions
+ """
+
+ def __preload_sc_feature(self,
+ _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]]=None) -> List[str]:
+ """
+ Preload the sc feature dataframe/parquet file limiting only to the column headers
+ If a dataframe is supplied, use as is and return the column names
+ If a path to a csv file is supplied, load only the header row
+ If a path to a parquet file is supplied, load only the parquet schema name
+
+ :param _sc_feature: The path to a csv file containing the cell profiler sc features
+ :type _sc_feature: str or pathlib.Path
+ :return: List of column names of dataframe/csv/parquet file
+ :rtype: List of strings
+ """
+
+ if _sc_feature is None:
+ # No sc feature supplied, cell coordinates not available, patch generation will fixed random
+ self.logger.debug("No sc feature supplied, patch generation will be random")
+ self._patch_generation_method = 'random'
+ return None
+
+ elif isinstance(_sc_feature, pd.DataFrame):
+ self.logger.debug("Dataframe supplied for sc_feature, using as is")
+ return _sc_feature.columns.tolist()
+
+ else:
+ self.logger.debug("Preloading sc feature from file")
+ if not isinstance(_sc_feature, pathlib.Path):
+ try:
+ _sc_feature = pathlib.Path(_sc_feature)
+ except e:
+ raise e
+
+ if not _sc_feature.exists():
+ raise FileNotFoundError(f"File {_sc_feature} not found")
+
+ if _sc_feature.suffix == '.csv':
+ self.logger.debug("Preloading sc feature from csv file")
+ return pd.read_csv(_sc_feature, nrows=0).columns.tolist()
+ elif _sc_feature.suffix == '.parquet':
+ pq_file = pq.ParquetFile(_sc_feature)
+ return pq_file.schema.names
+ else:
+ raise ValueError(f"File type {_sc_feature.suffix} not supported")
+
+ def __infer_merge_fields(self,
+ _loaddata_df,
+ _sc_col_names: List[str]
+ ) -> Union[List[str], None]:
+ """
+ Find the columns that are common to both dataframes to use in an inner join
+ Mean to be used to associate loaddata_csv with sc features
+
+ :param loaddata_csv: The first dataframe
+ :type loaddata_csv: pd.DataFrame
+ :param sc_feature: The second dataframe
+ :type sc_feature: pd.DataFrame
+ :return: The columns that are common to both dataframes
+ :rtype: List[str]
+ """
+ if _sc_col_names is None:
+ return None
+
+ self.logger.debug("Both loaddata_csv and sc_feature supplied, " \
+ "inferring merge fields to associate the two dataframes")
+ merge_fields = list(set(_loaddata_df.columns).intersection(set(_sc_col_names)))
+ if len(merge_fields) == 0:
+ raise ValueError("No common columns found between loaddata_csv and sc_feature")
+ self.logger.debug(f"Merge fields inferred: {merge_fields}")
+
+ return merge_fields
+
+ def __infer_x_y_columns(self,
+ _loaddata_df,
+ _sc_col_names: List[str],
+ candidate_x: str,
+ candidate_y: str) -> Tuple[str, str]:
+ """
+ Infer the columns that contain the x and y coordinates of the patches.
+ Will look for user specified patterns first but when no patterns
+ match this function returns the first columns that ends with _x and _y
+
+ :param candidate_x: The candidate column name for the x coordinates
+ :type candidate_x: str
+ :param candidate_y: The candidate column name for the y coordinates
+ :type candidate_y: str
+ :return: The columns that contain the x and y coordinates of the patches
+ :rtype: Tuple[str, str]
+ """
+
+ if _loaddata_df is None:
+ return None, None
+
+ if candidate_x not in _sc_col_names or candidate_y not in _sc_col_names:
+ self.logger.debug(f"X and Y columns {candidate_x}, {candidate_y} not detected in sc_features, attempting to infer from sc_feature dataframe")
+
+ # infer the columns that contain the x and y coordinates
+ x_col_candidates = [col for col in _sc_col_names if col.lower().endswith('_x')]
+ y_col_candidates = [col for col in _sc_col_names if col.lower().endswith('_y')]
+
+ if len(x_col_candidates) == 0 or len(y_col_candidates) == 0:
+ raise ValueError("No columns found containing the x and y coordinates")
+ else:
+ # sort x_col and y_col candidates
+ x_col_candidates.sort()
+ y_col_candidates.sort()
+ x_col_detected = x_col_candidates[0]
+ y_col_detected = y_col_candidates[0]
+ self.logger.debug(f"X and Y columns {x_col_detected}, {y_col_detected} detected in sc_feature dataframe, using as the coordinates for cell centers")
+ return x_col_detected, y_col_detected
+ else:
+ self.logger.debug(f"X and Y columns {candidate_x}, {candidate_y} detected in sc_feature dataframe, using as the coordinates for cell centers")
+ return candidate_x, candidate_y
+
+ def __load_sc_feature(self,
+ _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]],
+ _merge_fields: List[str],
+ _x_col: str,
+ _y_col: str
+ ) -> Union[pd.DataFrame, None]:
+ """
+ Load the actual sc feature as a dataframe, limiting the columns
+ to the merge fields and the x and y coordinates
+
+ :param _sc_feature: The path to a csv file containing the cell profiler sc features
+ :type _sc_feature: str or pathlib.Path
+ :return: The dataframe containing the cell profiler sc features
+ :rtype: pd.DataFrame
+ """
+
+ if _sc_feature is None:
+ return None
+ elif isinstance(_sc_feature, pd.DataFrame):
+ self.logger.debug("Dataframe supplied for sc_feature, using as is")
+ return _sc_feature
+ else:
+ self.logger.debug("Loading sc feature from file")
+ if not isinstance(_sc_feature, pathlib.Path):
+ try:
+ _sc_feature = pathlib.Path(_sc_feature)
+ except e:
+ raise e
+
+ if not _sc_feature.exists():
+ raise FileNotFoundError(f"File {_sc_feature} not found")
+
+ if _sc_feature.suffix == '.csv':
+ return pd.read_csv(_sc_feature,
+ usecols=_merge_fields + [_x_col, _y_col])
+ elif _sc_feature.suffix == '.parquet':
+ return pq.read_table(_sc_feature, columns=_merge_fields + [_x_col, _y_col]).to_pandas()
+ else:
+ raise ValueError(f"File type {_sc_feature.suffix} not supported")
+
+ """
+ Overriden parent class helper functions
+ """
+
+ def _load_loaddata(self,
+ _loaddata_csv: Union[pd.DataFrame, pathlib.Path],
+ _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]],
+ candidate_x: str,
+ candidate_y: str,
+ ):
+ """
+ Overridden function from parent class
+ Calls the parent class to get the loaddata df and then merges it with sc_feature
+
+ :param _loaddata_csv: The path to the loaddata CSV file or a DataFrame.
+ :type _loaddata_csv: Union[pd.DataFrame, pathlib.Path]
+ :param _sc_feature: The path to the single cell feature parquet file, csv file or Datafarme
+ :type _sc_feature: Optional[Union[pd.DataFrame, pathlib.Path]]
+ :param candidate_x: User specified column to access cell x coords
+ :type candidate_x: str
+ :param candidate_y: User specified column to access cell y coords
+ :type candidate_y: str
+ :param kwargs: Additional keyword arguments.
+ :type kwargs: dict
+ :raises ValueError: If no loaddata CSV is supplied or if the file type is not supported.
+ :raises FileNotFoundError: If the specified file does not exist.
+ :return: The loaded data as a DataFrame.
+ :rtype: pd.DataFrame
+ """
+
+ ## First calls the parent class to get the full loaddata df
+ loaddata_df = super()._load_loaddata(_loaddata_csv)
+
+ ## Obtain column names of sc features first to avoid needing to read in the whole
+ ## Parquet file as only a very small number of columns are needed
+ sc_feature_col_names = self.__preload_sc_feature(_sc_feature)
+
+ ## Infer columns corresponding to x and y coordinates to cells
+ self._x_col, self._y_col = self.__infer_x_y_columns(
+ loaddata_df, sc_feature_col_names, candidate_x, candidate_y)
+
+ ## Infer merge fields between the sc features and loaddata
+ self._merge_fields = self.__infer_merge_fields(loaddata_df, sc_feature_col_names)
+
+ ## Load sc features
+ sc_feature_df = self.__load_sc_feature(
+ _sc_feature, self._merge_fields, self._x_col, self._y_col)
+
+ ## Perform the merge and return the merged dataframe (which is loaddata plus columns for x and y coordinates)
+ return loaddata_df.merge(sc_feature_df, on=self._merge_fields, how='inner')
+
+ def _get_image_paths(self,
+ file_column_prefix: str,
+ path_column_prefix: str,
+ check_exists: bool = False,
+ **kwargs,
+ ) -> List[dict]:
+ """
+ Overridden function
+ From loaddata csv, extract the paths to all image channels cooresponding to each view/site
+
+ :param check_exists: check if every individual image file exist and remove those that do not
+ :type check_exists: bool
+ :return: A list of dictionaries containing the paths to the image channels
+ :rtype: List[dict]
+ """
+
+ # Define helper function to get the image file paths from all channels
+ # in a single row of loaddata csv (single view/site), organized into a dict
+ def get_channel_paths(row: pd.Series) -> Tuple[dict, bool]:
+
+ missing = False
+
+ multi_channel_paths = {}
+ for channel_key in self._channel_keys:
+ file_column = f"{file_column_prefix}{channel_key}"
+ path_column = f"{path_column_prefix}{channel_key}"
+
+ if file_column in row and path_column in row:
+ file = pathlib.Path(
+ row[path_column]
+ ) / row[file_column]
+ if (not check_exists) or file.exists():
+ multi_channel_paths[channel_key] = file
+ else:
+ missing = True
+
+ return multi_channel_paths, missing
+
+ # Define helper function to get the coordinates associated with a condition
+ def get_associated_coords(group):
+
+ try:
+ return group.loc[:, [self._x_col, self._y_col]].values
+ except:
+ return None
+
+ image_paths = []
+ cell_coords = []
+ self.logger.debug(
+ "Extracting image channel paths of site/view and associated"\
+ "cell coordinates (if applicable) from loaddata csv")
+
+ n_cells = 0
+ ## Group by identifier of site/view
+ grouped = self._loaddata_df.groupby(self._merge_fields)
+ for _, group in grouped:
+ ## Retrieve image file paths for each channel from the first row
+ ## (any row within group should have the same filenames)
+ _, row = next(group.iterrows())
+ multi_channel_paths, missing = get_channel_paths(row)
+ if not missing:
+ image_paths.append(multi_channel_paths)
+ # Get cell coords associated with each site/view
+ coords = get_associated_coords(group)
+ n_cells += len(coords)
+ cell_coords.append(coords)
+
+ self.logger.debug("Extracted images of all input and target channels for " \
+ f"{len(image_paths)} unique sites/view and {n_cells} cells")
+
+ ## Save cell coords (a list of np 2d array) as attribute
+ self.__cell_coords = cell_coords
+
+ ## Image paths will be saved as attribute by parent class init
+ return image_paths
+
+ """
+ Patch generation helper functions
+ """
+
+ def _generate_patches(self,
+ _patch_size: int,
+ patch_generation_method: str,
+ patch_generation_random_seed: int,
+ n_expected_patches_per_img=5,
+ max_attempts=1_000,
+ consistent_img_size=True,
+ )->None:
+ """
+ Generate patches for each image in the dataset
+
+ :param patch_generation_method: The method to use for generating patches
+ :type patch_generation_method: str
+ :param patch_generation_random_seed: The random seed to use for patch generation
+ :type patch_generation_random_seed: int
+ :param consistent_img_size: Whether the images are consistent in size.
+ If True, the patch generation will be based on the size of the first input channel of first image
+ If False, the patch generation will be based on the size of each image
+ :type consistent_img_size: bool
+ :param n_expected_patches_per_img: The number of patches to generate per image
+ :type n_expected_patches_per_img: int
+ :param max_attempts: The maximum number of attempts to generate a patch
+ :type max_attempts: int
+ :return: The coordinates of the patches
+ :rtype: List[List[Tuple[int
+ """
+ if patch_generation_method == 'random_cell':
+ if self.__cell_coords is None:
+ raise ValueError("Cell coordinates not available for generating cell containing patches")
+ else:
+ self.logger.debug("Generating patches that contain cells")
+ def patch_fn(image_size, patch_size, cell_coords, n_expected_patches_per_img, max_attempts):
+ return self.__generate_cell_containing_patches_unit(
+ image_size, patch_size, cell_coords, n_expected_patches_per_img, max_attempts)
+ pass
+ elif patch_generation_method == 'random':
+ self.logger.debug("Generating random patches")
+ def patch_fn(image_size, patch_size, cell_coords, n_expected_patches_per_img, max_attempts):
+ # cell_coords is not used in this case
+ return self.__generate_random_patches_unit(image_size, patch_size, n_expected_patches_per_img, max_attempts)
+ pass
+ else:
+ raise ValueError("Patch generation method not supported")
+
+ # Generate patches for each image
+ image_size = None
+ patch_count = 0
+ patch_coords = []
+
+ # set random seed
+ if patch_generation_random_seed is not None:
+ random.seed(patch_generation_random_seed)
+ for channel_paths, cell_coords in zip(self._ImageDataset__image_paths, self.__cell_coords):
+ if consistent_img_size:
+ if image_size is not None:
+ pass
+ else:
+ try:
+ image_size = self._read_convert_image(channel_paths[self._channel_keys[0]]).shape[0]
+ self.logger.debug(
+ f"Image size inferred: {image_size} for all images "
+ "to force redetect image sizes for each view/site set consistent_img_size=False"
+ )
+ except:
+ raise ValueError("Error reading image size")
+ pass
+ else:
+ try:
+ image_size = self._read_convert_image(channel_paths[self._channel_keys[0]]).shape[0]
+ except:
+ raise ValueError("Error reading image size")
+
+ coords = patch_fn(
+ image_size=image_size,
+ patch_size=_patch_size,
+ cell_coords=cell_coords,
+ n_expected_patches_per_img=n_expected_patches_per_img,
+ max_attempts=max_attempts
+ )
+ patch_coords.append(coords)
+ patch_count += len(coords)
+
+ self.logger.debug(f"Generated {patch_count} patches for {len(self._ImageDataset__image_paths)} site/view")
+ return patch_coords
+
+ @staticmethod
+ def __generate_cell_containing_patches_unit(
+ image_size,
+ patch_size,
+ cell_coords,
+ expected_n_patches=5,
+ max_attempts=1_000):
+ """
+ Static helper function to generate patches that contain the cell coordinates
+
+ :param image_size: The size of the image (square)
+ :type image_size: int
+ :param patch_size: The size of the square patches to generate
+ :type patch_size: int
+ :param cell_coords: The coordinates of the cells
+ :type cell_coords: List[Tuple[int, int]]
+ :param expected_n_patches: The number of patches to generate
+ :type expected_n_patches: int
+ :return: The coordinates of the patches
+ """
+
+ unit_size = math.gcd(image_size, patch_size)
+ tile_size_units = patch_size // unit_size
+ grid_size_units = image_size // unit_size
+
+ cell_containing_units = {(x // unit_size, y // unit_size) for x, y in cell_coords}
+ placed_tiles = set()
+ retained_tiles = []
+
+ attempts = 0
+ n_tiles = 0
+ while attempts < max_attempts:
+ top_left_x = randint(0, grid_size_units - tile_size_units)
+ top_left_y = randint(0, grid_size_units - tile_size_units)
+
+ tile_units = {(x, y) for x in range(top_left_x, top_left_x + tile_size_units)
+ for y in range(top_left_y, top_left_y + tile_size_units)}
+
+ if any(tile_units & placed_tile for placed_tile in placed_tiles):
+ attempts += 1
+ continue
+
+ if tile_units & cell_containing_units:
+ retained_tiles.append((top_left_x * unit_size, top_left_y * unit_size))
+ placed_tiles.add(frozenset(tile_units))
+ n_tiles += 1
+
+ attempts += 1
+ if n_tiles >= expected_n_patches:
+ break
+
+ return retained_tiles
+
+ @staticmethod
+ def __generate_random_patches_unit(
+ image_size,
+ patch_size,
+ expected_n_patches=5,
+ max_attempts=1_000):
+ """
+ Static helper function to generate random patches
+
+ :param image_size: The size of the image (square)
+ :type image_size: int
+ :param patch_size: The size of the square patches to generate
+ :type patch_size: int
+ :param expected_n_patches: The number of patches to generate
+ :type expected_n_patches: int
+ :return: The coordinates of the patches
+ """
+ unit_size = math.gcd(image_size, patch_size)
+ tile_size_units = patch_size // unit_size
+ grid_size_units = image_size // unit_size
+
+ placed_tiles = set()
+ retained_tiles = []
+
+ attempts = 0
+ n_tiles = 0
+ while attempts < max_attempts:
+ top_left_x = randint(0, grid_size_units - tile_size_units)
+ top_left_y = randint(0, grid_size_units - tile_size_units)
+
+ # check for overlap with already placed tiles
+ tile_units = {(x, y) for x in range(top_left_x, top_left_x + tile_size_units)
+ for y in range(top_left_y, top_left_y + tile_size_units)}
+
+ if any(tile_units & placed_tile for placed_tile in placed_tiles):
+ attempts += 1
+ continue
+
+ # no overlap, add the tile to the list of retained tiles
+ retained_tiles.append((top_left_x * unit_size, top_left_y * unit_size))
+ placed_tiles.add(frozenset(tile_units))
+ n_tiles += 1
+
+ attempts += 1
+ if n_tiles >= expected_n_patches:
+ break
+
+ return retained_tiles
\ No newline at end of file
diff --git a/datasets/README.md b/datasets/README.md
new file mode 100644
index 0000000..6dd6d9e
--- /dev/null
+++ b/datasets/README.md
@@ -0,0 +1,2 @@
+Here lives the dataset classes for interacting with cell painting images.
+The datasets are currently completely dependent on the pe2loaddata generated csv files.
\ No newline at end of file
diff --git a/evaluation/README.md b/evaluation/README.md
new file mode 100644
index 0000000..a698a2a
--- /dev/null
+++ b/evaluation/README.md
@@ -0,0 +1,3 @@
+Here lives some collection of functions useful for evaluating model performance and plotting
+
+These code are pretty messy and will need major revisions.
\ No newline at end of file
diff --git a/evaluation/evaluation_utils.py b/evaluation/evaluation_utils.py
new file mode 100644
index 0000000..e9d5c48
--- /dev/null
+++ b/evaluation/evaluation_utils.py
@@ -0,0 +1,42 @@
+from typing import List, Optional
+
+import pandas as pd
+import torch
+from torch.nn import Module
+from torch.utils.data import DataLoader
+
+def evaluate_per_image_metric(
+ predictions: torch.Tensor,
+ targets: torch.Tensor,
+ metrics: List[Module],
+ indices: Optional[List[int]] = None
+) -> pd.DataFrame:
+ """
+ Computes a set of metrics on a per-image basis and returns the results as a pandas DataFrame.
+
+ :param predictions: Predicted images, shape (N, C, H, W).
+ :type predictions: torch.Tensor
+ :param targets: Target images, shape (N, C, H, W).
+ :type targets: torch.Tensor
+ :param metrics: List of metric functions to evaluate.
+ :type metrics: List[torch.nn.Module]
+ :param indices: Optional list of indices to subset the dataset before inference. If None, all images are evaluated.
+ :type indices: Optional[List[int]], optional
+
+ :return: A DataFrame where each row corresponds to an image and each column corresponds to a metric.
+ :rtype: pd.DataFrame
+ """
+ if predictions.shape != targets.shape:
+ raise ValueError(f"Shape mismatch: predictions {predictions.shape} vs targets {targets.shape}")
+
+ results = []
+
+ if indices is None:
+ indices = range(predictions.shape[0])
+
+ for i in indices: # Iterate over images/subset
+ pred, target = predictions[i].unsqueeze(0), targets[i].unsqueeze(0) # Keep batch dimension
+ metric_scores = {metric.__class__.__name__: metric.forward(target, pred).item() for metric in metrics}
+ results.append(metric_scores)
+
+ return pd.DataFrame(results)
\ No newline at end of file
diff --git a/evaluation/predict_utils.py b/evaluation/predict_utils.py
new file mode 100644
index 0000000..bba623b
--- /dev/null
+++ b/evaluation/predict_utils.py
@@ -0,0 +1,103 @@
+from typing import Optional, List, Tuple, Callable
+
+import torch
+import numpy as np
+from torch.utils.data import DataLoader, Dataset, Subset
+from albumentations import ImageOnlyTransform, Compose
+
+def predict_image(
+ dataset: Dataset,
+ model: torch.nn.Module,
+ batch_size: int = 1,
+ device: str = "cpu",
+ num_workers: int = 0,
+ indices: Optional[List[int]] = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Runs a model on a dataset, performing a forward pass on all (or a subset of) input images
+ in evaluation mode and returning a stacked tensor of predictions.
+ DOES NOT check if the dataset dimensions are compatible with the model.
+
+ :param dataset: A dataset that returns (input_tensor, target_tensor) tuples,
+ where input_tensor has shape (C, H, W).
+ :type dataset: torch.utils.data.Dataset
+ :param model: A PyTorch model that is compatible with the dataset inputs.
+ :type model: torch.nn.Module
+ :param batch_size: The number of samples per batch (default is 1).
+ :type batch_size: int, optional
+ :param device: The device to run inference on, e.g., "cpu" or "cuda".
+ :type device: str, optional
+ :param num_workers: Number of workers for the DataLoader (default is 0).
+ :type num_workers: int, optional
+ :param indices: Optional list of dataset indices to subset the dataset before inference.
+ :type indices: Optional[List[int]], optional
+
+ :return: Tuple of stacked target and prediction tensors.
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
+ """
+ # Subset the dataset if indices are provided
+ if indices is not None:
+ dataset = Subset(dataset, indices)
+
+ # Create DataLoader for efficient batch processing
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
+
+ model.to(device)
+ model.eval()
+
+ predictions, targets = [], []
+
+ with torch.no_grad():
+ for inputs, target in dataloader: # Unpacking (input_tensor, target_tensor)
+ inputs = inputs.to(device) # Move input data to the specified device
+
+ # Forward pass
+ outputs = model(inputs)
+
+ # output both target and prediction tensors for metric
+ targets.append(target.cpu())
+ predictions.append(outputs.cpu()) # Move to CPU for stacking
+
+ return torch.cat(targets, dim=0), torch.cat(predictions, dim=0)
+
+def process_tensor_image(
+ img_tensor: torch.Tensor,
+ dtype: Optional[np.dtype] = None,
+ dataset: Optional[Dataset] = None,
+ invert_function: Optional[Callable] = None
+) -> np.ndarray:
+ """
+ Processes model output/other image tensor by casting to numpy, applying an optional dtype casting,
+ and inverting target transformations if a dataset with `target_transform` is provided.
+
+ :param img_tensor: Tensor stack of model-predicted images with shape (N, C, H, W).
+ :type img_tensor: torch.Tensor
+ :param dtype: Optional numpy dtype to cast the output array (default: None).
+ :type dtype: Optional[np.dtype], optional
+ :param dataset: Optional dataset object with `target_transform` to invert transformations.
+ :type dataset: Optional[torch.utils.data.Dataset], optional
+ :param invert_function: Optional function to invert transformations applied to the images.
+ If provided, overrides the invert function call from dataset transform.
+ :type invert_function: Optional[Callable], optional
+
+ :return: Processed numpy array of images with shape (N, C, H, W).
+ :rtype: np.ndarray
+ """
+ # Convert img_tensor to CPU and NumPy
+ output_images = img_tensor.cpu().numpy()
+
+ # Optionally cast to specified dtype
+ if dtype is not None:
+ output_images = output_images.astype(dtype)
+
+ # Apply invert function when supplied or transformation if invert function is supplied
+ if invert_function is not None and isinstance(invert_function, Callable):
+ output_images = np.array([invert_function(img) for img in output_images])
+ elif dataset is not None and hasattr(dataset, "target_transform"):
+ # Apply inverted target transformation if available
+ target_transform = dataset.target_transform
+ if isinstance(target_transform, (ImageOnlyTransform, Compose)):
+ # Apply the transformation on each image
+ output_images = np.array([target_transform.invert(img) for img in output_images])
+
+ return output_images
\ No newline at end of file
diff --git a/evaluation/visualization_utils.py b/evaluation/visualization_utils.py
new file mode 100644
index 0000000..1b27699
--- /dev/null
+++ b/evaluation/visualization_utils.py
@@ -0,0 +1,174 @@
+from typing import List, Union, Optional
+
+import numpy as np
+import pandas as pd
+import torch
+from torch.utils.data import Dataset
+import matplotlib.pyplot as plt
+from matplotlib.patches import Rectangle
+
+from ..datasets.PatchDataset import PatchDataset
+from ..evaluation.predict_utils import predict_image, process_tensor_image
+from ..evaluation.evaluation_utils import evaluate_per_image_metric
+
+def _plot_predictions_grid(
+ inputs: Union[np.ndarray, torch.Tensor],
+ targets: Union[np.ndarray, torch.Tensor],
+ predictions: Union[np.ndarray, torch.Tensor],
+ raw_images: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ patch_coords: Optional[List[tuple]] = None,
+ metrics_df: Optional[pd.DataFrame] = None,
+ save_path: Optional[str] = None,
+ **kwargs
+):
+ """
+ Generalized function to plot a grid of images with predictions and optional raw images.
+ The Batch dimensions of (raw_image), input, target, and prediction should match and so should the length of metrics_df.
+
+ :param inputs: Input images (N, C, H, W) or (N, H, W).
+ :param targets: Target images (N, C, H, W) or (N, H, W).
+ :param predictions: Model predictions (N, C, H, W) or (N, H, W).
+ :param raw_images: Optional raw images for PatchDataset (N, H, W).
+ :param patch_coords: Optional list of (x, y) coordinates for patches.
+ Only used if raw_images is provided. Length match the first dimension of inputs/targets/predictions.
+ :param metrics_df: Optional DataFrame with per-image metrics.
+ :param save_path: If provided, saves figure.
+ :param kwargs: Additional keyword arguments to pass to plt.subplots.
+ """
+
+ cmap = kwargs.get("cmap", "gray")
+ panel_width = kwargs.get("panel_width", 5)
+ show_plot = kwargs.get("show_plot", True)
+ fig_size = kwargs.get("fig_size", None)
+
+ num_samples = len(inputs)
+ is_patch_dataset = raw_images is not None
+ num_cols = 4 if is_patch_dataset else 3 # (Raw | Input | Target | Prediction) vs (Input | Target | Prediction)
+
+ fig_size = (panel_width * num_cols, panel_width * num_samples) if fig_size is None else fig_size
+ fig, axes = plt.subplots(num_samples, num_cols, figsize=fig_size)
+ column_titles = ["Raw Image", "Input", "Target", "Prediction"] if is_patch_dataset else ["Input", "Target", "Prediction"]
+
+ for row_idx in range(num_samples):
+ img_set = [raw_images[row_idx]] if is_patch_dataset else []
+ img_set.extend([inputs[row_idx], targets[row_idx], predictions[row_idx]])
+
+ for col_idx, img in enumerate(img_set):
+ ax = axes[row_idx, col_idx]
+ ax.imshow(img.squeeze(), cmap=cmap)
+ ax.set_title(column_titles[col_idx])
+ ax.axis("off")
+
+ # Draw rectangle on raw image if PatchDataset
+ if is_patch_dataset and col_idx == 0 and patch_coords is not None:
+ patch_x, patch_y = patch_coords[row_idx] # (x, y) coordinates
+ patch_size = targets.shape[-1] # Assume square patches from target size
+ rect = Rectangle((patch_x, patch_y), patch_size, patch_size, linewidth=2, edgecolor="r", facecolor="none")
+ ax.add_patch(rect)
+
+ # Display metrics if provided
+ if metrics_df is not None:
+ metric_values = metrics_df.iloc[row_idx]
+ metric_text = "\n".join([f"{key}: {value:.3f}" for key, value in metric_values.items()])
+ axes[row_idx, -1].set_title(
+ axes[row_idx, -1].get_title() + "\n" + metric_text, fontsize=10, pad=10)
+
+ # Save and/or show the plot
+ if save_path:
+ plt.savefig(save_path, bbox_inches="tight", dpi=300)
+ if show_plot:
+ plt.show()
+ else:
+ plt.close()
+
+def plot_predictions_grid_from_eval(
+ dataset: Dataset,
+ predictions: Union[torch.Tensor, np.ndarray],
+ indices: List[int],
+ metrics_df: Optional[pd.DataFrame] = None,
+ save_path: Optional[str] = None,
+ **kwargs
+):
+ """
+ Wrapper function to extract dataset samples and call `_plot_predictions_grid`.
+ This function operates on the outputs downstream of `evaluate_per_image_metric`
+ and `predict_image` to avoid unecessary forward pass.
+
+ :param dataset: Dataset (either normal or PatchDataset).
+ :param predictions: Subsetted tensor/NumPy array of predictions.
+ :param indices: Indices corresponding to the subset.
+ :param metrics_df: DataFrame with per-image metrics for the subset.
+ :param save_path: If provided, saves figure.
+ :param kwargs: Additional keyword arguments to pass to `_plot_predictions_grid`.
+ """
+
+ is_patch_dataset = isinstance(dataset, PatchDataset)
+
+ # Extract input, target, and (optional) raw images & patch coordinates
+ raw_images, inputs, targets, patch_coords = [], [], [], []
+ for i in indices:
+ inputs.append(dataset[i][0])
+ targets.append(dataset[i][1])
+ if is_patch_dataset:
+ raw_images.append(dataset.raw_input)
+ patch_coords.append(dataset.patch_coords) # Get patch location
+
+ inputs_numpy = process_tensor_image(torch.stack(inputs), invert_function=dataset.input_transform.invert)
+ targets_numpy = process_tensor_image(torch.stack(targets), invert_function=dataset.target_transform.invert)
+
+ # Pass everything to the core grid function
+ _plot_predictions_grid(
+ inputs_numpy, targets_numpy, predictions[indices],
+ raw_images if is_patch_dataset else None,
+ patch_coords if is_patch_dataset else None,
+ metrics_df, save_path, **kwargs
+ )
+
+def plot_predictions_grid_from_model(
+ model: torch.nn.Module,
+ dataset: Dataset,
+ indices: List[int],
+ metrics: List[torch.nn.Module],
+ device: str = "cuda",
+ save_path: Optional[str] = None,
+ **kwargs
+):
+ """
+ Wrapper plot function that internally performs inference and evaluation with the following steps:
+ 1. Perform inference on a subset of the dataset given the model.
+ 2. Compute per-image metrics on that subset.
+ 3. Plot the results with core `_plot_predictions_grid` function.
+
+ :param model: PyTorch model for inference.
+ :param dataset: The dataset to use for evaluation and plotting.
+ :param indices: List of dataset indices to evaluate and visualize.
+ :param metrics: List of metric functions to evaluate.
+ :param device: Device to run inference on ("cpu" or "cuda").
+ :param save_path: Optional path to save the plot.
+ :param kwargs: Additional keyword arguments to pass to `_plot_predictions_grid`.
+ """
+ # Step 1: Run inference on the selected subset
+ targets, predictions = predict_image(dataset, model, indices=indices, device=device)
+
+ # Step 2: Compute per-image metrics for the subset
+ metrics_df = evaluate_per_image_metric(predictions, targets, metrics)
+
+ # Step 3: Extract subset of inputs & targets and plot
+ is_patch_dataset = isinstance(dataset, PatchDataset)
+ raw_images, inputs, targets, patch_coords = [], [], [], []
+ for i in indices:
+ inputs.append(dataset[i][0])
+ targets.append(dataset[i][1])
+ if is_patch_dataset:
+ raw_images.append(dataset.raw_input)
+ patch_coords.append(dataset.patch_coords) # Get patch location
+
+ _plot_predictions_grid(
+ torch.stack(inputs),
+ torch.stack(targets),
+ predictions,
+ raw_images=raw_images if is_patch_dataset else None,
+ patch_coords=patch_coords if is_patch_dataset else None,
+ metrics_df=metrics_df,
+ save_path=save_path,
+ **kwargs)
\ No newline at end of file
diff --git a/examples/generate_sample_patch_dataset.ipynb b/examples/generate_sample_patch_dataset.ipynb
new file mode 100644
index 0000000..bef2c71
--- /dev/null
+++ b/examples/generate_sample_patch_dataset.ipynb
@@ -0,0 +1,2262 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/weishanli/Waylab\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "import pathlib\n",
+ "\n",
+ "import imageio\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "sys.path.append(str(pathlib.Path('.').absolute().parent.parent))\n",
+ "print(str(pathlib.Path('.').absolute().parent.parent))\n",
+ "\n",
+ "from virtual_stain_flow.datasets.PatchDataset import PatchDataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-02-20 00:15:46,794 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n",
+ "2025-02-20 00:15:46,794 - DEBUG - Dataframe supplied for sc_feature, using as is\n",
+ "2025-02-20 00:15:46,795 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n",
+ "2025-02-20 00:15:46,795 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n",
+ "2025-02-20 00:15:46,795 - DEBUG - Merge fields inferred: ['Metadata_Site', 'Metadata_Plate', 'Metadata_Well']\n",
+ "2025-02-20 00:15:46,795 - DEBUG - Dataframe supplied for sc_feature, using as is\n",
+ "2025-02-20 00:15:46,816 - DEBUG - Inferring channel keys from loaddata csv\n",
+ "2025-02-20 00:15:46,817 - DEBUG - Channel keys: {'OrigBrightfield', 'OrigAGP', 'OrigDNA', 'OrigRNA', 'OrigMito', 'OrigER'} inferred from loaddata csv\n",
+ "2025-02-20 00:15:46,817 - DEBUG - Setting input channel(s) ...\n",
+ "2025-02-20 00:15:46,817 - DEBUG - No channel keys specified, skip\n",
+ "2025-02-20 00:15:46,817 - DEBUG - Setting target channel(s) ...\n",
+ "2025-02-20 00:15:46,817 - DEBUG - No channel keys specified, skip\n",
+ "2025-02-20 00:15:46,817 - DEBUG - Setting input transform ...\n",
+ "2025-02-20 00:15:46,817 - DEBUG - Setting target transform ...\n",
+ "2025-02-20 00:15:46,817 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n",
+ "2025-02-20 00:15:46,834 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n",
+ "2025-02-20 00:15:46,835 - DEBUG - Generating patches that contain cells\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-02-20 00:15:46,848 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n",
+ "2025-02-20 00:15:47,226 - DEBUG - Generated 461 patches for 93 site/view\n"
+ ]
+ }
+ ],
+ "source": [
+ "## REPLACE WITH YOUR OWN PATHS\n",
+ "analysis_home_path = pathlib.Path('/home/weishanli/Waylab/ALSF_pilot/ALSF_img2img_prototyping')\n",
+ "sc_features_parquet_path = pathlib.Path(\n",
+ " '/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/preprocessed_profiles_SN0313537/single_cell_profiles'\n",
+ ")\n",
+ "\n",
+ "loaddata_csv_path = analysis_home_path \\\n",
+ " / '0.data_analysis_and_preprocessing' / 'loaddata_csvs'\n",
+ "\n",
+ "if loaddata_csv_path.exists():\n",
+ " try:\n",
+ " loaddata_csv = next(loaddata_csv_path.glob('*.csv'))\n",
+ " except:\n",
+ " raise FileNotFoundError(\"No loaddata csv found\")\n",
+ "else:\n",
+ " raise ValueError(\"Incorrect loaddata csv path\")\n",
+ "\n",
+ "loaddata_df = pd.read_csv(loaddata_csv)\n",
+ "# subsample to reduce runtime\n",
+ "loaddata_df = loaddata_df.sample(n=100, random_state=42)\n",
+ "\n",
+ "sc_features = pd.DataFrame()\n",
+ "for plate in loaddata_df['Metadata_Plate'].unique():\n",
+ " sc_features_parquet = sc_features_parquet_path / f'{plate}_sc_normalized.parquet'\n",
+ " if not sc_features_parquet.exists():\n",
+ " print(f'{sc_features_parquet} does not exist, skipping...')\n",
+ " continue \n",
+ " else:\n",
+ " sc_features = pd.concat([\n",
+ " sc_features, \n",
+ " pd.read_parquet(\n",
+ " sc_features_parquet,\n",
+ " columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']\n",
+ " )\n",
+ " ])\n",
+ "\n",
+ "PATCH_SIZE = 256\n",
+ "\n",
+ "channel_names = [\n",
+ " \"OrigBrightfield\",\n",
+ " \"OrigDNA\",\n",
+ " \"OrigER\",\n",
+ " \"OrigMito\",\n",
+ " \"OrigRNA\",\n",
+ " \"OrigAGP\",\n",
+ "]\n",
+ "input_channel_name = \"OrigBrightfield\"\n",
+ "target_channel_names = [ch for ch in channel_names if ch != input_channel_name]\n",
+ "\n",
+ "pds = PatchDataset(\n",
+ " _loaddata_csv=loaddata_df,\n",
+ " _sc_feature=sc_features,\n",
+ " _input_channel_keys=None,\n",
+ " _target_channel_keys=None,\n",
+ " _input_transform=None,\n",
+ " _target_transform=None,\n",
+ " patch_size=PATCH_SIZE,\n",
+ " verbose=True,\n",
+ " patch_generation_method=\"random_cell\",\n",
+ " patch_generation_random_seed=42\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "OrigBrightfield/r06c12f02p01-ch1sk1fk1fl1_8_680.tiff"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.microsoft.datawrangler.viewer.v0+json": {
+ "columns": [
+ {
+ "name": "index",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "FileName_OrigBrightfield",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "PathName_OrigBrightfield",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "FileName_OrigER",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "PathName_OrigER",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "FileName_OrigAGP",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "PathName_OrigAGP",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "FileName_OrigMito",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "PathName_OrigMito",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "FileName_OrigDNA",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "PathName_OrigDNA",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "FileName_OrigRNA",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "PathName_OrigRNA",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "Metadata_Plate",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "Metadata_Well",
+ "rawType": "object",
+ "type": "string"
+ },
+ {
+ "name": "Metadata_Site",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "Metadata_AbsPositionZ",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "Metadata_ChannelID",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "Metadata_Col",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "Metadata_FieldID",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "Metadata_PlaneID",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "Metadata_PositionX",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "Metadata_PositionY",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "Metadata_PositionZ",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "Metadata_Row",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "Metadata_Reimaged",
+ "rawType": "bool",
+ "type": "boolean"
+ }
+ ],
+ "conversionMethod": "pd.DataFrame",
+ "ref": "9503fb45-c7d0-43d4-8b8c-183f24a0d822",
+ "rows": [
+ [
+ "2079",
+ "r06c22f01p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c22f01p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c22f01p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c22f01p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c22f01p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c22f01p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "F22",
+ "1",
+ "0.134357795",
+ "6",
+ "22",
+ "1",
+ "1",
+ "0.0",
+ "0.0",
+ "-6e-06",
+ "6",
+ "False"
+ ],
+ [
+ "668",
+ "r05c09f03p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c09f03p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c09f03p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c09f03p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c09f03p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c09f03p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "E09",
+ "3",
+ "0.1344053",
+ "6",
+ "9",
+ "3",
+ "1",
+ "0.0",
+ "0.000645814",
+ "-6e-06",
+ "5",
+ "False"
+ ],
+ [
+ "2073",
+ "r05c22f04p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c22f04p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c22f04p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c22f04p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c22f04p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c22f04p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "E22",
+ "4",
+ "0.134365693",
+ "6",
+ "22",
+ "4",
+ "1",
+ "0.000645814",
+ "0.000645814",
+ "-6e-06",
+ "5",
+ "False"
+ ],
+ [
+ "1113",
+ "r06c13f07p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c13f07p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c13f07p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c13f07p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c13f07p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c13f07p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "F13",
+ "7",
+ "0.134346902",
+ "6",
+ "13",
+ "7",
+ "1",
+ "-0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "6",
+ "False"
+ ],
+ [
+ "788",
+ "r06c10f06p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c10f06p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c10f06p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c10f06p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c10f06p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c10f06p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "F10",
+ "6",
+ "0.134381399",
+ "6",
+ "10",
+ "6",
+ "1",
+ "-0.000645814",
+ "0.0",
+ "-6e-06",
+ "6",
+ "False"
+ ],
+ [
+ "1780",
+ "r08c19f08p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c19f08p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c19f08p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c19f08p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c19f08p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c19f08p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "H19",
+ "8",
+ "0.134326905",
+ "6",
+ "19",
+ "8",
+ "1",
+ "0.0",
+ "-0.000645814",
+ "-6e-06",
+ "8",
+ "False"
+ ],
+ [
+ "1672",
+ "r08c18f08p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c18f08p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c18f08p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c18f08p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c18f08p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c18f08p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "H18",
+ "8",
+ "0.134322807",
+ "6",
+ "18",
+ "8",
+ "1",
+ "0.0",
+ "-0.000645814",
+ "-6e-06",
+ "8",
+ "False"
+ ],
+ [
+ "1717",
+ "r13c18f08p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c18f08p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c18f08p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c18f08p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c18f08p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c18f08p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "M18",
+ "8",
+ "0.134328201",
+ "6",
+ "18",
+ "8",
+ "1",
+ "0.0",
+ "-0.000645814",
+ "-6e-06",
+ "13",
+ "False"
+ ],
+ [
+ "926",
+ "r09c11f09p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images",
+ "r09c11f09p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images",
+ "r09c11f09p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images",
+ "r09c11f09p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images",
+ "r09c11f09p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images",
+ "r09c11f09p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_SK-N-MC_Re-imaged/BR00143976__2024-07-17T19_15_31-Measurement 4/Images",
+ "BR00143976",
+ "I11",
+ "9",
+ "0.134355694",
+ "6",
+ "11",
+ "9",
+ "1",
+ "0.000645814",
+ "-0.000645814",
+ "-3e-06",
+ "9",
+ "True"
+ ],
+ [
+ "2157",
+ "r14c22f07p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c22f07p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c22f07p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c22f07p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c22f07p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c22f07p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "N22",
+ "7",
+ "0.134358302",
+ "6",
+ "22",
+ "7",
+ "1",
+ "-0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "14",
+ "False"
+ ],
+ [
+ "674",
+ "r05c09f09p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c09f09p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c09f09p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c09f09p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c09f09p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c09f09p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "E09",
+ "9",
+ "0.134399906",
+ "6",
+ "9",
+ "9",
+ "1",
+ "0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "5",
+ "False"
+ ],
+ [
+ "2011",
+ "r10c21f05p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r10c21f05p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r10c21f05p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r10c21f05p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r10c21f05p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r10c21f05p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "J21",
+ "5",
+ "0.134337202",
+ "6",
+ "21",
+ "5",
+ "1",
+ "0.000645814",
+ "0.0",
+ "-6e-06",
+ "10",
+ "False"
+ ],
+ [
+ "96",
+ "r13c03f07p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c03f07p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c03f07p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c03f07p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c03f07p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c03f07p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "M03",
+ "7",
+ "0.134531394",
+ "6",
+ "3",
+ "7",
+ "1",
+ "-0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "13",
+ "False"
+ ],
+ [
+ "1164",
+ "r12c13f04p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r12c13f04p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r12c13f04p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r12c13f04p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r12c13f04p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r12c13f04p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "L13",
+ "4",
+ "0.134333",
+ "6",
+ "13",
+ "4",
+ "1",
+ "0.000645814",
+ "0.000645814",
+ "-6e-06",
+ "12",
+ "False"
+ ],
+ [
+ "567",
+ "r06c08f01p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c08f01p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c08f01p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c08f01p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c08f01p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c08f01p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "F08",
+ "1",
+ "0.134413004",
+ "6",
+ "8",
+ "1",
+ "1",
+ "0.0",
+ "0.0",
+ "-6e-06",
+ "6",
+ "False"
+ ],
+ [
+ "135",
+ "r06c04f01p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c04f01p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c04f01p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c04f01p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c04f01p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c04f01p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "F04",
+ "1",
+ "0.134507298",
+ "6",
+ "4",
+ "1",
+ "1",
+ "0.0",
+ "0.0",
+ "-6e-06",
+ "6",
+ "False"
+ ],
+ [
+ "29",
+ "r06c03f03p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c03f03p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c03f03p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c03f03p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c03f03p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c03f03p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "F03",
+ "3",
+ "0.134534299",
+ "6",
+ "3",
+ "3",
+ "1",
+ "0.0",
+ "0.000645814",
+ "-6e-06",
+ "6",
+ "False"
+ ],
+ [
+ "1641",
+ "r05c18f04p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c18f04p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c18f04p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c18f04p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c18f04p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c18f04p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "E18",
+ "4",
+ "0.134341702",
+ "6",
+ "18",
+ "4",
+ "1",
+ "0.000645814",
+ "0.000645814",
+ "-6e-06",
+ "5",
+ "False"
+ ],
+ [
+ "1845",
+ "r04c20f01p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c20f01p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c20f01p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c20f01p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c20f01p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c20f01p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "BR00143976",
+ "D20",
+ "1",
+ "0.134361506",
+ "6",
+ "20",
+ "1",
+ "1",
+ "0.0",
+ "0.0",
+ "-4e-06",
+ "4",
+ "True"
+ ],
+ [
+ "1940",
+ "r14c20f06p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c20f06p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c20f06p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c20f06p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c20f06p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c20f06p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "N20",
+ "6",
+ "0.134341404",
+ "6",
+ "20",
+ "6",
+ "1",
+ "-0.000645814",
+ "0.0",
+ "-6e-06",
+ "14",
+ "False"
+ ],
+ [
+ "621",
+ "r12c08f01p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c08f01p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c08f01p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c08f01p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c08f01p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c08f01p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "BR00143976",
+ "L08",
+ "1",
+ "0.134399801",
+ "6",
+ "8",
+ "1",
+ "1",
+ "0.0",
+ "0.0",
+ "-4e-06",
+ "12",
+ "True"
+ ],
+ [
+ "1204",
+ "r04c14f08p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c14f08p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c14f08p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c14f08p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c14f08p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c14f08p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "BR00143976",
+ "D14",
+ "8",
+ "0.134359106",
+ "6",
+ "14",
+ "8",
+ "1",
+ "0.0",
+ "-0.000645814",
+ "-4e-06",
+ "4",
+ "True"
+ ],
+ [
+ "1455",
+ "r08c16f07p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c16f07p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c16f07p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c16f07p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c16f07p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c16f07p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "H16",
+ "7",
+ "0.134320304",
+ "6",
+ "16",
+ "7",
+ "1",
+ "-0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "8",
+ "False"
+ ],
+ [
+ "1057",
+ "r12c12f05p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c12f05p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c12f05p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c12f05p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c12f05p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c12f05p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "BR00143976",
+ "L12",
+ "5",
+ "0.134343997",
+ "6",
+ "12",
+ "5",
+ "1",
+ "0.000645814",
+ "0.0",
+ "-4e-06",
+ "12",
+ "True"
+ ],
+ [
+ "350",
+ "r05c06f09p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c06f09p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c06f09p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c06f09p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c06f09p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c06f09p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "E06",
+ "9",
+ "0.134458706",
+ "6",
+ "6",
+ "9",
+ "1",
+ "0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "5",
+ "False"
+ ],
+ [
+ "1586",
+ "r11c17f03p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r11c17f03p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r11c17f03p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r11c17f03p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r11c17f03p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r11c17f03p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "K17",
+ "3",
+ "0.134314403",
+ "6",
+ "17",
+ "3",
+ "1",
+ "0.0",
+ "0.000645814",
+ "-6e-06",
+ "11",
+ "False"
+ ],
+ [
+ "1000",
+ "r06c12f02p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c12f02p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c12f02p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c12f02p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c12f02p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c12f02p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "F12",
+ "2",
+ "0.134358004",
+ "6",
+ "12",
+ "2",
+ "1",
+ "-0.000645814",
+ "0.000645814",
+ "-6e-06",
+ "6",
+ "False"
+ ],
+ [
+ "251",
+ "r06c05f09p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c05f09p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c05f09p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c05f09p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c05f09p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c05f09p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "F05",
+ "9",
+ "0.134474307",
+ "6",
+ "5",
+ "9",
+ "1",
+ "0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "6",
+ "False"
+ ],
+ [
+ "1420",
+ "r04c16f08p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c16f08p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c16f08p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c16f08p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c16f08p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r04c16f08p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "BR00143976",
+ "D16",
+ "8",
+ "0.134352103",
+ "6",
+ "16",
+ "8",
+ "1",
+ "0.0",
+ "-0.000645814",
+ "-4e-06",
+ "4",
+ "True"
+ ],
+ [
+ "965",
+ "r14c11f03p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r14c11f03p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r14c11f03p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r14c11f03p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r14c11f03p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r14c11f03p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "BR00143976",
+ "N11",
+ "3",
+ "0.134372801",
+ "6",
+ "11",
+ "3",
+ "1",
+ "0.0",
+ "0.000645814",
+ "-2e-06",
+ "14",
+ "True"
+ ],
+ [
+ "1189",
+ "r03c14f02p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r03c14f02p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r03c14f02p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r03c14f02p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r03c14f02p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "r03c14f02p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_A-673_Re-imaged/BR00143976__2024-07-17T18_37_33-Measurement 3/Images",
+ "BR00143976",
+ "C14",
+ "2",
+ "0.134374201",
+ "6",
+ "14",
+ "2",
+ "1",
+ "-0.000645814",
+ "0.000645814",
+ "-4e-06",
+ "3",
+ "True"
+ ],
+ [
+ "2002",
+ "r09c21f05p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r09c21f05p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r09c21f05p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r09c21f05p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r09c21f05p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r09c21f05p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "I21",
+ "5",
+ "0.134339705",
+ "6",
+ "21",
+ "5",
+ "1",
+ "0.000645814",
+ "0.0",
+ "-6e-06",
+ "9",
+ "False"
+ ],
+ [
+ "111",
+ "r03c04f04p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r03c04f04p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r03c04f04p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r03c04f04p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r03c04f04p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r03c04f04p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "C04",
+ "4",
+ "0.134528205",
+ "6",
+ "4",
+ "4",
+ "1",
+ "0.000645814",
+ "0.000645814",
+ "-6e-06",
+ "3",
+ "False"
+ ],
+ [
+ "184",
+ "r11c04f05p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r11c04f05p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r11c04f05p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r11c04f05p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r11c04f05p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r11c04f05p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "BR00143976",
+ "K04",
+ "5",
+ "0.134485096",
+ "6",
+ "4",
+ "5",
+ "1",
+ "0.000645814",
+ "0.0",
+ "-4e-06",
+ "11",
+ "True"
+ ],
+ [
+ "1726",
+ "r14c18f08p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c18f08p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c18f08p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c18f08p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c18f08p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c18f08p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "N18",
+ "8",
+ "0.134337395",
+ "6",
+ "18",
+ "8",
+ "1",
+ "0.0",
+ "-0.000645814",
+ "-6e-06",
+ "14",
+ "False"
+ ],
+ [
+ "1747",
+ "r05c19f02p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c19f02p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c19f02p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c19f02p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c19f02p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r05c19f02p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "E19",
+ "2",
+ "0.134342298",
+ "6",
+ "19",
+ "2",
+ "1",
+ "-0.000645814",
+ "0.000645814",
+ "-6e-06",
+ "5",
+ "False"
+ ],
+ [
+ "1667",
+ "r08c18f03p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c18f03p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c18f03p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c18f03p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c18f03p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r08c18f03p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "H18",
+ "3",
+ "0.134323403",
+ "6",
+ "18",
+ "3",
+ "1",
+ "0.0",
+ "0.000645814",
+ "-6e-06",
+ "8",
+ "False"
+ ],
+ [
+ "2022",
+ "r11c21f07p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r11c21f07p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r11c21f07p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r11c21f07p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r11c21f07p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r11c21f07p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "K21",
+ "7",
+ "0.134334505",
+ "6",
+ "21",
+ "7",
+ "1",
+ "-0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "11",
+ "False"
+ ],
+ [
+ "422",
+ "r13c06f09p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r13c06f09p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r13c06f09p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r13c06f09p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r13c06f09p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r13c06f09p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "BR00143976",
+ "M06",
+ "9",
+ "0.1344482",
+ "6",
+ "6",
+ "9",
+ "1",
+ "0.000645814",
+ "-0.000645814",
+ "-2e-06",
+ "13",
+ "True"
+ ],
+ [
+ "508",
+ "r11c07f05p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r11c07f05p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r11c07f05p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r11c07f05p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r11c07f05p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r11c07f05p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "BR00143976",
+ "K07",
+ "5",
+ "0.134413093",
+ "6",
+ "7",
+ "5",
+ "1",
+ "0.000645814",
+ "0.0",
+ "-4e-06",
+ "11",
+ "True"
+ ],
+ [
+ "1331",
+ "r06c15f09p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c15f09p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c15f09p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c15f09p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c15f09p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c15f09p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "F15",
+ "9",
+ "0.134334505",
+ "6",
+ "15",
+ "9",
+ "1",
+ "0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "6",
+ "False"
+ ],
+ [
+ "1469",
+ "r10c16f03p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r10c16f03p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r10c16f03p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r10c16f03p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r10c16f03p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r10c16f03p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "J16",
+ "3",
+ "0.134315595",
+ "6",
+ "16",
+ "3",
+ "1",
+ "0.0",
+ "0.000645814",
+ "-6e-06",
+ "10",
+ "False"
+ ],
+ [
+ "1935",
+ "r14c20f01p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c20f01p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c20f01p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c20f01p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c20f01p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c20f01p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "N20",
+ "1",
+ "0.134343296",
+ "6",
+ "20",
+ "1",
+ "1",
+ "0.0",
+ "0.0",
+ "-6e-06",
+ "14",
+ "False"
+ ],
+ [
+ "630",
+ "r13c08f01p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r13c08f01p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r13c08f01p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r13c08f01p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r13c08f01p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "r13c08f01p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240718_SH-SY5Y_Re-imaged/BR00143976__2024-07-18T17_56_52-Measurement 6/Images",
+ "BR00143976",
+ "M08",
+ "1",
+ "0.134411901",
+ "6",
+ "8",
+ "1",
+ "1",
+ "0.0",
+ "0.0",
+ "-2e-06",
+ "13",
+ "True"
+ ],
+ [
+ "1393",
+ "r13c15f08p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c15f08p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c15f08p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c15f08p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c15f08p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c15f08p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "M15",
+ "8",
+ "0.134330496",
+ "6",
+ "15",
+ "8",
+ "1",
+ "0.0",
+ "-0.000645814",
+ "-6e-06",
+ "13",
+ "False"
+ ],
+ [
+ "1933",
+ "r13c20f08p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c20f08p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c20f08p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c20f08p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c20f08p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r13c20f08p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "M20",
+ "8",
+ "0.134337306",
+ "6",
+ "20",
+ "8",
+ "1",
+ "0.0",
+ "-0.000645814",
+ "-6e-06",
+ "13",
+ "False"
+ ],
+ [
+ "1293",
+ "r14c14f07p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c14f07p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c14f07p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c14f07p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c14f07p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r14c14f07p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "N14",
+ "7",
+ "0.134347007",
+ "6",
+ "14",
+ "7",
+ "1",
+ "-0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "14",
+ "False"
+ ],
+ [
+ "973",
+ "r03c12f02p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r03c12f02p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r03c12f02p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r03c12f02p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r03c12f02p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r03c12f02p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "C12",
+ "2",
+ "0.134390101",
+ "6",
+ "12",
+ "2",
+ "1",
+ "-0.000645814",
+ "0.000645814",
+ "-6e-06",
+ "3",
+ "False"
+ ],
+ [
+ "297",
+ "r12c05f01p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c05f01p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c05f01p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c05f01p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c05f01p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "r12c05f01p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/20240717_CHP-212_Re-imaged/BR00143976__2024-07-17T19_50_33-Measurement 5/Images",
+ "BR00143976",
+ "L05",
+ "1",
+ "0.134464994",
+ "6",
+ "5",
+ "1",
+ "1",
+ "0.0",
+ "0.0",
+ "-4e-06",
+ "12",
+ "True"
+ ],
+ [
+ "1761",
+ "r06c19f07p01-ch1sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c19f07p01-ch2sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c19f07p01-ch3sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c19f07p01-ch4sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c19f07p01-ch5sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "r06c19f07p01-ch6sk1fk1fl1.tiff",
+ "/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/SN0313537/BR00143976__2024-07-04T16_04_45-Measurement 2/Images",
+ "BR00143976",
+ "F19",
+ "7",
+ "0.134334296",
+ "6",
+ "19",
+ "7",
+ "1",
+ "-0.000645814",
+ "-0.000645814",
+ "-6e-06",
+ "6",
+ "False"
+ ]
+ ],
+ "shape": {
+ "columns": 25,
+ "rows": 100
+ }
+ },
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " FileName_OrigBrightfield | \n",
+ " PathName_OrigBrightfield | \n",
+ " FileName_OrigER | \n",
+ " PathName_OrigER | \n",
+ " FileName_OrigAGP | \n",
+ " PathName_OrigAGP | \n",
+ " FileName_OrigMito | \n",
+ " PathName_OrigMito | \n",
+ " FileName_OrigDNA | \n",
+ " PathName_OrigDNA | \n",
+ " ... | \n",
+ " Metadata_AbsPositionZ | \n",
+ " Metadata_ChannelID | \n",
+ " Metadata_Col | \n",
+ " Metadata_FieldID | \n",
+ " Metadata_PlaneID | \n",
+ " Metadata_PositionX | \n",
+ " Metadata_PositionY | \n",
+ " Metadata_PositionZ | \n",
+ " Metadata_Row | \n",
+ " Metadata_Reimaged | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 2079 | \n",
+ " r06c22f01p01-ch1sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c22f01p01-ch2sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c22f01p01-ch3sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c22f01p01-ch4sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c22f01p01-ch5sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " ... | \n",
+ " 0.134358 | \n",
+ " 6 | \n",
+ " 22 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " -0.000006 | \n",
+ " 6 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 668 | \n",
+ " r05c09f03p01-ch1sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r05c09f03p01-ch2sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r05c09f03p01-ch3sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r05c09f03p01-ch4sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r05c09f03p01-ch5sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " ... | \n",
+ " 0.134405 | \n",
+ " 6 | \n",
+ " 9 | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 0.000000 | \n",
+ " 0.000646 | \n",
+ " -0.000006 | \n",
+ " 5 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 2073 | \n",
+ " r05c22f04p01-ch1sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r05c22f04p01-ch2sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r05c22f04p01-ch3sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r05c22f04p01-ch4sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r05c22f04p01-ch5sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " ... | \n",
+ " 0.134366 | \n",
+ " 6 | \n",
+ " 22 | \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 0.000646 | \n",
+ " 0.000646 | \n",
+ " -0.000006 | \n",
+ " 5 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 1113 | \n",
+ " r06c13f07p01-ch1sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c13f07p01-ch2sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c13f07p01-ch3sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c13f07p01-ch4sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c13f07p01-ch5sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " ... | \n",
+ " 0.134347 | \n",
+ " 6 | \n",
+ " 13 | \n",
+ " 7 | \n",
+ " 1 | \n",
+ " -0.000646 | \n",
+ " -0.000646 | \n",
+ " -0.000006 | \n",
+ " 6 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 788 | \n",
+ " r06c10f06p01-ch1sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c10f06p01-ch2sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c10f06p01-ch3sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c10f06p01-ch4sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r06c10f06p01-ch5sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " ... | \n",
+ " 0.134381 | \n",
+ " 6 | \n",
+ " 10 | \n",
+ " 6 | \n",
+ " 1 | \n",
+ " -0.000646 | \n",
+ " 0.000000 | \n",
+ " -0.000006 | \n",
+ " 6 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 1730 | \n",
+ " r03c19f03p01-ch1sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r03c19f03p01-ch2sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r03c19f03p01-ch4sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r03c19f03p01-ch3sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r03c19f03p01-ch6sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " ... | \n",
+ " 0.134366 | \n",
+ " 6 | \n",
+ " 19 | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 0.000000 | \n",
+ " 0.000646 | \n",
+ " -0.000004 | \n",
+ " 3 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " 196 | \n",
+ " r12c04f08p01-ch1sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r12c04f08p01-ch2sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r12c04f08p01-ch4sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r12c04f08p01-ch3sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r12c04f08p01-ch6sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " ... | \n",
+ " 0.134491 | \n",
+ " 6 | \n",
+ " 4 | \n",
+ " 8 | \n",
+ " 1 | \n",
+ " 0.000000 | \n",
+ " -0.000646 | \n",
+ " -0.000004 | \n",
+ " 12 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " 367 | \n",
+ " r07c06f08p01-ch1sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r07c06f08p01-ch2sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r07c06f08p01-ch3sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r07c06f08p01-ch4sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r07c06f08p01-ch5sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " ... | \n",
+ " 0.134447 | \n",
+ " 6 | \n",
+ " 6 | \n",
+ " 8 | \n",
+ " 1 | \n",
+ " 0.000000 | \n",
+ " -0.000646 | \n",
+ " -0.000006 | \n",
+ " 7 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 650 | \n",
+ " r03c09f03p01-ch1sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r03c09f03p01-ch2sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r03c09f03p01-ch3sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r03c09f03p01-ch4sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r03c09f03p01-ch5sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " ... | \n",
+ " 0.134428 | \n",
+ " 6 | \n",
+ " 9 | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 0.000000 | \n",
+ " 0.000646 | \n",
+ " -0.000006 | \n",
+ " 3 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 2064 | \n",
+ " r04c22f04p01-ch1sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r04c22f04p01-ch2sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r04c22f04p01-ch4sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r04c22f04p01-ch3sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " r04c22f04p01-ch6sk1fk1fl1.tiff | \n",
+ " /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... | \n",
+ " ... | \n",
+ " 0.134379 | \n",
+ " 6 | \n",
+ " 22 | \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 0.000646 | \n",
+ " 0.000646 | \n",
+ " -0.000004 | \n",
+ " 4 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
100 rows × 25 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " FileName_OrigBrightfield \\\n",
+ "2079 r06c22f01p01-ch1sk1fk1fl1.tiff \n",
+ "668 r05c09f03p01-ch1sk1fk1fl1.tiff \n",
+ "2073 r05c22f04p01-ch1sk1fk1fl1.tiff \n",
+ "1113 r06c13f07p01-ch1sk1fk1fl1.tiff \n",
+ "788 r06c10f06p01-ch1sk1fk1fl1.tiff \n",
+ "... ... \n",
+ "1730 r03c19f03p01-ch1sk1fk1fl1.tiff \n",
+ "196 r12c04f08p01-ch1sk1fk1fl1.tiff \n",
+ "367 r07c06f08p01-ch1sk1fk1fl1.tiff \n",
+ "650 r03c09f03p01-ch1sk1fk1fl1.tiff \n",
+ "2064 r04c22f04p01-ch1sk1fk1fl1.tiff \n",
+ "\n",
+ " PathName_OrigBrightfield \\\n",
+ "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "... ... \n",
+ "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "\n",
+ " FileName_OrigER \\\n",
+ "2079 r06c22f01p01-ch2sk1fk1fl1.tiff \n",
+ "668 r05c09f03p01-ch2sk1fk1fl1.tiff \n",
+ "2073 r05c22f04p01-ch2sk1fk1fl1.tiff \n",
+ "1113 r06c13f07p01-ch2sk1fk1fl1.tiff \n",
+ "788 r06c10f06p01-ch2sk1fk1fl1.tiff \n",
+ "... ... \n",
+ "1730 r03c19f03p01-ch2sk1fk1fl1.tiff \n",
+ "196 r12c04f08p01-ch2sk1fk1fl1.tiff \n",
+ "367 r07c06f08p01-ch2sk1fk1fl1.tiff \n",
+ "650 r03c09f03p01-ch2sk1fk1fl1.tiff \n",
+ "2064 r04c22f04p01-ch2sk1fk1fl1.tiff \n",
+ "\n",
+ " PathName_OrigER \\\n",
+ "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "... ... \n",
+ "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "\n",
+ " FileName_OrigAGP \\\n",
+ "2079 r06c22f01p01-ch3sk1fk1fl1.tiff \n",
+ "668 r05c09f03p01-ch3sk1fk1fl1.tiff \n",
+ "2073 r05c22f04p01-ch3sk1fk1fl1.tiff \n",
+ "1113 r06c13f07p01-ch3sk1fk1fl1.tiff \n",
+ "788 r06c10f06p01-ch3sk1fk1fl1.tiff \n",
+ "... ... \n",
+ "1730 r03c19f03p01-ch4sk1fk1fl1.tiff \n",
+ "196 r12c04f08p01-ch4sk1fk1fl1.tiff \n",
+ "367 r07c06f08p01-ch3sk1fk1fl1.tiff \n",
+ "650 r03c09f03p01-ch3sk1fk1fl1.tiff \n",
+ "2064 r04c22f04p01-ch4sk1fk1fl1.tiff \n",
+ "\n",
+ " PathName_OrigAGP \\\n",
+ "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "... ... \n",
+ "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "\n",
+ " FileName_OrigMito \\\n",
+ "2079 r06c22f01p01-ch4sk1fk1fl1.tiff \n",
+ "668 r05c09f03p01-ch4sk1fk1fl1.tiff \n",
+ "2073 r05c22f04p01-ch4sk1fk1fl1.tiff \n",
+ "1113 r06c13f07p01-ch4sk1fk1fl1.tiff \n",
+ "788 r06c10f06p01-ch4sk1fk1fl1.tiff \n",
+ "... ... \n",
+ "1730 r03c19f03p01-ch3sk1fk1fl1.tiff \n",
+ "196 r12c04f08p01-ch3sk1fk1fl1.tiff \n",
+ "367 r07c06f08p01-ch4sk1fk1fl1.tiff \n",
+ "650 r03c09f03p01-ch4sk1fk1fl1.tiff \n",
+ "2064 r04c22f04p01-ch3sk1fk1fl1.tiff \n",
+ "\n",
+ " PathName_OrigMito \\\n",
+ "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "... ... \n",
+ "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "\n",
+ " FileName_OrigDNA \\\n",
+ "2079 r06c22f01p01-ch5sk1fk1fl1.tiff \n",
+ "668 r05c09f03p01-ch5sk1fk1fl1.tiff \n",
+ "2073 r05c22f04p01-ch5sk1fk1fl1.tiff \n",
+ "1113 r06c13f07p01-ch5sk1fk1fl1.tiff \n",
+ "788 r06c10f06p01-ch5sk1fk1fl1.tiff \n",
+ "... ... \n",
+ "1730 r03c19f03p01-ch6sk1fk1fl1.tiff \n",
+ "196 r12c04f08p01-ch6sk1fk1fl1.tiff \n",
+ "367 r07c06f08p01-ch5sk1fk1fl1.tiff \n",
+ "650 r03c09f03p01-ch5sk1fk1fl1.tiff \n",
+ "2064 r04c22f04p01-ch6sk1fk1fl1.tiff \n",
+ "\n",
+ " PathName_OrigDNA ... \\\n",
+ "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "... ... ... \n",
+ "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "\n",
+ " Metadata_AbsPositionZ Metadata_ChannelID Metadata_Col Metadata_FieldID \\\n",
+ "2079 0.134358 6 22 1 \n",
+ "668 0.134405 6 9 3 \n",
+ "2073 0.134366 6 22 4 \n",
+ "1113 0.134347 6 13 7 \n",
+ "788 0.134381 6 10 6 \n",
+ "... ... ... ... ... \n",
+ "1730 0.134366 6 19 3 \n",
+ "196 0.134491 6 4 8 \n",
+ "367 0.134447 6 6 8 \n",
+ "650 0.134428 6 9 3 \n",
+ "2064 0.134379 6 22 4 \n",
+ "\n",
+ " Metadata_PlaneID Metadata_PositionX Metadata_PositionY \\\n",
+ "2079 1 0.000000 0.000000 \n",
+ "668 1 0.000000 0.000646 \n",
+ "2073 1 0.000646 0.000646 \n",
+ "1113 1 -0.000646 -0.000646 \n",
+ "788 1 -0.000646 0.000000 \n",
+ "... ... ... ... \n",
+ "1730 1 0.000000 0.000646 \n",
+ "196 1 0.000000 -0.000646 \n",
+ "367 1 0.000000 -0.000646 \n",
+ "650 1 0.000000 0.000646 \n",
+ "2064 1 0.000646 0.000646 \n",
+ "\n",
+ " Metadata_PositionZ Metadata_Row Metadata_Reimaged \n",
+ "2079 -0.000006 6 False \n",
+ "668 -0.000006 5 False \n",
+ "2073 -0.000006 5 False \n",
+ "1113 -0.000006 6 False \n",
+ "788 -0.000006 6 False \n",
+ "... ... ... ... \n",
+ "1730 -0.000004 3 True \n",
+ "196 -0.000004 12 True \n",
+ "367 -0.000006 7 False \n",
+ "650 -0.000006 3 False \n",
+ "2064 -0.000004 4 True \n",
+ "\n",
+ "[100 rows x 25 columns]"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "loaddata_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "EXAMPLE_PATCH_DATA_EXPORT_PATH = pathlib.Path('.').absolute().parent.parent / 'example_patch_data'\n",
+ "EXAMPLE_PATCH_DATA_EXPORT_PATH.mkdir(exist_ok=True)\n",
+ "INPUT_EXPORT_PATH = EXAMPLE_PATCH_DATA_EXPORT_PATH / input_channel_name\n",
+ "INPUT_EXPORT_PATH.mkdir(exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-02-20 00:15:47,252 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n",
+ "2025-02-20 00:15:47,252 - DEBUG - Set target channel(s) as ['OrigDNA']\n",
+ "2025-02-20 00:15:47,475 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n",
+ "2025-02-20 00:15:47,475 - DEBUG - Set target channel(s) as ['OrigER']\n",
+ "2025-02-20 00:15:47,676 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n",
+ "2025-02-20 00:15:47,677 - DEBUG - Set target channel(s) as ['OrigMito']\n",
+ "2025-02-20 00:15:47,850 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n",
+ "2025-02-20 00:15:47,850 - DEBUG - Set target channel(s) as ['OrigRNA']\n",
+ "2025-02-20 00:15:48,034 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n",
+ "2025-02-20 00:15:48,035 - DEBUG - Set target channel(s) as ['OrigAGP']\n"
+ ]
+ }
+ ],
+ "source": [
+ "for j, channel_name in enumerate(target_channel_names):\n",
+ "\n",
+ " pds.set_input_channel_keys([input_channel_name])\n",
+ " pds.set_target_channel_keys([channel_name])\n",
+ "\n",
+ " CHANNEL_EXPORT_PATH = EXAMPLE_PATCH_DATA_EXPORT_PATH / channel_name\n",
+ " CHANNEL_EXPORT_PATH.mkdir(exist_ok=True)\n",
+ "\n",
+ " for i in range(len(pds)):\n",
+ " input, target = pds[i]\n",
+ " input_name = pds.input_names\n",
+ " target_name = pds.target_names\n",
+ " patch_coord = pds.patch_coords\n",
+ "\n",
+ " if j == 0:\n",
+ " imageio.imwrite(\n",
+ " INPUT_EXPORT_PATH / f'{input_name[0].stem}_{patch_coord[0]}_{patch_coord[1]}.tiff', \n",
+ " input[0].numpy().astype(np.uint16))\n",
+ "\n",
+ " imageio.imwrite(\n",
+ " CHANNEL_EXPORT_PATH / f'{target_name[0].stem}_{patch_coord[0]}_{patch_coord[1]}.tiff', \n",
+ " target[0].numpy().astype(np.uint16))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "speckle_analysis",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/minimal_example.ipynb b/examples/minimal_example.ipynb
new file mode 100644
index 0000000..d3842c5
--- /dev/null
+++ b/examples/minimal_example.ipynb
@@ -0,0 +1,1019 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# A minimal example to demonstrate how the trainer for FNet and wGAN GP plus the callbacks works along with patched dataset\n",
+ "\n",
+ "Is dependent on the files produced by 1.illumination_correction/0.create_loaddata_csvs ALSF pilot data repo https://github.com/WayScience/pediatric_cancer_atlas_profiling"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/weishanli/Waylab\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/weishanli/anaconda3/envs/cp_gan_env/lib/python3.9/site-packages/albumentations/__init__.py:28: UserWarning: A new version of Albumentations is available: '2.0.5' (you have '2.0.4'). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n",
+ " check_for_updates()\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "import pathlib\n",
+ "\n",
+ "import pandas as pd\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "\n",
+ "sys.path.append(str(pathlib.Path('.').absolute().parent.parent))\n",
+ "print(str(pathlib.Path('.').absolute().parent.parent))\n",
+ "\n",
+ "## Dataset\n",
+ "from virtual_stain_flow.datasets.PatchDataset import PatchDataset\n",
+ "from virtual_stain_flow.datasets.CachedDataset import CachedDataset\n",
+ "\n",
+ "## FNet training\n",
+ "from virtual_stain_flow.models.fnet import FNet\n",
+ "from virtual_stain_flow.trainers.Trainer import Trainer\n",
+ "\n",
+ "## wGaN training\n",
+ "from virtual_stain_flow.models.unet import UNet\n",
+ "from virtual_stain_flow.models.discriminator import GlobalDiscriminator\n",
+ "from virtual_stain_flow.trainers.WGANTrainer import WGANTrainer\n",
+ "\n",
+ "## wGaN losses\n",
+ "from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss\n",
+ "from virtual_stain_flow.losses.DiscriminatorLoss import WassersteinLoss\n",
+ "from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss\n",
+ "\n",
+ "from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize\n",
+ "\n",
+ "## Metrics\n",
+ "from virtual_stain_flow.metrics.MetricsWrapper import MetricsWrapper\n",
+ "from virtual_stain_flow.metrics.PSNR import PSNR\n",
+ "from virtual_stain_flow.metrics.SSIM import SSIM\n",
+ "\n",
+ "## callback\n",
+ "from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger\n",
+ "from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePlot\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Specify train output paths"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "EXAMPLE_DIR = pathlib.Path('.').absolute() / 'example_train'\n",
+ "EXAMPLE_DIR.mkdir(exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!rm -rf example_train/*"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "PLOT_DIR = EXAMPLE_DIR / 'plot'\n",
+ "PLOT_DIR.mkdir(parents=True, exist_ok=True)\n",
+ "\n",
+ "MLFLOW_DIR =EXAMPLE_DIR / 'mlflow'\n",
+ "MLFLOW_DIR.mkdir(parents=True, exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Specify paths to loaddata and read single cell features"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " FileName_OrigBrightfield \\\n",
+ "2079 r06c22f01p01-ch1sk1fk1fl1.tiff \n",
+ "668 r05c09f03p01-ch1sk1fk1fl1.tiff \n",
+ "2073 r05c22f04p01-ch1sk1fk1fl1.tiff \n",
+ "1113 r06c13f07p01-ch1sk1fk1fl1.tiff \n",
+ "788 r06c10f06p01-ch1sk1fk1fl1.tiff \n",
+ "\n",
+ " PathName_OrigBrightfield \\\n",
+ "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "\n",
+ " FileName_OrigER \\\n",
+ "2079 r06c22f01p01-ch2sk1fk1fl1.tiff \n",
+ "668 r05c09f03p01-ch2sk1fk1fl1.tiff \n",
+ "2073 r05c22f04p01-ch2sk1fk1fl1.tiff \n",
+ "1113 r06c13f07p01-ch2sk1fk1fl1.tiff \n",
+ "788 r06c10f06p01-ch2sk1fk1fl1.tiff \n",
+ "\n",
+ " PathName_OrigER \\\n",
+ "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "\n",
+ " FileName_OrigAGP \\\n",
+ "2079 r06c22f01p01-ch3sk1fk1fl1.tiff \n",
+ "668 r05c09f03p01-ch3sk1fk1fl1.tiff \n",
+ "2073 r05c22f04p01-ch3sk1fk1fl1.tiff \n",
+ "1113 r06c13f07p01-ch3sk1fk1fl1.tiff \n",
+ "788 r06c10f06p01-ch3sk1fk1fl1.tiff \n",
+ "\n",
+ " PathName_OrigAGP \\\n",
+ "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "\n",
+ " FileName_OrigMito \\\n",
+ "2079 r06c22f01p01-ch4sk1fk1fl1.tiff \n",
+ "668 r05c09f03p01-ch4sk1fk1fl1.tiff \n",
+ "2073 r05c22f04p01-ch4sk1fk1fl1.tiff \n",
+ "1113 r06c13f07p01-ch4sk1fk1fl1.tiff \n",
+ "788 r06c10f06p01-ch4sk1fk1fl1.tiff \n",
+ "\n",
+ " PathName_OrigMito \\\n",
+ "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n",
+ "\n",
+ " FileName_OrigDNA \\\n",
+ "2079 r06c22f01p01-ch5sk1fk1fl1.tiff \n",
+ "668 r05c09f03p01-ch5sk1fk1fl1.tiff \n",
+ "2073 r05c22f04p01-ch5sk1fk1fl1.tiff \n",
+ "1113 r06c13f07p01-ch5sk1fk1fl1.tiff \n",
+ "788 r06c10f06p01-ch5sk1fk1fl1.tiff \n",
+ "\n",
+ " PathName_OrigDNA ... \\\n",
+ "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n",
+ "\n",
+ " Metadata_AbsPositionZ Metadata_ChannelID Metadata_Col Metadata_FieldID \\\n",
+ "2079 0.134358 6 22 1 \n",
+ "668 0.134405 6 9 3 \n",
+ "2073 0.134366 6 22 4 \n",
+ "1113 0.134347 6 13 7 \n",
+ "788 0.134381 6 10 6 \n",
+ "\n",
+ " Metadata_PlaneID Metadata_PositionX Metadata_PositionY \\\n",
+ "2079 1 0.000000 0.000000 \n",
+ "668 1 0.000000 0.000646 \n",
+ "2073 1 0.000646 0.000646 \n",
+ "1113 1 -0.000646 -0.000646 \n",
+ "788 1 -0.000646 0.000000 \n",
+ "\n",
+ " Metadata_PositionZ Metadata_Row Metadata_Reimaged \n",
+ "2079 -0.000006 6 False \n",
+ "668 -0.000006 5 False \n",
+ "2073 -0.000006 5 False \n",
+ "1113 -0.000006 6 False \n",
+ "788 -0.000006 6 False \n",
+ "\n",
+ "[5 rows x 25 columns]\n",
+ " Metadata_Plate Metadata_Well Metadata_Site \\\n",
+ "0 BR00143976 C03 2 \n",
+ "1 BR00143976 C03 6 \n",
+ "2 BR00143976 C03 9 \n",
+ "3 BR00143976 C03 5 \n",
+ "4 BR00143976 C03 7 \n",
+ "\n",
+ " Metadata_Cells_Location_Center_X Metadata_Cells_Location_Center_Y \n",
+ "0 629.552987 62.017799 \n",
+ "1 279.951864 56.588228 \n",
+ "2 876.508878 205.794360 \n",
+ "3 479.254866 45.496581 \n",
+ "4 866.557068 205.908787 \n"
+ ]
+ }
+ ],
+ "source": [
+ "## REPLACE WITH YOUR OWN PATHS\n",
+ "loaddata_csv_path = pathlib.Path(\n",
+ " '/REPLACE/WITH/YOUR/PATH'\n",
+ " )\n",
+ "sc_features_parquet_path = pathlib.Path(\n",
+ " '/REPLACE/WITH/YOUR/PATH'\n",
+ " )\n",
+ "\n",
+ "if loaddata_csv_path.exists():\n",
+ " try:\n",
+ " loaddata_csv = next(loaddata_csv_path.glob('*.csv'))\n",
+ " except:\n",
+ " raise FileNotFoundError(\"No loaddata csv found\")\n",
+ "else:\n",
+ " raise ValueError(\"Incorrect loaddata csv path\")\n",
+ "\n",
+ "loaddata_df = pd.read_csv(loaddata_csv)\n",
+ "# subsample to reduce runtime\n",
+ "loaddata_df = loaddata_df.sample(n=100, random_state=42)\n",
+ "\n",
+ "sc_features = pd.DataFrame()\n",
+ "for plate in loaddata_df['Metadata_Plate'].unique():\n",
+ " sc_features_parquet = sc_features_parquet_path / f'{plate}_sc_normalized.parquet'\n",
+ " if not sc_features_parquet.exists():\n",
+ " print(f'{sc_features_parquet} does not exist, skipping...')\n",
+ " continue \n",
+ " else:\n",
+ " sc_features = pd.concat([\n",
+ " sc_features, \n",
+ " pd.read_parquet(\n",
+ " sc_features_parquet,\n",
+ " columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']\n",
+ " )\n",
+ " ])\n",
+ "\n",
+ "print(loaddata_df.head())\n",
+ "print(sc_features.head())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Configure Patch size and channels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "PATCH_SIZE = 256\n",
+ "\n",
+ "channel_names = [\n",
+ " \"OrigBrightfield\",\n",
+ " \"OrigDNA\",\n",
+ " \"OrigER\",\n",
+ " \"OrigMito\",\n",
+ " \"OrigRNA\",\n",
+ " \"OrigAGP\",\n",
+ "]\n",
+ "input_channel_name = \"OrigBrightfield\"\n",
+ "target_channel_names = [ch for ch in channel_names if ch != input_channel_name]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Prep Patch dataset and Cache"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-03-03 20:31:52,602 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n",
+ "2025-03-03 20:31:52,603 - DEBUG - Dataframe supplied for sc_feature, using as is\n",
+ "2025-03-03 20:31:52,604 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n",
+ "2025-03-03 20:31:52,605 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n",
+ "2025-03-03 20:31:52,605 - DEBUG - Merge fields inferred: ['Metadata_Plate', 'Metadata_Well', 'Metadata_Site']\n",
+ "2025-03-03 20:31:52,605 - DEBUG - Dataframe supplied for sc_feature, using as is\n",
+ "2025-03-03 20:31:52,673 - DEBUG - Inferring channel keys from loaddata csv\n",
+ "2025-03-03 20:31:52,674 - DEBUG - Channel keys: {'OrigBrightfield', 'OrigMito', 'OrigAGP', 'OrigER', 'OrigRNA', 'OrigDNA'} inferred from loaddata csv\n",
+ "2025-03-03 20:31:52,674 - DEBUG - Setting input channel(s) ...\n",
+ "2025-03-03 20:31:52,674 - DEBUG - No channel keys specified, skip\n",
+ "2025-03-03 20:31:52,675 - DEBUG - Setting target channel(s) ...\n",
+ "2025-03-03 20:31:52,675 - DEBUG - No channel keys specified, skip\n",
+ "2025-03-03 20:31:52,675 - DEBUG - Setting input transform ...\n",
+ "2025-03-03 20:31:52,676 - DEBUG - Setting target transform ...\n",
+ "2025-03-03 20:31:52,676 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-03-03 20:31:53,018 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n",
+ "2025-03-03 20:31:53,018 - DEBUG - Generating patches that contain cells\n",
+ "2025-03-03 20:31:53,052 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n",
+ "2025-03-03 20:31:53,681 - DEBUG - Generated 461 patches for 93 site/view\n",
+ "2025-03-03 20:31:53,682 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n",
+ "2025-03-03 20:31:53,682 - DEBUG - Set target channel(s) as ['OrigDNA']\n"
+ ]
+ }
+ ],
+ "source": [
+ "pds = PatchDataset(\n",
+ " _loaddata_csv=loaddata_df,\n",
+ " _sc_feature=sc_features,\n",
+ " _input_channel_keys=None,\n",
+ " _target_channel_keys=None,\n",
+ " _input_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),\n",
+ " _target_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),\n",
+ " patch_size=PATCH_SIZE,\n",
+ " verbose=True,\n",
+ " patch_generation_method=\"random_cell\",\n",
+ " patch_generation_random_seed=42\n",
+ ")\n",
+ "\n",
+ "## Set input and target channels\n",
+ "pds.set_input_channel_keys([input_channel_name])\n",
+ "pds.set_target_channel_keys('OrigDNA')\n",
+ "\n",
+ "## Cache for faster training \n",
+ "cds = CachedDataset(\n",
+ " pds,\n",
+ " prefill_cache=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# FNet trainer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train model without callback and check logs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = FNet(depth=4)\n",
+ "lr = 3e-4\n",
+ "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model = model,\n",
+ " optimizer = optimizer,\n",
+ " backprop_loss = nn.L1Loss(),\n",
+ " dataset = cds,\n",
+ " batch_size = 16,\n",
+ " epochs = 10,\n",
+ " patience = 5,\n",
+ " callbacks=None,\n",
+ " early_termination_metric = 'L1Loss',\n",
+ " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n",
+ " device = 'cuda'\n",
+ ")\n",
+ "\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.microsoft.datawrangler.viewer.v0+json": {
+ "columns": [
+ {
+ "name": "index",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "epoch",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "L1Loss",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "val_L1Loss",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "psnr",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "ssim",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "val_psnr",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "val_ssim",
+ "rawType": "float64",
+ "type": "float"
+ }
+ ],
+ "conversionMethod": "pd.DataFrame",
+ "ref": "0e376d8b-d69e-4ca5-a37b-ad7cc30c4534",
+ "rows": [
+ [
+ "0",
+ "1",
+ "0.30758210165160044",
+ "0.367111998796463",
+ "10.07480525970459",
+ "0.026603518053889275",
+ "8.649313926696777",
+ "0.03540125489234924"
+ ],
+ [
+ "1",
+ "2",
+ "0.1586258773292814",
+ "0.15944983959197997",
+ "14.711766242980957",
+ "0.04422195255756378",
+ "15.713263511657715",
+ "0.06773694604635239"
+ ],
+ [
+ "2",
+ "3",
+ "0.09979256739219029",
+ "0.09274622052907944",
+ "17.329845428466797",
+ "0.07210950553417206",
+ "20.035673141479492",
+ "0.11819823086261749"
+ ],
+ [
+ "3",
+ "4",
+ "0.06892516552692368",
+ "0.050482964515686034",
+ "19.144039154052734",
+ "0.1051420047879219",
+ "24.395761489868164",
+ "0.2566310465335846"
+ ],
+ [
+ "4",
+ "5",
+ "0.04841376912026178",
+ "0.04219272881746292",
+ "24.528972625732422",
+ "0.30562901496887207",
+ "25.807828903198242",
+ "0.3405899703502655"
+ ],
+ [
+ "5",
+ "6",
+ "0.03341594719815822",
+ "0.027642089501023294",
+ "27.669620513916016",
+ "0.45510202646255493",
+ "28.5550594329834",
+ "0.5035992860794067"
+ ],
+ [
+ "6",
+ "7",
+ "0.026435295385973796",
+ "0.025301840528845786",
+ "28.809947967529297",
+ "0.5208737254142761",
+ "29.219507217407227",
+ "0.563963770866394"
+ ],
+ [
+ "7",
+ "8",
+ "0.02121852764061519",
+ "0.017991484329104423",
+ "29.607341766357422",
+ "0.5826537609100342",
+ "30.165979385375977",
+ "0.6275455355644226"
+ ],
+ [
+ "8",
+ "9",
+ "0.018385388267536957",
+ "0.017969632521271706",
+ "29.834793090820312",
+ "0.6006285548210144",
+ "29.763931274414062",
+ "0.582464873790741"
+ ],
+ [
+ "9",
+ "10",
+ "0.015868896308044594",
+ "0.018145012110471724",
+ "30.2650203704834",
+ "0.6225405931472778",
+ "29.08074951171875",
+ "0.48871317505836487"
+ ]
+ ],
+ "shape": {
+ "columns": 7,
+ "rows": 10
+ }
+ },
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " epoch | \n",
+ " L1Loss | \n",
+ " val_L1Loss | \n",
+ " psnr | \n",
+ " ssim | \n",
+ " val_psnr | \n",
+ " val_ssim | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0.307582 | \n",
+ " 0.367112 | \n",
+ " 10.074805 | \n",
+ " 0.026604 | \n",
+ " 8.649314 | \n",
+ " 0.035401 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 0.158626 | \n",
+ " 0.159450 | \n",
+ " 14.711766 | \n",
+ " 0.044222 | \n",
+ " 15.713264 | \n",
+ " 0.067737 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 0.099793 | \n",
+ " 0.092746 | \n",
+ " 17.329845 | \n",
+ " 0.072110 | \n",
+ " 20.035673 | \n",
+ " 0.118198 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 0.068925 | \n",
+ " 0.050483 | \n",
+ " 19.144039 | \n",
+ " 0.105142 | \n",
+ " 24.395761 | \n",
+ " 0.256631 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 0.048414 | \n",
+ " 0.042193 | \n",
+ " 24.528973 | \n",
+ " 0.305629 | \n",
+ " 25.807829 | \n",
+ " 0.340590 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 0.033416 | \n",
+ " 0.027642 | \n",
+ " 27.669621 | \n",
+ " 0.455102 | \n",
+ " 28.555059 | \n",
+ " 0.503599 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 7 | \n",
+ " 0.026435 | \n",
+ " 0.025302 | \n",
+ " 28.809948 | \n",
+ " 0.520874 | \n",
+ " 29.219507 | \n",
+ " 0.563964 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 8 | \n",
+ " 0.021219 | \n",
+ " 0.017991 | \n",
+ " 29.607342 | \n",
+ " 0.582654 | \n",
+ " 30.165979 | \n",
+ " 0.627546 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 9 | \n",
+ " 0.018385 | \n",
+ " 0.017970 | \n",
+ " 29.834793 | \n",
+ " 0.600629 | \n",
+ " 29.763931 | \n",
+ " 0.582465 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 10 | \n",
+ " 0.015869 | \n",
+ " 0.018145 | \n",
+ " 30.265020 | \n",
+ " 0.622541 | \n",
+ " 29.080750 | \n",
+ " 0.488713 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " epoch L1Loss val_L1Loss psnr ssim val_psnr val_ssim\n",
+ "0 1 0.307582 0.367112 10.074805 0.026604 8.649314 0.035401\n",
+ "1 2 0.158626 0.159450 14.711766 0.044222 15.713264 0.067737\n",
+ "2 3 0.099793 0.092746 17.329845 0.072110 20.035673 0.118198\n",
+ "3 4 0.068925 0.050483 19.144039 0.105142 24.395761 0.256631\n",
+ "4 5 0.048414 0.042193 24.528973 0.305629 25.807829 0.340590\n",
+ "5 6 0.033416 0.027642 27.669621 0.455102 28.555059 0.503599\n",
+ "6 7 0.026435 0.025302 28.809948 0.520874 29.219507 0.563964\n",
+ "7 8 0.021219 0.017991 29.607342 0.582654 30.165979 0.627546\n",
+ "8 9 0.018385 0.017970 29.834793 0.600629 29.763931 0.582465\n",
+ "9 10 0.015869 0.018145 30.265020 0.622541 29.080750 0.488713"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.DataFrame(trainer.log)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train model with alternative early termination metric"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Early termination at epoch 6 with best validation metric 9.887849807739258\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = FNet(depth=4)\n",
+ "lr = 3e-4\n",
+ "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model = model,\n",
+ " optimizer = optimizer,\n",
+ " backprop_loss = nn.L1Loss(),\n",
+ " dataset = cds,\n",
+ " batch_size = 16,\n",
+ " epochs = 10,\n",
+ " patience = 5,\n",
+ " callbacks=None,\n",
+ " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n",
+ " device = 'cuda',\n",
+ " early_termination_metric = 'psnr' # set early termination metric as psnr for the sake of demonstration\n",
+ ")\n",
+ "\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train with mlflow logger callbacks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mlflow_logger_callback = MlflowLogger(\n",
+ " name='mlflow_logger',\n",
+ " mlflow_uri=MLFLOW_DIR / 'mlruns',\n",
+ " mlflow_experiment_name='Default',\n",
+ " mlflow_start_run_args={'run_name': 'example_train', 'nested': True},\n",
+ " mlflow_log_params_args={\n",
+ " 'lr': 3e-4\n",
+ " },\n",
+ " )\n",
+ "\n",
+ "del trainer\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model = model,\n",
+ " optimizer = optimizer,\n",
+ " backprop_loss = nn.L1Loss(),\n",
+ " dataset = cds,\n",
+ " batch_size = 16,\n",
+ " epochs = 10,\n",
+ " patience = 5,\n",
+ " callbacks=[mlflow_logger_callback],\n",
+ " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n",
+ " device = 'cuda',\n",
+ " early_termination_metric = 'L1Loss'\n",
+ ")\n",
+ "\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# wGaN GP example with mlflow logger callback and plot callback"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "generator = UNet(\n",
+ " n_channels=1,\n",
+ " n_classes=1\n",
+ ")\n",
+ "\n",
+ "discriminator = GlobalDiscriminator(\n",
+ " n_in_channels = 2,\n",
+ " n_in_filters = 64,\n",
+ " _conv_depth = 4,\n",
+ " _pool_before_fc = True,\n",
+ " batch_norm=True\n",
+ ")\n",
+ "\n",
+ "generator_optimizer = optim.Adam(generator.parameters(), \n",
+ " lr=0.0002, \n",
+ " betas=(0., 0.9))\n",
+ "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n",
+ " lr=0.00002, \n",
+ " betas=(0., 0.9),\n",
+ " weight_decay=0.001)\n",
+ "\n",
+ "gp_loss = GradientPenaltyLoss(\n",
+ " _metric_name='gp_loss',\n",
+ " discriminator=discriminator,\n",
+ " weight=10.0,\n",
+ ")\n",
+ "\n",
+ "gen_loss = GeneratorLoss(\n",
+ " _metric_name='gen_loss'\n",
+ ")\n",
+ "\n",
+ "disc_loss = WassersteinLoss(\n",
+ " _metric_name='disc_loss'\n",
+ ")\n",
+ "\n",
+ "mlflow_logger_callback = MlflowLogger(\n",
+ " name='mlflow_logger',\n",
+ " mlflow_uri=MLFLOW_DIR / 'mlruns',\n",
+ " mlflow_experiment_name='Default',\n",
+ " mlflow_start_run_args={'run_name': 'example_train_wgan', 'nested': True},\n",
+ " mlflow_log_params_args={\n",
+ " 'gen_lr': 0.0002,\n",
+ " 'disc_lr': 0.00002\n",
+ " },\n",
+ " )\n",
+ "\n",
+ "plot_callback = IntermediatePlot(\n",
+ " name='plotter',\n",
+ " path=PLOT_DIR,\n",
+ " dataset=pds, # give it the patch dataset as opposed to the cached dataset\n",
+ " indices=[1,3,5,7,9], # plot 5 selected patches images from the dataset\n",
+ " plot_metrics=[SSIM(_metric_name='ssim'), PSNR(_metric_name='psnr')],\n",
+ " figsize=(20, 25),\n",
+ " show_plot=False,\n",
+ ")\n",
+ "\n",
+ "wgan_trainer = WGANTrainer(\n",
+ " dataset=cds,\n",
+ " batch_size=16,\n",
+ " epochs=20,\n",
+ " patience=20, # setting this to prevent unwanted early termination here\n",
+ " device='cuda',\n",
+ " generator=generator,\n",
+ " discriminator=discriminator,\n",
+ " gen_optimizer=generator_optimizer,\n",
+ " disc_optimizer=discriminator_optimizer,\n",
+ " generator_loss_fn=gen_loss,\n",
+ " discriminator_loss_fn=disc_loss,\n",
+ " gradient_penalty_fn=gp_loss,\n",
+ " discriminator_update_freq=1,\n",
+ " generator_update_freq=2,\n",
+ " callbacks=[mlflow_logger_callback, plot_callback],\n",
+ " metrics={'ssim': SSIM(_metric_name='ssim'), \n",
+ " 'psnr': PSNR(_metric_name='psnr')\n",
+ " },\n",
+ " early_termination_metric = 'GeneratorLoss'\n",
+ ")\n",
+ "\n",
+ "wgan_trainer.train()\n",
+ "\n",
+ "del generator\n",
+ "del wgan_trainer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## # wGaN GP example with mlflow logger callback and alternative early termination loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "generator = UNet(\n",
+ " n_channels=1,\n",
+ " n_classes=1\n",
+ ")\n",
+ "\n",
+ "discriminator = GlobalDiscriminator(\n",
+ " n_in_channels = 2,\n",
+ " n_in_filters = 64,\n",
+ " _conv_depth = 4,\n",
+ " _pool_before_fc = True,\n",
+ " batch_norm=True\n",
+ ")\n",
+ "\n",
+ "generator_optimizer = optim.Adam(generator.parameters(), \n",
+ " lr=0.0002, \n",
+ " betas=(0., 0.9))\n",
+ "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n",
+ " lr=0.00002, \n",
+ " betas=(0., 0.9),\n",
+ " weight_decay=0.001)\n",
+ "\n",
+ "gp_loss = GradientPenaltyLoss(\n",
+ " _metric_name='gp_loss',\n",
+ " discriminator=discriminator,\n",
+ " weight=10.0,\n",
+ ")\n",
+ "\n",
+ "gen_loss = GeneratorLoss(\n",
+ " _metric_name='gen_loss'\n",
+ ")\n",
+ "\n",
+ "disc_loss = WassersteinLoss(\n",
+ " _metric_name='disc_loss'\n",
+ ")\n",
+ "\n",
+ "mlflow_logger_callback = MlflowLogger(\n",
+ " name='mlflow_logger',\n",
+ " mlflow_uri=MLFLOW_DIR / 'mlruns',\n",
+ " mlflow_experiment_name='Default',\n",
+ " mlflow_start_run_args={'run_name': 'example_train_wgan_mae_early_term', 'nested': True},\n",
+ " mlflow_log_params_args={\n",
+ " 'gen_lr': 0.0002,\n",
+ " 'disc_lr': 0.00002\n",
+ " },\n",
+ " )\n",
+ "\n",
+ "wgan_trainer = WGANTrainer(\n",
+ " dataset=cds,\n",
+ " batch_size=16,\n",
+ " epochs=20,\n",
+ " patience=5, # lower patience here\n",
+ " device='cuda',\n",
+ " generator=generator,\n",
+ " discriminator=discriminator,\n",
+ " gen_optimizer=generator_optimizer,\n",
+ " disc_optimizer=discriminator_optimizer,\n",
+ " generator_loss_fn=gen_loss,\n",
+ " discriminator_loss_fn=disc_loss,\n",
+ " gradient_penalty_fn=gp_loss,\n",
+ " discriminator_update_freq=1,\n",
+ " generator_update_freq=2,\n",
+ " callbacks=[mlflow_logger_callback],\n",
+ " metrics={'ssim': SSIM(_metric_name='ssim'), \n",
+ " 'psnr': PSNR(_metric_name='psnr'),\n",
+ " 'mae': MetricsWrapper(_metric_name='mae', module=nn.L1Loss()) # use a wrapper for torch nn L1Loss\n",
+ " },\n",
+ " early_termination_metric = 'mae' # update early temrination loss with the supplied L1Loss/mae metric instead of the default GaN generator loss\n",
+ ")\n",
+ "\n",
+ "wgan_trainer.train()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "cp_gan_env",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.21"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/minimal_example_generic_dataset.ipynb b/examples/minimal_example_generic_dataset.ipynb
new file mode 100644
index 0000000..6eb2b1c
--- /dev/null
+++ b/examples/minimal_example_generic_dataset.ipynb
@@ -0,0 +1,840 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# A minimal example to demonstrate how the trainer for FNet and wGAN GP plus the callbacks works along with patched dataset\n",
+ "\n",
+ "Is will not be dependent on the pe2loaddata generated index file from the ALSF pilot data repo unlike the other example notebook"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/weishanli/Waylab\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/weishanli/anaconda3/envs/cp_gan_env/lib/python3.9/site-packages/albumentations/__init__.py:28: UserWarning: A new version of Albumentations is available: '2.0.5' (you have '2.0.4'). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n",
+ " check_for_updates()\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "import pathlib\n",
+ "\n",
+ "import pandas as pd\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "\n",
+ "sys.path.append(str(pathlib.Path('.').absolute().parent.parent))\n",
+ "print(str(pathlib.Path('.').absolute().parent.parent))\n",
+ "\n",
+ "## Dataset\n",
+ "from virtual_stain_flow.datasets.GenericImageDataset import GenericImageDataset\n",
+ "from virtual_stain_flow.datasets.CachedDataset import CachedDataset\n",
+ "\n",
+ "## FNet training\n",
+ "from virtual_stain_flow.models.fnet import FNet\n",
+ "from virtual_stain_flow.trainers.Trainer import Trainer\n",
+ "\n",
+ "## wGaN training\n",
+ "from virtual_stain_flow.models.unet import UNet\n",
+ "from virtual_stain_flow.models.discriminator import GlobalDiscriminator\n",
+ "from virtual_stain_flow.trainers.WGANTrainer import WGANTrainer\n",
+ "\n",
+ "## wGaN losses\n",
+ "from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss\n",
+ "from virtual_stain_flow.losses.DiscriminatorLoss import WassersteinLoss\n",
+ "from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss\n",
+ "\n",
+ "from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize\n",
+ "\n",
+ "## Metrics\n",
+ "from virtual_stain_flow.metrics.MetricsWrapper import MetricsWrapper\n",
+ "from virtual_stain_flow.metrics.PSNR import PSNR\n",
+ "from virtual_stain_flow.metrics.SSIM import SSIM\n",
+ "\n",
+ "## callback\n",
+ "from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger\n",
+ "from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePlot\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Specify train data and output paths"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": [
+ "parameters"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "EXAMPLE_PATCH_DATA_EXPORT_PATH = '/REPLACE/WITH/PATH/TO/DATA'\n",
+ "\n",
+ "EXAMPLE_DIR = pathlib.Path('.').absolute() / 'example_train_generic_dataset'\n",
+ "EXAMPLE_DIR.mkdir(exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!rm -rf example_train_generic_dataset/*"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "PLOT_DIR = EXAMPLE_DIR / 'plot'\n",
+ "PLOT_DIR.mkdir(parents=True, exist_ok=True)\n",
+ "\n",
+ "MLFLOW_DIR =EXAMPLE_DIR / 'mlflow'\n",
+ "MLFLOW_DIR.mkdir(parents=True, exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Configure channels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "channel_names = [\n",
+ " \"OrigBrightfield\",\n",
+ " \"OrigDNA\",\n",
+ " \"OrigER\",\n",
+ " \"OrigMito\",\n",
+ " \"OrigRNA\",\n",
+ " \"OrigAGP\",\n",
+ "]\n",
+ "input_channel_name = \"OrigBrightfield\"\n",
+ "target_channel_names = [ch for ch in channel_names if ch != input_channel_name]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Prep Patch dataset and Cache"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-03-03 20:30:03,223 - DEBUG - Channel keys: {'OrigMito', 'OrigBrightfield', 'OrigDNA', 'OrigRNA', 'OrigER', 'OrigAGP'} detected\n",
+ "2025-03-03 20:30:03,225 - DEBUG - No channel keys specified, skip\n",
+ "2025-03-03 20:30:03,226 - DEBUG - No channel keys specified, skip\n",
+ "2025-03-03 20:30:03,226 - DEBUG - Setting input transform ...\n",
+ "2025-03-03 20:30:03,226 - DEBUG - Setting target transform ...\n",
+ "2025-03-03 20:30:03,226 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n",
+ "2025-03-03 20:30:03,227 - DEBUG - Set target channel(s) as ['OrigDNA']\n"
+ ]
+ }
+ ],
+ "source": [
+ "ds = GenericImageDataset(\n",
+ " image_dir=EXAMPLE_PATCH_DATA_EXPORT_PATH,\n",
+ " site_pattern=r\"^([^_]+_[^_]+_[^_]+)\",\n",
+ " channel_pattern=r\"_([^_]+)\\.tiff$\",\n",
+ " verbose=True\n",
+ ")\n",
+ "\n",
+ "## Set input and target channels\n",
+ "ds.set_input_channel_keys([input_channel_name])\n",
+ "ds.set_target_channel_keys('OrigDNA')\n",
+ "\n",
+ "## Cache for faster training \n",
+ "cds = CachedDataset(\n",
+ " ds,\n",
+ " prefill_cache=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# FNet trainer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train model without callback and check logs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = FNet(depth=4)\n",
+ "lr = 3e-4\n",
+ "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model = model,\n",
+ " optimizer = optimizer,\n",
+ " backprop_loss = nn.L1Loss(),\n",
+ " dataset = cds,\n",
+ " batch_size = 16,\n",
+ " epochs = 10,\n",
+ " patience = 5,\n",
+ " callbacks=None,\n",
+ " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n",
+ " device = 'cuda',\n",
+ " early_termination_metric = None\n",
+ ")\n",
+ "\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.microsoft.datawrangler.viewer.v0+json": {
+ "columns": [
+ {
+ "name": "index",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "epoch",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "L1Loss",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "val_L1Loss",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "psnr",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "ssim",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "val_psnr",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "val_ssim",
+ "rawType": "float64",
+ "type": "float"
+ }
+ ],
+ "conversionMethod": "pd.DataFrame",
+ "ref": "57db911c-095d-4f8b-8f57-ab73118f2b71",
+ "rows": [
+ [
+ "0",
+ "1",
+ "1556.9512803819443",
+ "1610.1253051757812",
+ "-70.28494262695312",
+ "9.063276795728825e-10",
+ "-70.56584167480469",
+ "-8.826770980796539e-10"
+ ],
+ [
+ "1",
+ "2",
+ "1702.884250217014",
+ "1610.0806274414062",
+ "-70.88795471191406",
+ "-3.853541929998983e-09",
+ "-70.56578063964844",
+ "-1.0126930405363055e-09"
+ ],
+ [
+ "2",
+ "3",
+ "1515.5343560112847",
+ "1609.9841918945312",
+ "-70.11585998535156",
+ "-4.895769567525576e-09",
+ "-70.56565856933594",
+ "-1.7035615140770233e-09"
+ ],
+ [
+ "3",
+ "4",
+ "1559.3525390625",
+ "1609.861083984375",
+ "-70.39810943603516",
+ "-3.283770810824649e-09",
+ "-70.56550598144531",
+ "-1.8262883427766496e-09"
+ ],
+ [
+ "4",
+ "5",
+ "1545.5415174696182",
+ "1609.8245849609375",
+ "-70.3438491821289",
+ "-3.253234126532334e-09",
+ "-70.56546020507812",
+ "-1.8821437741678437e-09"
+ ],
+ [
+ "5",
+ "6",
+ "1542.607638888889",
+ "1609.786376953125",
+ "-70.49755859375",
+ "-8.243815630137874e-10",
+ "-70.56541442871094",
+ "-1.5619402438105112e-09"
+ ],
+ [
+ "6",
+ "7",
+ "1501.347873263889",
+ "1609.7578735351562",
+ "-69.79637908935547",
+ "3.219659261421981e-10",
+ "-70.56538391113281",
+ "-1.5038892353658184e-09"
+ ],
+ [
+ "7",
+ "8",
+ "1526.6503228081597",
+ "1609.74462890625",
+ "-70.21217346191406",
+ "-1.515618464065227e-10",
+ "-70.56536102294922",
+ "-9.222221319937773e-10"
+ ],
+ [
+ "8",
+ "9",
+ "1527.799533420139",
+ "1609.7286987304688",
+ "-70.20182800292969",
+ "-8.36740537968339e-11",
+ "-70.56534576416016",
+ "-1.539568472708197e-09"
+ ],
+ [
+ "9",
+ "10",
+ "1627.6031358506943",
+ "1609.71826171875",
+ "-70.63406372070312",
+ "-1.5872374595216066e-11",
+ "-70.5653305053711",
+ "-7.939728874362117e-10"
+ ]
+ ],
+ "shape": {
+ "columns": 7,
+ "rows": 10
+ }
+ },
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " epoch | \n",
+ " L1Loss | \n",
+ " val_L1Loss | \n",
+ " psnr | \n",
+ " ssim | \n",
+ " val_psnr | \n",
+ " val_ssim | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 1556.951280 | \n",
+ " 1610.125305 | \n",
+ " -70.284943 | \n",
+ " 9.063277e-10 | \n",
+ " -70.565842 | \n",
+ " -8.826771e-10 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 1702.884250 | \n",
+ " 1610.080627 | \n",
+ " -70.887955 | \n",
+ " -3.853542e-09 | \n",
+ " -70.565781 | \n",
+ " -1.012693e-09 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 1515.534356 | \n",
+ " 1609.984192 | \n",
+ " -70.115860 | \n",
+ " -4.895770e-09 | \n",
+ " -70.565659 | \n",
+ " -1.703562e-09 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 1559.352539 | \n",
+ " 1609.861084 | \n",
+ " -70.398109 | \n",
+ " -3.283771e-09 | \n",
+ " -70.565506 | \n",
+ " -1.826288e-09 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 1545.541517 | \n",
+ " 1609.824585 | \n",
+ " -70.343849 | \n",
+ " -3.253234e-09 | \n",
+ " -70.565460 | \n",
+ " -1.882144e-09 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 1542.607639 | \n",
+ " 1609.786377 | \n",
+ " -70.497559 | \n",
+ " -8.243816e-10 | \n",
+ " -70.565414 | \n",
+ " -1.561940e-09 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 7 | \n",
+ " 1501.347873 | \n",
+ " 1609.757874 | \n",
+ " -69.796379 | \n",
+ " 3.219659e-10 | \n",
+ " -70.565384 | \n",
+ " -1.503889e-09 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 8 | \n",
+ " 1526.650323 | \n",
+ " 1609.744629 | \n",
+ " -70.212173 | \n",
+ " -1.515618e-10 | \n",
+ " -70.565361 | \n",
+ " -9.222221e-10 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 9 | \n",
+ " 1527.799533 | \n",
+ " 1609.728699 | \n",
+ " -70.201828 | \n",
+ " -8.367405e-11 | \n",
+ " -70.565346 | \n",
+ " -1.539568e-09 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 10 | \n",
+ " 1627.603136 | \n",
+ " 1609.718262 | \n",
+ " -70.634064 | \n",
+ " -1.587237e-11 | \n",
+ " -70.565331 | \n",
+ " -7.939729e-10 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " epoch L1Loss val_L1Loss psnr ssim val_psnr \\\n",
+ "0 1 1556.951280 1610.125305 -70.284943 9.063277e-10 -70.565842 \n",
+ "1 2 1702.884250 1610.080627 -70.887955 -3.853542e-09 -70.565781 \n",
+ "2 3 1515.534356 1609.984192 -70.115860 -4.895770e-09 -70.565659 \n",
+ "3 4 1559.352539 1609.861084 -70.398109 -3.283771e-09 -70.565506 \n",
+ "4 5 1545.541517 1609.824585 -70.343849 -3.253234e-09 -70.565460 \n",
+ "5 6 1542.607639 1609.786377 -70.497559 -8.243816e-10 -70.565414 \n",
+ "6 7 1501.347873 1609.757874 -69.796379 3.219659e-10 -70.565384 \n",
+ "7 8 1526.650323 1609.744629 -70.212173 -1.515618e-10 -70.565361 \n",
+ "8 9 1527.799533 1609.728699 -70.201828 -8.367405e-11 -70.565346 \n",
+ "9 10 1627.603136 1609.718262 -70.634064 -1.587237e-11 -70.565331 \n",
+ "\n",
+ " val_ssim \n",
+ "0 -8.826771e-10 \n",
+ "1 -1.012693e-09 \n",
+ "2 -1.703562e-09 \n",
+ "3 -1.826288e-09 \n",
+ "4 -1.882144e-09 \n",
+ "5 -1.561940e-09 \n",
+ "6 -1.503889e-09 \n",
+ "7 -9.222221e-10 \n",
+ "8 -1.539568e-09 \n",
+ "9 -7.939729e-10 "
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.DataFrame(trainer.log)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train model with alternative early termination metric"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Early termination at epoch 6 with best validation metric -69.82839965820312\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = FNet(depth=4)\n",
+ "lr = 3e-4\n",
+ "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model = model,\n",
+ " optimizer = optimizer,\n",
+ " backprop_loss = nn.L1Loss(),\n",
+ " dataset = cds,\n",
+ " batch_size = 16,\n",
+ " epochs = 10,\n",
+ " patience = 5,\n",
+ " callbacks=None,\n",
+ " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n",
+ " device = 'cuda',\n",
+ " early_termination_metric = 'psnr' # set early termination metric as psnr for the sake of demonstration\n",
+ ")\n",
+ "\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train with mlflow logger callbacks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mlflow_logger_callback = MlflowLogger(\n",
+ " name='mlflow_logger',\n",
+ " mlflow_uri=MLFLOW_DIR / 'mlruns',\n",
+ " mlflow_experiment_name='Default',\n",
+ " mlflow_start_run_args={'run_name': 'example_train', 'nested': True},\n",
+ " mlflow_log_params_args={\n",
+ " 'lr': 3e-4\n",
+ " },\n",
+ " )\n",
+ "\n",
+ "del trainer\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model = model,\n",
+ " optimizer = optimizer,\n",
+ " backprop_loss = nn.L1Loss(),\n",
+ " dataset = cds,\n",
+ " batch_size = 16,\n",
+ " epochs = 10,\n",
+ " patience = 5,\n",
+ " callbacks=[mlflow_logger_callback],\n",
+ " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n",
+ " device = 'cuda'\n",
+ ")\n",
+ "\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# wGaN GP example with mlflow logger callback and plot callback"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "generator = UNet(\n",
+ " n_channels=1,\n",
+ " n_classes=1\n",
+ ")\n",
+ "\n",
+ "discriminator = GlobalDiscriminator(\n",
+ " n_in_channels = 2,\n",
+ " n_in_filters = 64,\n",
+ " _conv_depth = 4,\n",
+ " _pool_before_fc = True\n",
+ ")\n",
+ "\n",
+ "generator_optimizer = optim.Adam(generator.parameters(), \n",
+ " lr=0.0002, \n",
+ " betas=(0., 0.9))\n",
+ "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n",
+ " lr=0.00002, \n",
+ " betas=(0., 0.9),\n",
+ " weight_decay=0.001)\n",
+ "\n",
+ "gp_loss = GradientPenaltyLoss(\n",
+ " _metric_name='gp_loss',\n",
+ " discriminator=discriminator,\n",
+ " weight=10.0,\n",
+ ")\n",
+ "\n",
+ "gen_loss = GeneratorLoss(\n",
+ " _metric_name='gen_loss'\n",
+ ")\n",
+ "\n",
+ "disc_loss = WassersteinLoss(\n",
+ " _metric_name='disc_loss'\n",
+ ")\n",
+ "\n",
+ "mlflow_logger_callback = MlflowLogger(\n",
+ " name='mlflow_logger',\n",
+ " mlflow_uri=MLFLOW_DIR / 'mlruns',\n",
+ " mlflow_experiment_name='Default',\n",
+ " mlflow_start_run_args={'run_name': 'example_train_wgan', 'nested': True},\n",
+ " mlflow_log_params_args={\n",
+ " 'gen_lr': 0.0002,\n",
+ " 'disc_lr': 0.00002\n",
+ " },\n",
+ " )\n",
+ "\n",
+ "plot_callback = IntermediatePlot(\n",
+ " name='plotter',\n",
+ " path=PLOT_DIR,\n",
+ " dataset=ds, # give it the patch dataset as opposed to the cached dataset\n",
+ " indices=[1,3,5,7,9], # plot 5 selected patches images from the dataset\n",
+ " plot_metrics=[SSIM(_metric_name='ssim'), PSNR(_metric_name='psnr')],\n",
+ " figsize=(20, 25),\n",
+ " show_plot=False,\n",
+ ")\n",
+ "\n",
+ "wgan_trainer = WGANTrainer(\n",
+ " dataset=cds,\n",
+ " batch_size=16,\n",
+ " epochs=20,\n",
+ " patience=20, # setting this to prevent unwanted early termination here\n",
+ " device='cuda',\n",
+ " generator=generator,\n",
+ " discriminator=discriminator,\n",
+ " gen_optimizer=generator_optimizer,\n",
+ " disc_optimizer=discriminator_optimizer,\n",
+ " generator_loss_fn=gen_loss,\n",
+ " discriminator_loss_fn=disc_loss,\n",
+ " gradient_penalty_fn=gp_loss,\n",
+ " discriminator_update_freq=1,\n",
+ " generator_update_freq=2,\n",
+ " callbacks=[mlflow_logger_callback, plot_callback],\n",
+ " metrics={'ssim': SSIM(_metric_name='ssim'), \n",
+ " 'psnr': PSNR(_metric_name='psnr')}\n",
+ ")\n",
+ "\n",
+ "wgan_trainer.train()\n",
+ "\n",
+ "del generator\n",
+ "del wgan_trainer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## # wGaN GP example with mlflow logger callback and alternative early termination loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "generator = UNet(\n",
+ " n_channels=1,\n",
+ " n_classes=1\n",
+ ")\n",
+ "\n",
+ "discriminator = GlobalDiscriminator(\n",
+ " n_in_channels = 2,\n",
+ " n_in_filters = 64,\n",
+ " _conv_depth = 4,\n",
+ " _pool_before_fc = True\n",
+ ")\n",
+ "\n",
+ "generator_optimizer = optim.Adam(generator.parameters(), \n",
+ " lr=0.0002, \n",
+ " betas=(0., 0.9))\n",
+ "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n",
+ " lr=0.00002, \n",
+ " betas=(0., 0.9),\n",
+ " weight_decay=0.001)\n",
+ "\n",
+ "gp_loss = GradientPenaltyLoss(\n",
+ " _metric_name='gp_loss',\n",
+ " discriminator=discriminator,\n",
+ " weight=10.0,\n",
+ ")\n",
+ "\n",
+ "gen_loss = GeneratorLoss(\n",
+ " _metric_name='gen_loss'\n",
+ ")\n",
+ "\n",
+ "disc_loss = WassersteinLoss(\n",
+ " _metric_name='disc_loss'\n",
+ ")\n",
+ "\n",
+ "mlflow_logger_callback = MlflowLogger(\n",
+ " name='mlflow_logger',\n",
+ " mlflow_uri=MLFLOW_DIR / 'mlruns',\n",
+ " mlflow_experiment_name='Default',\n",
+ " mlflow_start_run_args={'run_name': 'example_train_wgan_mae_early_term', 'nested': True},\n",
+ " mlflow_log_params_args={\n",
+ " 'gen_lr': 0.0002,\n",
+ " 'disc_lr': 0.00002\n",
+ " },\n",
+ " )\n",
+ "\n",
+ "wgan_trainer = WGANTrainer(\n",
+ " dataset=cds,\n",
+ " batch_size=16,\n",
+ " epochs=20,\n",
+ " patience=5, # lower patience here\n",
+ " device='cuda',\n",
+ " generator=generator,\n",
+ " discriminator=discriminator,\n",
+ " gen_optimizer=generator_optimizer,\n",
+ " disc_optimizer=discriminator_optimizer,\n",
+ " generator_loss_fn=gen_loss,\n",
+ " discriminator_loss_fn=disc_loss,\n",
+ " gradient_penalty_fn=gp_loss,\n",
+ " discriminator_update_freq=1,\n",
+ " generator_update_freq=2,\n",
+ " callbacks=[mlflow_logger_callback],\n",
+ " metrics={'ssim': SSIM(_metric_name='ssim'), \n",
+ " 'psnr': PSNR(_metric_name='psnr'),\n",
+ " 'mae': MetricsWrapper(_metric_name='mae', module=nn.L1Loss()) # use a wrapper for torch nn L1Loss\n",
+ " },\n",
+ " early_termination_metric = 'mae' # update early temrination loss with the supplied L1Loss/mae metric instead of the default GaN generator loss\n",
+ ")\n",
+ "\n",
+ "wgan_trainer.train()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "cp_gan_env",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.21"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/losses/AbstractLoss.py b/losses/AbstractLoss.py
new file mode 100644
index 0000000..c96c70b
--- /dev/null
+++ b/losses/AbstractLoss.py
@@ -0,0 +1,50 @@
+from abc import ABC, abstractmethod
+
+import torch
+import torch.nn as nn
+
+"""
+Adapted from https://github.com/WayScience/nuclear_speckles_analysis
+"""
+class AbstractLoss(nn.Module, ABC):
+ """Abstract class for metrics"""
+
+ def __init__(self, _metric_name: str):
+
+ super(AbstractLoss, self).__init__()
+
+ self._metric_name = _metric_name
+ self._trainer = None
+
+ @property
+ def trainer(self):
+ return self._trainer
+
+ @trainer.setter
+ def trainer(self, value):
+ """
+ Setter of trainer meant to be called by the trainer class during initialization
+ """
+ self._trainer = value
+
+ @property
+ def metric_name(self, _metric_name: str):
+ """Defines the mertic name returned by the class."""
+ return self._metric_name
+
+ @abstractmethod
+ def forward(self, truth: torch.Tensor, generated: torch.Tensor
+ ) -> float:
+ """
+ Computes the metric given information about the data
+
+ :param truth: The tensor containing the ground truth image,
+ should be of shape [batch_size, channel_number, img_height, img_width].
+ :type truth: torch.Tensor
+ :param generated: The tensor containing model generated image,
+ should be of shape [batch_size, channel_number, img_height, img_width].
+ :type generated: torch.Tensor
+ :return: The computed metric as a float value.
+ :rtype: float
+ """
+ pass
\ No newline at end of file
diff --git a/losses/DiscriminatorLoss.py b/losses/DiscriminatorLoss.py
new file mode 100644
index 0000000..84ee7ca
--- /dev/null
+++ b/losses/DiscriminatorLoss.py
@@ -0,0 +1,38 @@
+import torch
+
+from .AbstractLoss import AbstractLoss
+
+class WassersteinLoss(AbstractLoss):
+ """
+ This class implements the loss function for the discriminator in a Wasserstein Generative Adversarial Network (wGAN).
+ The discriminator loss measures how well the discriminator is able to distinguish between real (ground expected_truth)
+ images and fake (expected_generated) images produced by the generator.
+ """
+ def __init__(self, _metric_name):
+ super().__init__(_metric_name)
+
+ def forward(self, expected_truth, expected_generated):
+ """
+ Computes the Wasserstein Discriminator Loss given probability scores expected_truth and expected_generated from the discriminator
+
+ :param expected_truth: The tensor containing the ground expected_truth
+ probability score predicted by the discriminator over a batch of real images (input target pair),
+ should be of shape [batch_size, 1].
+ :type expected_truth: torch.Tensor
+ :param expected_generated: The tensor containing model expected_generated
+ probability score predicted by the discriminator over a batch of generated images (input generated pair),
+ should be of shape [batch_size, 1].
+ :type expected_generated: torch.Tensor
+ :return: The computed metric as a float value.
+ :rtype: float
+ """
+
+ # If the probability output is more than Scalar, take the mean of the output
+ # For compatibility with both a Discriminator class that would output a scalar probability (currently implemented)
+ # and a Discriminator class that would output a 2d matrix of probabilities (currently not implemented)
+ if expected_truth.dim() >= 3:
+ expected_truth = torch.mean(expected_truth, tuple(range(2, expected_truth.dim())))
+ if expected_generated.dim() >= 3:
+ expected_generated = torch.mean(expected_generated, tuple(range(2, expected_generated.dim())))
+
+ return (expected_generated - expected_truth).mean()
\ No newline at end of file
diff --git a/losses/GeneratorLoss.py b/losses/GeneratorLoss.py
new file mode 100644
index 0000000..9ec2bd5
--- /dev/null
+++ b/losses/GeneratorLoss.py
@@ -0,0 +1,65 @@
+from typing import Optional
+
+import torch
+from torch.nn import L1Loss
+
+from .AbstractLoss import AbstractLoss
+
+class GeneratorLoss(AbstractLoss):
+ """
+ Computes the loss for the GaN generator.
+ Combines an adversarial loss component with an image reconstruction loss.
+ """
+ def __init__(self,
+ _metric_name: str,
+ reconstruction_loss: Optional[torch.tensor] = L1Loss(),
+ reconstruction_weight: float = 1.0
+ ):
+ """
+ :param reconstruction_loss: The image reconstruction loss,
+ defaults to L1Loss(reduce=False)
+ :type reconstruction_loss: torch.tensor
+ :param reconstruction_weight: The weight for the image reconstruction loss, defaults to 1.0
+ :type reconstruction_weight: float
+ """
+
+ super().__init__(_metric_name)
+
+ self._reconstruction_loss = reconstruction_loss
+ if isinstance(reconstruction_weight, float):
+ self._reconstruction_weight = reconstruction_weight
+ else:
+ raise ValueError("reconstruction_weight must be a float value")
+
+ def forward(self,
+ discriminator_probs: torch.tensor,
+ truth: torch.tensor,
+ generated: torch.tensor,
+ epoch: int = 0
+ ):
+ """
+ Computes the loss for the GaN generator.
+
+ :param discriminator_probs: The probabilities of the discriminator for the fake images being real.
+ :type discriminator_probs: torch.tensor
+ :param truth: The tensor containing the ground truth image,
+ should be of shape [batch_size, channel_number, img_height, img_width].
+ :type truth: torch.Tensor
+ :param generated: The tensor containing model generated image,
+ should be of shape [batch_size, channel_number, img_height, img_width].
+ :type generated: torch.Tensor
+ :param epoch: The current epoch number.
+ Used for a smoothing weight for the adversarial loss component
+ Defaults to 0.
+ :type epoch: int
+ :return: The computed metric as a float value.
+ :rtype: float
+ """
+
+ # Adversarial loss
+ adversarial_loss = -torch.mean(discriminator_probs)
+ adversarial_loss = 0.01 * adversarial_loss/(epoch + 1)
+
+ image_loss = self._reconstruction_loss(generated, truth)
+
+ return adversarial_loss + self._reconstruction_weight * image_loss.mean()
\ No newline at end of file
diff --git a/losses/GradientPenaltyLoss.py b/losses/GradientPenaltyLoss.py
new file mode 100644
index 0000000..fa96a00
--- /dev/null
+++ b/losses/GradientPenaltyLoss.py
@@ -0,0 +1,44 @@
+import torch
+import torch.autograd as autograd
+
+from .AbstractLoss import AbstractLoss
+
+class GradientPenaltyLoss(AbstractLoss):
+ def __init__(self, _metric_name, discriminator, weight=10.0):
+ super().__init__(_metric_name)
+
+ self._discriminator = discriminator
+ self._weight = weight
+
+ def forward(self, truth, generated):
+ """
+ Computes Gradient Penalty Loss for wGaN GP
+
+ :param truth: The tensor containing the ground truth image,
+ should be of shape [batch_size, channel_number, img_height, img_width].
+ :type truth: torch.Tensor
+ :param generated: The tensor containing model generated image,
+ should be of shape [batch_size, channel_number, img_height, img_width].
+ :type generated: torch.Tensor
+ :return: The computed metric as a float value.
+ :rtype: float
+ """
+
+ device = self.trainer.device
+
+ batch_size = truth.size(0)
+ eta = torch.rand(batch_size, 1, 1, 1, device=device).expand_as(truth)
+ interpolated = (eta * truth + (1 - eta) * generated).requires_grad_(True)
+ prob_interpolated = self._discriminator(interpolated)
+
+ gradients = autograd.grad(
+ outputs=prob_interpolated,
+ inputs=interpolated,
+ grad_outputs=torch.ones_like(prob_interpolated),
+ create_graph=True,
+ retain_graph=True,
+ )[0]
+
+ gradients = gradients.view(batch_size, -1)
+ gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
+ return self._weight * gradient_penalty
\ No newline at end of file
diff --git a/losses/README.md b/losses/README.md
new file mode 100644
index 0000000..2fe0a00
--- /dev/null
+++ b/losses/README.md
@@ -0,0 +1 @@
+Here lives the loss functions used by wGaN GP
\ No newline at end of file
diff --git a/metrics/AbstractMetrics.py b/metrics/AbstractMetrics.py
new file mode 100644
index 0000000..24fc571
--- /dev/null
+++ b/metrics/AbstractMetrics.py
@@ -0,0 +1,91 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+class AbstractMetrics(nn.Module, ABC):
+ """Abstract class for metrics"""
+
+ def __init__(self, _metric_name: str):
+
+ super(AbstractMetrics, self).__init__()
+
+ self.__metric_name = _metric_name
+
+ self.__train_metric_values = []
+ self.__val_metric_values = []
+
+ @property
+ def metric_name(self):
+ """Defines the mertic name returned by the class."""
+ return self.__metric_name
+
+ @property
+ def train_metric_values(self):
+ """Returns the training metric values."""
+ return self.__train_metric_values
+
+ @property
+ def val_metric_values(self):
+ """Returns the validation metric values."""
+ return self.__val_metric_values
+
+ @abstractmethod
+ def forward(self,
+ _generated_outputs: torch.tensor,
+ _targets: torch.tensor
+ ) -> torch.tensor:
+ """Computes the metric given information about the data."""
+ pass
+
+ def update(self,
+ _generated_outputs: torch.tensor,
+ _targets: torch.tensor,
+ validation: bool=False
+ ) -> None:
+ """Updates the metric with the new data."""
+ if validation:
+ self.__val_metric_values.append(self.forward(_generated_outputs, _targets))
+ else:
+ self.__train_metric_values.append(self.forward(_generated_outputs, _targets))
+
+ def reset(self):
+ """Resets the metric."""
+ self.__train_metric_values = []
+ self.__val_metric_values = []
+
+ def compute(self, **kwargs):
+ """
+ Calls the aggregate_metrics method to compute the metric value for now
+ In future may be used for more complex computations
+ """
+ return self.aggregate_metrics(**kwargs)
+
+ def aggregate_metrics(self, aggregation: Optional[str] = 'mean'):
+ """
+ Aggregates the metric value over batches
+
+ :param aggregation: The aggregation method to use, by default 'mean'
+ :type aggregation: Optional[str]
+ :return: The aggregated metric value for training and validation
+ :rtype: Tuple[torch.tensor, torch.tensor]
+ """
+
+ if aggregation == 'mean':
+ return \
+ torch.mean(torch.stack(self.__train_metric_values)) if len(self.__train_metric_values) > 0 else None , \
+ torch.mean(torch.stack(self.__val_metric_values)) if len(self.__val_metric_values) > 0 else None
+
+ elif aggregation == 'sum':
+ return \
+ torch.sum(torch.stack(self.__train_metric_values)) if len(self.__train_metric_values) > 0 else None , \
+ torch.sum(torch.stack(self.__val_metric_values)) if len(self.__val_metric_values) > 0 else None
+
+ elif aggregation is None:
+ return \
+ torch.stack(self.__train_metric_values) if len(self.__train_metric_values) > 0 else None , \
+ torch.stack(self.__val_metric_values) if len(self.__val_metric_values) > 0 else None
+
+ else:
+ raise ValueError(f"Aggregation method {aggregation} is not supported.")
\ No newline at end of file
diff --git a/metrics/MetricsWrapper.py b/metrics/MetricsWrapper.py
new file mode 100644
index 0000000..d261ca2
--- /dev/null
+++ b/metrics/MetricsWrapper.py
@@ -0,0 +1,24 @@
+import torch
+
+from .AbstractMetrics import AbstractMetrics
+
+class MetricsWrapper(AbstractMetrics):
+ """Metrics wrapper class that wraps a pytorch module
+ and calls it forward pass function to accumulate the metric
+ values across batches
+ """
+
+ def __init__(self, _metric_name: str, module: torch.nn.Module):
+ """
+ Initialize the MetricsWrapper class with the metric name and the module.
+
+ :param _metric_name: The name of the metric.
+ :param module: The module to be wrapped. Needs to have a forward function.
+ :type module: torch.nn.Module
+ """
+
+ super(MetricsWrapper, self).__init__(_metric_name)
+ self._module = module
+
+ def forward(self,_generated_outputs: torch.Tensor, _targets: torch.Tensor):
+ return self._module(_generated_outputs, _targets).mean()
\ No newline at end of file
diff --git a/metrics/PSNR.py b/metrics/PSNR.py
new file mode 100644
index 0000000..d07bc4c
--- /dev/null
+++ b/metrics/PSNR.py
@@ -0,0 +1,43 @@
+import torch
+
+from .AbstractMetrics import AbstractMetrics
+
+"""
+Adapted from https://github.com/WayScience/nuclear_speckles_analysis
+"""
+class PSNR(AbstractMetrics):
+ """Computes and tracks the Peak Signal-to-Noise Ratio (PSNR)."""
+
+ def __init__(self, _metric_name: str, _max_pixel_value: int = 1):
+ """
+ Initializes the PSNR metric.
+
+ :param _metric_name: The name of the metric.
+ :param _max_pixel_value: The maximum possible pixel value of the images, by default 1.
+ :type _max_pixel_value: int, optional
+ """
+
+ super(PSNR, self).__init__(_metric_name)
+
+ self.__max_pixel_value = _max_pixel_value
+
+ def forward(self, _generated_outputs: torch.Tensor, _targets: torch.Tensor):
+ """
+ Computes the Peak Signal-to-Noise Ratio (PSNR) between the generated outputs and the target images.
+
+ :param _generated_outputs: The tensor containing the generated output images.
+ :type _generated_outputs: torch.Tensor
+ :param _targets: The tensor containing the target images.
+ :type _targets: torch.Tensor
+ :return: The computed PSNR value.
+ :rtype: torch.Tensor
+ """
+
+ mse = torch.mean((_generated_outputs - _targets) ** 2, dim=[2, 3])
+ psnr = torch.where(
+ mse == 0,
+ torch.tensor(0.0),
+ 10 * torch.log10((self.__max_pixel_value**2) / mse),
+ )
+
+ return psnr.mean()
\ No newline at end of file
diff --git a/metrics/README.md b/metrics/README.md
new file mode 100644
index 0000000..ac8b458
--- /dev/null
+++ b/metrics/README.md
@@ -0,0 +1,2 @@
+Here lives the metric classes which is dependent on a abstract metric class
+Each metric needs to have a foward function implemented over target and predict while the abstract class functions inhertied handles accumulation
\ No newline at end of file
diff --git a/metrics/SSIM.py b/metrics/SSIM.py
new file mode 100644
index 0000000..a06deee
--- /dev/null
+++ b/metrics/SSIM.py
@@ -0,0 +1,52 @@
+import torch
+
+from .AbstractMetrics import AbstractMetrics
+
+"""
+Adapted from https://github.com/WayScience/nuclear_speckles_analysis
+"""
+class SSIM(AbstractMetrics):
+ """Computes and tracks the Structural Similarity Index Measure (SSIM)."""
+
+ def __init__(self, _metric_name: str, _max_pixel_value: int = 1):
+ """
+ Initializes the SSIM metric.
+
+ :param _metric_name: The name of the metric.
+ :param _max_pixel_value: The maximum possible pixel value of the images, by default 1.
+ :type _max_pixel_value: int, optional
+ """
+
+ super(SSIM, self).__init__(_metric_name)
+
+ self.__max_pixel_value = _max_pixel_value
+
+ def forward(self, _generated_outputs: torch.Tensor, _targets: torch.Tensor):
+ """
+ Computes the Structural Similarity Index Measure (SSIM) between the generated outputs and the target images.
+
+ :param _generated_outputs: The tensor containing the generated output images.
+ :type _generated_outputs: torch.Tensor
+ :param _targets: The tensor containing the target images.
+ :type _targets: torch.Tensor
+ :return: The computed SSIM value.
+ :rtype: torch.Tensor
+ """
+
+ mu1 = _generated_outputs.mean(dim=[2, 3], keepdim=True)
+ mu2 = _targets.mean(dim=[2, 3], keepdim=True)
+
+ sigma1_sq = ((_generated_outputs - mu1) ** 2).mean(dim=[2, 3], keepdim=True)
+ sigma2_sq = ((_targets - mu2) ** 2).mean(dim=[2, 3], keepdim=True)
+ sigma12 = ((_generated_outputs - mu1) * (_targets - mu2)).mean(
+ dim=[2, 3], keepdim=True
+ )
+
+ c1 = (self.__max_pixel_value * 0.01) ** 2
+ c2 = (self.__max_pixel_value * 0.03) ** 2
+
+ ssim_value = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / (
+ (mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2)
+ )
+
+ return ssim_value.mean()
\ No newline at end of file
diff --git a/models/README.md b/models/README.md
new file mode 100644
index 0000000..2c0cd8b
--- /dev/null
+++ b/models/README.md
@@ -0,0 +1,3 @@
+Here lives the torch model and parts for FNet, UNet and wGaN GP
+
+Quite unclean in its current state.
\ No newline at end of file
diff --git a/models/discriminator.py b/models/discriminator.py
new file mode 100644
index 0000000..fa93b22
--- /dev/null
+++ b/models/discriminator.py
@@ -0,0 +1,154 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+"""
+Implementation of GaN discriminators to use along with UNet or FNet generator.
+"""
+
+class PatchBasedDiscriminator(nn.Module):
+
+ def __init__(
+ self,
+ n_in_channels: int,
+ n_in_filters: int,
+ _conv_depth: int=4,
+ _leaky_relu_alpha: float=0.2,
+ _batch_norm: bool=False
+ ):
+ """
+ A patch-based discriminator for pix2pix GANs that outputs a feature map
+ of probabilities
+
+ :param n_in_channels: (int) number of input channels
+ :type n_in_channels: int
+ :param n_in_filters: (int) number of filters in the first convolutional layer.
+ Every subsequent layer will double the number of filters
+ :type n_in_filters: int
+ :param _conv_depth: (int) depth of the convolutional network
+ :type _conv_depth: int
+ :param _leaky_relu_alpha: (float) alpha value for leaky ReLU activation.
+ Must be between 0 and 1
+ :type _leaky_relu_alpha: float
+ :param _batch_norm: (bool) whether to use batch normalization, defaults to False
+ :type _batch_norm: bool
+ """
+
+ super().__init__()
+
+ conv_layers = []
+
+ n_channels = n_in_filters
+ conv_layers.append(
+ nn.Conv2d(n_in_channels, n_channels, kernel_size=4, stride=2, padding=1)
+ )
+ conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True))
+
+ # Sequentially add convolutional layers
+ for _ in range(_conv_depth - 2):
+ conv_layers.append(
+ nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=2, padding=1)
+ )
+ conv_layers.append(nn.BatchNorm2d(n_channels * 2))
+ conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True))
+ n_channels *= 2
+
+ # Another layer of conv without downscaling
+ ## TODO: figure out if this is needed
+ conv_layers.append(
+ nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=1, padding=1)
+ )
+
+ if _batch_norm:
+ conv_layers.append(nn.BatchNorm2d(n_channels * 2))
+
+ conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True))
+ n_channels *= 2
+ self._conv_layers = nn.Sequential(*conv_layers)
+
+ # Output layer to get the probability map
+ self.out = nn.Sequential(
+ *[nn.Conv2d(n_channels, 1, kernel_size=4, stride=1, padding=1),
+ nn.Sigmoid()]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self._conv_layers(x)
+ x = self.out(x)
+
+ return x
+
+class GlobalDiscriminator(nn.Module):
+
+ def __init__(
+ self,
+ n_in_channels: int,
+ n_in_filters: int,
+ _conv_depth: int=4,
+ _leaky_relu_alpha: float=0.2,
+ _batch_norm: bool=False,
+ _pool_before_fc: bool=False
+ ):
+ """
+ A global discriminator for pix2pix GANs that outputs a single scalar value as the global probability
+
+ Parameters:
+ :param n_in_channels: (int) number of input channels
+ :type n_in_channels: int
+ :param n_in_filters: (int) number of filters in the first convolutional layer.
+ Every subsequent layer will double the number of filters
+ :type n_in_filters: int
+ :param _conv_depth: (int) depth of the convolutional network
+ :type _conv_depth: int
+ :param _leaky_relu_alpha: (float) alpha value for leaky ReLU activation.
+ Must be between 0 and 1
+ :type _leaky_relu_alpha: float
+ :param _batch_norm: (bool) whether to use batch normalization, defaults to False
+ :type _batch_norm: bool
+ :param _pool_before_fc: (bool) whether to pool before the fully connected network
+ Pooling before the fully connected network can reduce the number of parameters
+ :type _pool_before_fc: bool
+ """
+
+ super().__init__()
+
+ conv_layers = []
+
+ n_channels = n_in_filters
+ conv_layers.append(
+ nn.Conv2d(n_in_channels, n_channels, kernel_size=4, stride=2, padding=1)
+ )
+ conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True))
+
+ # Sequentially add convolutional layers
+ for _ in range(_conv_depth - 1):
+ conv_layers.append(
+ nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=2, padding=1)
+ )
+
+ if _batch_norm:
+ conv_layers.append(nn.BatchNorm2d(n_channels * 2))
+
+ conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True))
+ n_channels *= 2
+
+ # Flattening
+ if _pool_before_fc:
+ conv_layers.append(nn.AdaptiveAvgPool2d((1, 1)))
+ conv_layers.append(nn.Flatten())
+ self._conv_layers = nn.Sequential(*conv_layers)
+
+
+ # Fully connected network to output probability
+ self.fc = nn.Sequential(
+ nn.LazyLinear(512),
+ nn.LeakyReLU(_leaky_relu_alpha, inplace=True),
+ nn.Linear(512, 1),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self._conv_layers(x)
+ x = self.fc(x)
+
+ return x
\ No newline at end of file
diff --git a/models/fnet.py b/models/fnet.py
new file mode 100644
index 0000000..cb2e344
--- /dev/null
+++ b/models/fnet.py
@@ -0,0 +1,98 @@
+import torch
+
+class FNet(torch.nn.Module):
+ def __init__(self, depth=4, mult_chan=32, output_activation='sigmoid'):
+ """
+ Initialize the FNet model with a customizable depth.
+
+ :param depth: Depth of the F-Net model.
+ :type depth: int
+ :param mult_chan: Factor to determine number of output channels.
+ :type mult_chan: int
+ :param output_activation: Activation function for the output layer.
+ :type output_activation: str
+ """
+ super().__init__()
+ self._depth = depth
+ self._multi_chan = mult_chan
+ self.net_recurse = _Net_recurse(
+ n_in_channels=1,
+ mult_chan=self._multi_chan,
+ depth=self._depth)
+ self.conv_out = torch.nn.Conv2d(
+ self._multi_chan, 1, kernel_size=3, padding=1)
+
+ if output_activation == 'sigmoid':
+ self.output_activation = torch.nn.Sigmoid()
+ elif output_activation == 'relu':
+ self.output_activation = torch.nn.ReLU()
+ elif output_activation == 'linear':
+ self.output_activation = torch.nn.Identity()
+ else:
+ raise ValueError('Invalid output_activation')
+
+ def forward(self, x):
+ x_rec = self.net_recurse(x)
+ x_act = self.conv_out(x_rec)
+
+ return self.output_activation(x_act)
+
+class _Net_recurse(torch.nn.Module):
+ def __init__(self, n_in_channels, mult_chan=2, depth=0):
+ """Class for recursive definition of U-network.p
+
+ Parameters:
+ in_channels - (int) number of channels for input.
+ mult_chan - (int) factor to determine number of output channels
+ depth - (int) if 0, this subnet will only be convolutions that double the channel count.
+ """
+ super().__init__()
+ self.depth = depth
+ n_out_channels = n_in_channels * mult_chan
+ self.sub_2conv_more = SubNet2Conv(n_in_channels, n_out_channels)
+
+ if depth > 0:
+ self.sub_2conv_less = SubNet2Conv(2 * n_out_channels, n_out_channels)
+ self.conv_down = torch.nn.Conv2d(n_out_channels, n_out_channels, 2, stride=2)
+ self.bn0 = torch.nn.BatchNorm2d(n_out_channels)
+ self.relu0 = torch.nn.ReLU()
+
+ self.convt = torch.nn.ConvTranspose2d(2 * n_out_channels, n_out_channels, kernel_size=2, stride=2)
+ self.bn1 = torch.nn.BatchNorm2d(n_out_channels)
+ self.relu1 = torch.nn.ReLU()
+ self.sub_u = _Net_recurse(n_out_channels, mult_chan=2, depth=(depth - 1))
+
+ def forward(self, x):
+ if self.depth == 0:
+ return self.sub_2conv_more(x)
+ else: # depth > 0
+ x_2conv_more = self.sub_2conv_more(x)
+ x_conv_down = self.conv_down(x_2conv_more)
+ x_bn0 = self.bn0(x_conv_down)
+ x_relu0 = self.relu0(x_bn0)
+ x_sub_u = self.sub_u(x_relu0)
+ x_convt = self.convt(x_sub_u)
+ x_bn1 = self.bn1(x_convt)
+ x_relu1 = self.relu1(x_bn1)
+ x_cat = torch.cat((x_2conv_more, x_relu1), 1) # concatenate
+ x_2conv_less = self.sub_2conv_less(x_cat)
+ return x_2conv_less
+
+class SubNet2Conv(torch.nn.Module):
+ def __init__(self, n_in, n_out):
+ super().__init__()
+ self.conv1 = torch.nn.Conv2d(n_in, n_out, kernel_size=3, padding=1)
+ self.bn1 = torch.nn.BatchNorm2d(n_out)
+ self.relu1 = torch.nn.ReLU()
+ self.conv2 = torch.nn.Conv2d(n_out, n_out, kernel_size=3, padding=1)
+ self.bn2 = torch.nn.BatchNorm2d(n_out)
+ self.relu2 = torch.nn.ReLU()
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu1(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu2(x)
+ return x
diff --git a/models/unet.py b/models/unet.py
new file mode 100644
index 0000000..2e04f12
--- /dev/null
+++ b/models/unet.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+import matplotlib.pyplot as plt
+from matplotlib.patches import Rectangle
+
+from .unet_utils import *
+
+class UNet(nn.Module):
+ def __init__(self,
+ n_channels,
+ n_classes,
+ base_channels=64,
+ depth=4,
+ bilinear=False,
+ output_activation='sigmoid'):
+ """
+ Initialize the U-Net model with a customizable depth.
+
+ :param n_channels: Number of input channels.
+ :type n_channels: int
+ :param n_classes: Number of output classes.
+ :type n_classes: int
+ :param base_channels: Number of base channels to start with.
+ :type base_channels: int
+ :param depth: Depth of the U-Net model.
+ :type depth: int
+ :param bilinear: Use bilinear interpolation for ups
+ :type bilinear: bool
+ :param output_activation: Activation function for the output layer.
+ :type output_activation: str
+ """
+ super(UNet, self).__init__()
+ self.n_channels = n_channels
+ self.n_classes = n_classes
+ self.depth = depth
+ self.bilinear = bilinear
+
+ in_channels = n_channels # Input channel to the first upsampling layer is the number of input channels
+ out_channels = base_channels # Output channel of the first upsampling layer is the base number of channels
+
+ # Initial upsampling layer
+ self.inc = ConvBnRelu(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=None,
+ num_layers=2)
+
+ # Contracting path
+ contracting_path = []
+ for _ in range(self.depth):
+ # set the number of input channels to the output channels of the previous layer
+ in_channels = out_channels
+ # double the number of output channels for the next layer
+ out_channels *= 2
+ contracting_path.append(
+ Contract(
+ in_channels=in_channels,
+ out_channels=out_channels
+ )
+ )
+ self.down = nn.ModuleList(contracting_path)
+
+ # Bottleneck
+ factor = 2 if bilinear else 1
+ in_channels = out_channels # Input channel to the bottleneck layer is the output channel of the last downsampling layer
+ out_channels = in_channels // factor
+ self.bottleneck = ConvBnRelu(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=None,
+ num_layers=2
+ )
+
+ # Expanding path
+ expanding_path = []
+ for _ in range(self.depth):
+ # input to expanding path has the same dimension as the output of the bottleneck layer
+ in_channels = out_channels
+ # half the number of output channels for the next layer
+ out_channels = in_channels // 2
+ expanding_path.append(
+ Up(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ bilinear=bilinear
+ )
+ )
+ self.up = nn.ModuleList(expanding_path)
+
+ # Output layer
+ self.outc = OutConv(
+ base_channels,
+ n_classes,
+ output_activation=output_activation)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the U-Net model.
+
+ :param x: Input tensor of shape (batch_size, n_channels, height, width).
+ :type x: torch.Tensor
+ :return: Output tensor of shape (batch_size, n_classes, height, width).
+ :rtype: torch.Tensor
+ """
+ # Contracting path
+ x_contracted = []
+ x = self.inc(x)
+ for down in self.down:
+ x_contracted.append(x)
+ x = down(x)
+
+ # Bottleneck
+ x = self.bottleneck(x)
+
+ # Expanding path
+ for i, up in enumerate(self.up):
+ x = up(x, x_contracted[-(i + 1)])
+
+ # Final output
+ logits = self.outc(x)
+ return logits
\ No newline at end of file
diff --git a/models/unet_utils.py b/models/unet_utils.py
new file mode 100644
index 0000000..e8db2e0
--- /dev/null
+++ b/models/unet_utils.py
@@ -0,0 +1,195 @@
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+"""
+Components of the U-Net model
+"""
+
+class ConvBnRelu(nn.Module):
+ """
+ A customizable convolutional block: (Convolution => [BN] => ReLU) * N.
+
+ Allows specifying the number of layers and intermediate channels.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ mid_channels: Optional[List[int]] = None,
+ num_layers: int = 2):
+ """
+ Initialize the customizable DoubleConv module for upsampling/downsampling the channels.
+
+ :param in_channels: Number of input channels.
+ :type in_channels: int
+ :param out_channels: Number of output channels.
+ :type out_channels: int
+ :param mid_channels: List of intermediate channel numbers for each convolutional layer.
+ If unspecified, defaults to [out_channels] * (num_layers - 1).
+ Order matters: mid_channels[0] corresponds to the first intermediate layer, etc.
+ :type mid_channels: Optional[List[int]]
+ :param num_layers: Number of convolutional layers in the block.
+ :type num_layers: int
+ """
+ super().__init__()
+
+ # Default intermediate channels if not specified
+ if mid_channels is None:
+ mid_channels = [out_channels] * (num_layers - 1)
+
+ if len(mid_channels) != num_layers - 1:
+ raise ValueError("Length of mid_channels must be equal to num_layers - 1.")
+
+ layers = []
+
+ # Add the first convolution layer
+ layers.append(
+ nn.Conv2d(in_channels, mid_channels[0], kernel_size=3, padding=1, bias=False)
+ )
+ layers.append(nn.BatchNorm2d(mid_channels[0]))
+ layers.append(nn.ReLU(inplace=True))
+
+ # Add intermediate convolutional layers
+ for i in range(1, num_layers - 1):
+ layers.append(
+ nn.Conv2d(mid_channels[i - 1], mid_channels[i], kernel_size=3, padding=1, bias=False)
+ )
+ layers.append(nn.BatchNorm2d(mid_channels[i]))
+ layers.append(nn.ReLU(inplace=True))
+
+ # Add the final convolution layer
+ layers.append(
+ nn.Conv2d(mid_channels[-1], out_channels, kernel_size=3, padding=1, bias=False)
+ )
+ layers.append(nn.BatchNorm2d(out_channels))
+ layers.append(nn.ReLU(inplace=True))
+
+ # Combine layers into a sequential block
+ self.conv_block = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the ConvBnRelu module.
+
+ :param x: Input tensor.
+ :type x: torch.Tensor
+ :return: Processed output tensor.
+ :rtype: torch.Tensor
+ """
+ return self.conv_block(x)
+
+class Contract(nn.Module):
+ """Downscaling with maxpool then 2 * ConvBnRelu"""
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.maxpool_conv = nn.Sequential(
+ nn.MaxPool2d(2), # Halves spatial dimensions
+ ConvBnRelu(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=None,
+ num_layers=2) # Refines features with 2 sequential convolutions
+ )
+
+ def forward(self, x):
+ return self.maxpool_conv(x)
+
+class Up(nn.Module):
+ """Upscaling then 2 * ConvBnRelu"""
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ bilinear: bool=True):
+ """
+ Up sampling module that combines the upsampled feature map with the skip connection.
+ Upsampling is done via bilinear interpolation or transposed convolution.
+
+ :param in_channels: Number of input channels.
+ :type in_channels: int
+ :param out_channels: Number of output channels.
+ :type out_channels: int
+ :param bilinear: If True, use bilinear upsampling
+ :type bilinear: bool
+ """
+ super().__init__()
+
+ # If bilinear, use the normal convolutions to reduce the number of channels
+ if bilinear:
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
+ self.conv = ConvBnRelu(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=[in_channels // 2],
+ num_layers=2)
+ else:
+ self.up = nn.ConvTranspose2d(
+ in_channels, in_channels // 2, kernel_size=2, stride=2
+ )
+ self.conv = ConvBnRelu(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=None,
+ num_layers=2
+ )
+
+ def forward(self,
+ x1: torch.Tensor,
+ x2: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the Up module.
+ :param x1: Input tensor to be upsampled.
+ :type x1: torch.Tensor
+ :param x2: Skip connection tensor.
+ :type x2: torch.Tensor
+ :return: Processed output tensor.
+ """
+ x1 = self.up(x1) # Upsample x1
+
+ # Handle potential mismatches in spatial dimensions
+ diffY = x2.size()[2] - x1.size()[2]
+ diffX = x2.size()[3] - x1.size()[3]
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
+ # Concatenate x1 (upsampled) with x2 (skip connection)
+ x = torch.cat([x2, x1], dim=1)
+ return self.conv(x)
+
+class OutConv(nn.Module):
+ """
+ Final output layer that applies a 1x1 convolution followed by a sigmoid activation.
+ """
+ def __init__(self, in_channels, out_channels, output_activation:str='sigmoid'):
+ """
+ Initialize the OutConv module.
+
+ :param in_channels: Number of input channels.
+ :type in_channels: int
+ :param out_channels: Number of output channels.
+ :type out_channels: int
+ :param output_activation: Activation function to apply to the output.
+ :type output_activation: str
+ """
+ super(OutConv, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
+ if output_activation == 'sigmoid':
+ self.output_activation = torch.nn.Sigmoid()
+ elif output_activation == 'relu':
+ self.output_activation = torch.nn.ReLU()
+ elif output_activation == 'linear':
+ self.output_activation = torch.nn.Identity()
+ else:
+ raise ValueError('Invalid output_activation')
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the OutConv module.
+
+ :param x: Input tensor.
+ :type x: torch.Tensor
+ :return: Processed output tensor.
+ :rtype: torch.Tensor
+ """
+ return self.output_activation(self.conv(x))
\ No newline at end of file
diff --git a/trainers/AbstractTrainer.py b/trainers/AbstractTrainer.py
new file mode 100644
index 0000000..4fdb57d
--- /dev/null
+++ b/trainers/AbstractTrainer.py
@@ -0,0 +1,460 @@
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from typing import List, Callable, Dict, Optional
+
+import torch
+from torch.utils.data import DataLoader, random_split
+
+from ..metrics.AbstractMetrics import AbstractMetrics
+from ..callbacks.AbstractCallback import AbstractCallback
+
+
+class AbstractTrainer(ABC):
+ """
+ Abstract trainer class for img2img translation models.
+ Provides shared dataset handling and modular callbacks for logging and evaluation.
+ """
+
+ def __init__(
+ self,
+ dataset: torch.utils.data.Dataset,
+ batch_size: int = 16,
+ epochs: int = 10,
+ patience: int = 5,
+ callbacks: List[AbstractCallback] = None,
+ metrics: Dict[str, AbstractMetrics] = None,
+ device: Optional[torch.device] = None,
+ early_termination_metric: str = None,
+ **kwargs,
+ ):
+ """
+ :param dataset: The dataset to be used for training.
+ :type dataset: torch.utils.data.Dataset
+ :param batch_size: The batch size for training.
+ :type batch_size: int
+ :param epochs: The number of epochs for training.
+ :type epochs: int
+ :param patience: The number of epochs with no improvement after which training will be stopped.
+ :type patience: int
+ :param callbacks: List of callback functions to be executed
+ at the end of each epoch.
+ :type callbacks: list of callable
+ :param metrics: Dictionary of metrics to be logged.
+ :type metrics: dict
+ :param device: (optional) The device to be used for training.
+ :type device: torch.device
+ :param early_termination_metric: (optional) The metric to be tracked and used to update early
+ termination count on the validation dataset. If None, early termination is disabled and the
+ training will run for the specified number of epochs.
+ :type early_termination_metric: str
+ """
+
+ self._batch_size = batch_size
+ self._epochs = epochs
+ self._patience = patience
+ self.initialize_callbacks(callbacks)
+ self._metrics = metrics if metrics else {}
+
+ if isinstance(device, torch.device):
+ self._device = device
+ else:
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ self._best_model = None
+ self._best_loss = float("inf")
+ self._early_stop_counter = 0
+ self._early_termination_metric = early_termination_metric
+ self._early_termination = True if early_termination_metric else False
+
+ # Customize data splits
+ self._train_ratio = kwargs.get("train", 0.7)
+ self._val_ratio = kwargs.get("val", 0.15)
+ self._test_ratio = kwargs.get("test", 1.0 - self._train_ratio - self._val_ratio)
+
+ if not (0 < self._train_ratio + self._val_ratio + self._test_ratio <= 1.0):
+ raise ValueError("Data split ratios must sum to 1.0 or less.")
+
+ train_size = int(self._train_ratio * len(dataset))
+ val_size = int(self._val_ratio * len(dataset))
+ test_size = len(dataset) - train_size - val_size
+ self._train_dataset, self._val_dataset, self._test_dataset = random_split(
+ dataset, [train_size, val_size, test_size]
+ )
+
+ # Create DataLoaders
+ self._train_loader = DataLoader(
+ self._train_dataset, batch_size=self._batch_size, shuffle=True
+ )
+ self._val_loader = DataLoader(
+ self._val_dataset, batch_size=self._batch_size, shuffle=False
+ )
+
+ # Epoch counter
+ self._epoch = 0
+
+ # Loss and metrics storage
+ self._train_losses = defaultdict(list)
+ self._val_losses = defaultdict(list)
+ self._train_metrics = defaultdict(list)
+ self._val_metrics = defaultdict(list)
+
+ @abstractmethod
+ def train_step(self, inputs: torch.tensor, targets: torch.tensor)->Dict[str, torch.Tensor]:
+ """
+ Abstract method for training the model on one batch
+ Must be implemented by subclasses.
+ This should be where the losses and metrics are calculated.
+ Should return a dictionary with loss name as key and torch tensor loss as value.
+
+ :param inputs: The input data.
+ :type inputs: torch.Tensor
+ :param targets: The target data.
+ :type targets: torch.Tensor
+ :return: A dictionary containing the loss values for the batch.
+ :rtype: dict[str, torch.Tensor]
+ """
+ pass
+
+ @abstractmethod
+ def evaluate_step(self, inputs: torch.tensor, targets: torch.tensor)->Dict[str, torch.Tensor]:
+ """
+ Abstract method for evaluating the model on one batch
+ Must be implemented by subclasses.
+ This should be where the losses and metrics are calculated.
+ Should return a dictionary with loss name as key and torch tensor loss as value.
+
+ :param inputs: The input data.
+ :type inputs: torch.Tensor
+ :param targets: The target data.
+ :type targets: torch.Tensor
+ :return: A dictionary containing the loss values for the batch.
+ :rtype: dict[str, torch.Tensor]
+ """
+ pass
+
+ @abstractmethod
+ def train_epoch(self)->dict[str, torch.Tensor]:
+ """
+ Can be overridden by subclasses to implement custom training logic.
+ Make calls to the train_step method for each batch
+ in the training DataLoader.
+
+ Return a dictionary with loss name as key and
+ torch tensor loss as value. Multiple losses can be returned.
+
+ :return: A dictionary containing the loss values for the epoch.
+ :rtype: dict[str, torch.Tensor]
+ """
+
+ pass
+
+ @abstractmethod
+ def evaluate_epoch(self)->dict[str, torch.Tensor]:
+ """
+ Can be overridden by subclasses to implement custom evaluation logic.
+ Should make calls to the evaluate_step method for each batch
+ in the validation DataLoader.
+
+ Should return a dictionary with loss name as key and
+ torch tensor loss as value. Multiple losses can be returned.
+
+ :return: A dictionary containing the loss values for the epoch.
+ :rtype: dict[str, torch.Tensor]
+ """
+
+ pass
+
+ def train(self):
+ """
+ Train the model for the specified number of epochs.
+ Make calls to the train epoch and evaluate epoch methods.
+ """
+
+ self.model.to(self.device)
+
+ # callbacks
+ for callback in self.callbacks:
+ callback.on_train_start()
+
+ for epoch in range(self.epochs):
+
+ # Increment the epoch counter
+ self.epoch += 1
+
+ # callbacks
+ for callback in self.callbacks:
+ callback.on_epoch_start()
+
+ # Access all the metrics and reset them
+ for _, metric in self.metrics.items():
+ metric.reset()
+
+ # Train the model for one epoch
+ train_loss = self.train_epoch()
+ for loss_name, loss in train_loss.items():
+ self._train_losses[loss_name].append(loss)
+
+ # Evaluate the model for one epoch
+ val_loss = self.evaluate_epoch()
+ for loss_name, loss in val_loss.items():
+ self._val_losses[loss_name].append(loss)
+
+ # Access all the metrics and compute the final epoch metric value
+ for metric_name, metric in self.metrics.items():
+ train_metric, val_metric = metric.compute()
+ self._train_metrics[metric_name].append(train_metric.item())
+ self._val_metrics[metric_name].append(val_metric.item())
+
+ # Invoke callback on epoch_end
+ for callback in self.callbacks:
+ callback.on_epoch_end()
+
+ # Update early stopping
+ if self._early_termination_metric is None:
+ # Do not perform early stopping when no termination metric is specified
+ early_term_metric = None
+ else:
+ # First look for the metric in validation loss
+ if self._early_termination_metric in list(val_loss.keys()):
+ early_term_metric = val_loss[self._early_termination_metric]
+ # Then look for the metric in validation metrics
+ elif self._early_termination_metric in list(self._val_metrics.keys()):
+ early_term_metric = self._val_metrics[self._early_termination_metric][-1]
+ else:
+ raise ValueError("Invalid early termination metric")
+
+ self.update_early_stop(early_term_metric)
+
+ # Check if early stopping is needed
+ if self._early_termination and self.early_stop_counter >= self.patience:
+ print(f"Early termination at epoch {epoch + 1} with best validation metric {self._best_loss}")
+ break
+
+ for callback in self.callbacks:
+ callback.on_train_end()
+
+ def update_early_stop(self, val_loss: Optional[torch.Tensor]):
+ """
+ Method to update the early stopping criterion
+
+ :param val_loss: The loss value on the validation set
+ :type val_loss: torch.Tensor
+ """
+
+ # When early termination is disabled, the best model is updated with the current model
+ if not self._early_termination and val_loss is None:
+ self.best_model = self.model.state_dict().copy()
+ return
+
+ if val_loss < self.best_loss:
+ self.best_loss = val_loss
+ self.early_stop_counter = 0
+ self.best_model = self.model.state_dict().copy()
+ else:
+ self.early_stop_counter += 1
+
+ def initialize_callbacks(self, callbacks):
+ """
+ Helper to iterate over all callbacks and set trainer property
+
+ :param callbacks: List of callback objects that can be invoked
+ at epcoh start, epoch end, train start and train end
+ :type callbacks: Callback class or subclass or list of Callback class
+ """
+
+ if callbacks is None:
+ self._callbacks = []
+ return
+
+ if not isinstance(callbacks, List):
+ callbacks = [callbacks]
+ for callback in callbacks:
+ if not isinstance(callback, AbstractCallback):
+ raise TypeError("Invalid callback object type")
+ callback._set_trainer(self)
+
+ self._callbacks = callbacks
+
+ """
+ Log property
+ """
+ @property
+ def log(self):
+ """
+ Returns the training and validation losses and metrics.
+ """
+ log ={
+ **{'epoch': list(range(1, self.epoch + 1))},
+ **self._train_losses,
+ **{f'val_{key}': val for key, val in self._val_losses.items()},
+ **self._train_metrics,
+ **{f'val_{key}': val for key, val in self._val_metrics.items()}
+ }
+
+ return log
+
+ """
+ Properties for accessing various attributes of the trainer.
+ """
+ @property
+ def train_ratio(self):
+ return self._train_ratio
+
+ @property
+ def val_ratio(self):
+ return self._val_ratio
+
+ @property
+ def test_ratio(self):
+ return self._test_ratio
+
+ @property
+ def model(self):
+ return self._model
+
+ @property
+ def optimizer(self):
+ return self._optimizer
+
+ @property
+ def device(self):
+ return self._device
+
+ @property
+ def batch_size(self):
+ return self._batch_size
+
+ @property
+ def epochs(self):
+ return self._epochs
+
+ @property
+ def patience(self):
+ return self._patience
+
+ @property
+ def callbacks(self):
+ return self._callbacks
+
+ @property
+ def best_model(self):
+ return self._best_model
+
+ @property
+ def best_loss(self):
+ return self._best_loss
+
+ @property
+ def early_stop_counter(self):
+ return self._early_stop_counter
+
+ @property
+ def metrics(self):
+ return self._metrics
+
+ @property
+ def epoch(self):
+ return self._epoch
+
+ @property
+ def train_losses(self):
+ return self._train_losses
+
+ @property
+ def val_losses(self):
+ return self._val_losses
+
+ @property
+ def train_metrics(self):
+ return self._train_metrics
+
+ @property
+ def val_metrics(self):
+ return self._val_metrics
+
+ """
+ Setters for best model and best loss and early stop counter
+ Meant to be used by the subclasses to update the best model and loss
+ """
+
+ @best_model.setter
+ def best_model(self, value: torch.nn.Module):
+ self._best_model = value
+
+ @best_loss.setter
+ def best_loss(self, value):
+ self._best_loss = value
+
+ @early_stop_counter.setter
+ def early_stop_counter(self, value: int):
+ self._early_stop_counter = value
+
+ @epoch.setter
+ def epoch(self, value: int):
+ self._epoch = value
+
+ """
+ Update loss and metrics
+ """
+
+ def update_loss(self,
+ loss: torch.Tensor,
+ loss_name: str,
+ validation: bool = False):
+ if validation:
+ self._val_losses[loss_name].append(loss)
+ else:
+ self._train_losses[loss_name].append(loss)
+
+ def update_metrics(self,
+ metric: torch.tensor,
+ metric_name: str,
+ validation: bool = False):
+ if validation:
+ self._val_metrics[metric_name].append(metric)
+ else:
+ self._train_metrics[metric_name].append(metric)
+
+ """
+ Properties for accessing the split datasets.
+ """
+ @property
+ def train_dataset(self, loader=False):
+ """
+ Returns the training dataset or DataLoader if loader=True
+
+ :param loader: (bool) whether to return a DataLoader or the dataset
+ :type loader: bool
+ """
+ if loader:
+ return self._train_loader
+ else:
+ return self._train_dataset
+
+ @property
+ def val_dataset(self, loader=False):
+ """
+ Returns the validation dataset or DataLoader if loader=True
+
+ :param loader: (bool) whether to return a DataLoader or the dataset
+ :type loader: bool
+ """
+ if loader:
+ return self._val_loader
+ else:
+ return self._val_dataset
+
+ @property
+ def test_dataset(self, loader=False):
+ """
+ Returns the test dataset or DataLoader if loader=True
+ Generates the DataLoader on the fly as the test data loader is not
+ pre-defined during object initialization
+
+ :param loader: (bool) whether to return a DataLoader or the dataset
+ :type loader: bool
+ """
+ if loader:
+ return DataLoader(self._test_dataset, batch_size=self._batch_size, shuffle=False)
+ else:
+ return self._test_dataset
diff --git a/trainers/README.md b/trainers/README.md
new file mode 100644
index 0000000..06173b4
--- /dev/null
+++ b/trainers/README.md
@@ -0,0 +1 @@
+Here lives the trainer class for FNet/UNet and wGaN GP. Shared components between trainers for different models are isolated into the asbstract trainer class
\ No newline at end of file
diff --git a/trainers/Trainer.py b/trainers/Trainer.py
new file mode 100644
index 0000000..5c6411c
--- /dev/null
+++ b/trainers/Trainer.py
@@ -0,0 +1,144 @@
+from collections import defaultdict
+from typing import Optional, List, Union
+
+import torch
+from torch.utils.data import DataLoader, random_split
+
+from .AbstractTrainer import AbstractTrainer
+
+class Trainer(AbstractTrainer):
+ """
+ Trainer class for generator while backpropagating on single or multiple loss functions.
+ """
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ optimizer: torch.optim.Optimizer,
+ backprop_loss: Union[torch.nn.Module, List[torch.nn.Module]],
+ **kwargs
+ ):
+ """
+ Initialize the trainer with the model, optimizer and loss function.
+
+ :param model: The model to be trained.
+ :type model: torch.nn.Module
+ :param optimizer: The optimizer to be used for training.
+ :type optimizer: torch.optim.Optimizer
+ :param backprop_loss: The loss function to be used for training or a list of loss functions.
+ :type backprop_loss: torch.nn.Module
+ """
+
+ super().__init__(**kwargs)
+
+ self._model = model
+ self._optimizer = optimizer
+ self._backprop_loss = backprop_loss \
+ if isinstance(backprop_loss, list) else [backprop_loss]
+
+ """
+ Overidden methods from the parent abstract class
+ """
+ def train_step(self, inputs: torch.tensor, targets: torch.tensor):
+ """
+ Perform a single training step on batch.
+
+ :param inputs: The input image data batch
+ :type inputs: torch.tensor
+ :param targets: The target image data batch
+ :type targets: torch.tensor
+ """
+ # move the data to the device
+ inputs, targets = inputs.to(self.device), targets.to(self.device)
+
+ # set the model to train
+ self.model.train()
+ # set the optimizer gradients to zero
+ self.optimizer.zero_grad()
+
+ # Forward pass
+ outputs = self.model(inputs)
+
+ # Back propagate the loss
+ losses = {}
+ total_loss = torch.tensor(0.0, device=self.device)
+ for loss in self._backprop_loss:
+ losses[type(loss).__name__] = loss(outputs, targets)
+ total_loss += losses[type(loss).__name__]
+
+ total_loss.backward()
+ self.optimizer.step()
+
+ # Calculate the metrics outputs and update the metrics
+ for _, metric in self.metrics.items():
+ metric.update(outputs, targets, validation=False)
+
+ return {
+ key: value.item() for key, value in losses.items()
+ }
+
+ def evaluate_step(self, inputs: torch.tensor, targets: torch.tensor):
+ """
+ Perform a single evaluation step on batch.
+
+ :param inputs: The input image data batch
+ :type inputs: torch.tensor
+ :param targets: The target image data batch
+ :type targets: torch.tensor
+ """
+ # move the data to the device
+ inputs, targets = inputs.to(self.device), targets.to(self.device)
+
+ # set the model to evaluation
+ self.model.eval()
+
+ with torch.no_grad():
+ # Forward pass
+ outputs = self.model(inputs)
+
+ # calculate the loss
+ losses = {}
+ for loss in self._backprop_loss:
+ losses[type(loss).__name__] = loss(outputs, targets)
+
+ # Calculate the metrics outputs and update the metrics
+ for _, metric in self.metrics.items():
+ metric.update(outputs, targets, validation=True)
+
+ return {
+ key: value.item() for key, value in losses.items()
+ }
+
+ def train_epoch(self):
+ """
+ Train the model for one epoch.
+ """
+ self._model.train()
+ losses = defaultdict(list)
+ # Iterate over the train_loader
+ for inputs, targets in self._train_loader:
+ batch_loss = self.train_step(inputs, targets)
+ for key, value in batch_loss.items():
+ losses[key].append(value)
+
+ # reduce loss
+ return {
+ key: sum(value) / len(value) for key, value in losses.items()
+ }
+
+ def evaluate_epoch(self):
+ """
+ Evaluate the model for one epoch.
+ """
+
+ self._model.eval()
+ losses = defaultdict(list)
+ # Iterate over the val_loader
+ for inputs, targets in self._val_loader:
+ batch_loss = self.evaluate_step(inputs, targets)
+ for key, value in batch_loss.items():
+ losses[key].append(value)
+
+ # reduce loss
+ return {
+ key: sum(value) / len(value) for key, value in losses.items()
+ }
\ No newline at end of file
diff --git a/trainers/WGANTrainer.py b/trainers/WGANTrainer.py
new file mode 100644
index 0000000..909617a
--- /dev/null
+++ b/trainers/WGANTrainer.py
@@ -0,0 +1,268 @@
+from typing import Optional
+from collections import defaultdict
+
+import torch
+import torch.autograd as autograd
+from torch.utils.data import DataLoader
+
+from .AbstractTrainer import AbstractTrainer
+
+class WGANTrainer(AbstractTrainer):
+ def __init__(self,
+ generator: torch.nn.Module,
+ discriminator: torch.nn.Module,
+ gen_optimizer: torch.optim.Optimizer,
+ disc_optimizer: torch.optim.Optimizer,
+ generator_loss_fn: torch.nn.Module,
+ discriminator_loss_fn: torch.nn.Module,
+ gradient_penalty_fn: Optional[torch.nn.Module]=None,
+ discriminator_update_freq: int=1,
+ generator_update_freq: int=5,
+ # rest of the arguments are passed to and handled by the parent class
+ # - dataset
+ # - batch_size
+ # - epochs
+ # - patience
+ # - callbacks
+ # - metrics
+ **kwargs):
+ """
+ Initializes the WGaN Trainer class.
+
+ :param generator: The image2image generator model (e.g., UNet)
+ :type generator: torch.nn.Module
+ :param discriminator: The discriminator model
+ :type discriminator: torch.nn.Module
+ :param gen_optimizer: Generator optimizer
+ :type gen_optimizer: torch.optim.Optimizer
+ :param disc_optimizer: Discriminator optimizer
+ :type disc_optimizer: torch.optim.Optimizer
+ :param generator_loss_fn: Generator loss function
+ :type generator_loss_fn: torch.nn.Module
+ :param discriminator_loss_fn: Adverserial loss function
+ :type discriminator_loss_fn: torch.nn.Module
+ :param gradient_penalty_fn: (optional) Gradient penalty loss function
+ :type gradient_penalty_fn: torch.nn.Module
+ :param discriminator_update_freq: How frequently to update the discriminator
+ :type discriminator_update_freq: int
+ :param generator_update_freq: How frequently to update the generator
+ :type generator_update_freq: int
+ :param kwargs: Additional arguments passed to the AbstractTrainer
+ :type kwargs: dict
+ """
+ super().__init__(**kwargs)
+
+ # Validate update frequencies
+ if discriminator_update_freq > 1 and generator_update_freq > 1:
+ raise ValueError(
+ "Both discriminator_update_freq and generator_update_freq cannot be greater than 1. "
+ "At least one network must update every epoch."
+ )
+
+ self._generator = generator
+ self._discriminator = discriminator
+ self._gen_optimizer = gen_optimizer
+ self._disc_optimizer = disc_optimizer
+ self._generator_loss_fn = generator_loss_fn
+ self._generator_loss_fn.trainer = self
+ self._discriminator_loss_fn = discriminator_loss_fn
+ self._discriminator_loss_fn.trainer = self
+ self._gradient_penalty_fn = gradient_penalty_fn
+ if self._gradient_penalty_fn is not None:
+ self._gradient_penalty_fn.trainer = self
+
+ # Global step counter and update frequencies
+ self._discriminator_update_freq = discriminator_update_freq
+ self._generator_update_freq = generator_update_freq
+
+ # Initialize the last losses memory to zero
+ self._last_discriminator_loss = torch.tensor(0.0, device=self.device).detach()
+ self._last_gradient_penalty_loss = torch.tensor(0.0, device=self.device).detach()
+ self._last_generator_loss = torch.tensor(0.0, device=self.device).detach()
+
+ def train_step(self,
+ inputs: torch.tensor,
+ targets: torch.tensor
+ ):
+ """
+ Perform a single training step on batch.
+
+ :param inputs: The input image data batch
+ :type inputs: torch.tensor
+ :param targets: The target image data batch
+ :type targets: torch.tensor
+ """
+ inputs, targets = inputs.to(self.device), targets.to(self.device)
+
+ gp_loss = torch.tensor(0.0, device=self.device)
+
+ # foward pass to generate image (shared by both updates)
+ generated_images = self._generator(inputs)
+
+ # Train Discriminator
+ if self.epoch % self._discriminator_update_freq == 0:
+ self._disc_optimizer.zero_grad()
+
+ real_images = targets
+
+ # Concatenate input channel and real/generated image channels along the
+ # channel dimension to feed full stacked multi-channel images to the discriminator
+ real_input_pair = torch.cat((real_images, inputs), 1)
+ generated_input_pair = torch.cat((generated_images.detach(), inputs), 1)
+
+ discriminator_real_score = self._discriminator(real_input_pair).mean()
+ discriminator_fake_score = self._discriminator(generated_input_pair).mean()
+
+ # Adverserial loss
+ discriminator_loss = self._discriminator_loss_fn(discriminator_real_score, discriminator_fake_score)
+
+ # Compute Gradient penalty loss if fn is supplied
+ if self._gradient_penalty_fn is not None:
+ gp_loss = self._gradient_penalty_fn(real_input_pair, generated_input_pair)
+
+ total_discriminator_loss = discriminator_loss + gp_loss
+ total_discriminator_loss.backward()
+ self._disc_optimizer.step()
+
+ # memorize current discriminator loss until next discriminator update
+ self._last_discriminator_loss = discriminator_loss.detach()
+ self._last_gradient_penalty_loss = gp_loss
+ else:
+ # when not being updated, use the loss from previus update
+ discriminator_loss = self._last_discriminator_loss
+ gp_loss = self._last_gradient_penalty_loss
+
+ # Train Generator
+ if self.epoch % self._generator_update_freq == 0:
+ self._gen_optimizer.zero_grad()
+
+ discriminator_fake_score = self._discriminator(torch.cat((generated_images, inputs), 1)).mean()
+ generator_loss = self._generator_loss_fn(discriminator_fake_score, real_images, generated_images, self.epoch)
+ generator_loss.backward()
+ self._gen_optimizer.step()
+
+ # memorize current generator loss until next generator update
+ self._last_generator_loss = generator_loss.detach()
+ else:
+ # when not being updated, set the loss to zero
+ generator_loss = self._last_generator_loss
+
+ for _, metric in self.metrics.items():
+ ## TODO: centralize the update of metrics
+ # compute the generated fake targets regardless for use with metrics
+ generated_images = self._generator(inputs).detach()
+ metric.update(generated_images, targets, validation=False)
+ ## After each batch -> after each epoch
+
+ loss = {type(self._discriminator_loss_fn).__name__: discriminator_loss.item(),
+ type(self._generator_loss_fn).__name__: generator_loss.item()}
+ if self._gradient_penalty_fn is not None:
+ loss = {
+ **loss,
+ **{type(self._gradient_penalty_fn).__name__: gp_loss.item()}
+ }
+
+ return loss
+
+ def evaluate_step(self,
+ inputs: torch.tensor,
+ targets: torch.tensor
+ ):
+ """
+ Perform a single evaluation step on batch.
+
+ :param inputs: The input image data batch
+ :type inputs: torch.tensor
+ :param targets: The target image data batch
+ :type targets: torch.tensor
+ """
+ inputs, targets = inputs.to(self.device), targets.to(self.device)
+
+ self._generator.eval()
+ self._discriminator.eval()
+ with torch.no_grad():
+
+ real_images = targets
+ generated_images = self._generator(inputs)
+
+ # Concatenate input channel and real/generated image channels along the
+ # channel dimension to feed full stacked multi-channel images to the discriminator
+ real_input_pair = torch.cat((real_images, inputs), 1)
+ generated_input_pair = torch.cat((generated_images, inputs), 1)
+
+ discriminator_real_score = self._discriminator(real_input_pair).mean()
+ discriminator_fake_score = self._discriminator(generated_input_pair).mean()
+
+ # Compute losses
+ discriminator_loss = self._discriminator_loss_fn(discriminator_real_score, discriminator_fake_score)
+
+ ## Declare an empty tensor for the gradient penalty loss as
+ # it is not useful during evaluation
+ gp_loss = torch.tensor(0.0, device=self.device)
+
+ generator_loss = self._generator_loss_fn(discriminator_fake_score, generated_images, real_images, self.epoch)
+
+ for _, metric in self.metrics.items():
+ metric.update(generated_images, targets, validation=True)
+
+ loss = {type(self._discriminator_loss_fn).__name__: discriminator_loss.item(),
+ type(self._generator_loss_fn).__name__: generator_loss.item()}
+ if self._gradient_penalty_fn is not None:
+ loss = {
+ **loss,
+ **{type(self._gradient_penalty_fn).__name__: gp_loss.item()}
+ }
+
+ return loss
+
+ def train_epoch(self):
+
+ self._generator.train()
+ self._discriminator.train()
+
+ epoch_losses = defaultdict(list)
+ for inputs, targets in self._train_loader:
+ losses = self.train_step(inputs, targets)
+ for key, value in losses.items():
+ epoch_losses[key].append(value)
+
+ for key, _ in epoch_losses.items():
+ epoch_losses[key] = sum(epoch_losses[key])/len(self._train_loader)
+
+ return epoch_losses
+
+ def evaluate_epoch(self):
+
+ self._generator.eval()
+ self._discriminator.eval()
+
+ epoch_losses = defaultdict(list)
+ for inputs, targets in self._val_loader:
+ losses = self.evaluate_step(inputs, targets)
+ for key, value in losses.items():
+ epoch_losses[key].append(value)
+
+ for key, _ in epoch_losses.items():
+ epoch_losses[key] = sum(epoch_losses[key])/len(self._val_loader)
+
+ return epoch_losses
+
+ def train(self):
+
+ self._discriminator.to(self.device)
+
+ super().train()
+
+ @property
+ def model(self) -> torch.nn.Module:
+ """
+ return the generator
+ """
+ return self._generator
+
+ @property
+ def discriminator(self) -> torch.nn.Module:
+ """
+ returns the discriminator
+ """
+ return self._discriminator
\ No newline at end of file
diff --git a/transforms/MinMaxNormalize.py b/transforms/MinMaxNormalize.py
new file mode 100644
index 0000000..254f9a2
--- /dev/null
+++ b/transforms/MinMaxNormalize.py
@@ -0,0 +1,60 @@
+from albumentations import ImageOnlyTransform
+import numpy as np
+
+"""
+Adapted from https://github.com/WayScience/nuclear_speckles_analysis
+"""
+class MinMaxNormalize(ImageOnlyTransform):
+ """Min-Max normalize each image"""
+
+ def __init__(self,
+ _normalization_factor: float,
+ _always_apply: bool=False,
+ _p: float=0.5):
+ """
+ Initializes the MinMaxNormalize transform.
+
+ :param _normalization_factor: The factor by which to normalize the image.
+ :type _normalization_factor: float
+ :param _always_apply: If True, always apply this transformation.
+ :type _always_apply: bool
+ :param _p: Probability of applying this transformation.
+ :type _p: float
+ """
+ super(MinMaxNormalize, self).__init__(_always_apply, _p)
+ self.__normalization_factor = _normalization_factor
+
+ @property
+ def normalization_factor(self):
+ return self.__normalization_factor
+
+ def apply(self, _img, **kwargs):
+ """
+ Apply min-max normalization to the image.
+
+ :param _img: Input image as a numpy array.
+ :type _img: np.ndarray
+ :return: Min-max normalized image.
+ :rtype: np.ndarray
+ :raises TypeError: If the input image is not a numpy array.
+ """
+ if isinstance(_img, np.ndarray):
+ return _img / self.normalization_factor
+
+ else:
+ raise TypeError("Unsupported image type for transform (Should be a numpy array)")
+
+ def invert(self, _img, **kwargs):
+ """
+ Invert the min-max normalization.
+
+ :param _img: Input image as a numpy array.
+ :type _img: np.ndarray
+ :return: Inverted image.
+ :rtype: np.ndarray
+ :raises TypeError: If the input image is not a numpy array.
+ """
+ if isinstance(_img, np.ndarray):
+ return _img * self.normalization_factor
+ else:
+ raise TypeError("Unsupported image type for transform (Should be a numpy array)")
diff --git a/transforms/PixelDepthTransform.py b/transforms/PixelDepthTransform.py
new file mode 100644
index 0000000..29ef444
--- /dev/null
+++ b/transforms/PixelDepthTransform.py
@@ -0,0 +1,85 @@
+from albumentations import ImageOnlyTransform
+import numpy as np
+
+class PixelDepthTransform(ImageOnlyTransform):
+ """
+ Transform to convert images from a specified bit depth to another bit depth (e.g., 16-bit to 8-bit).
+ Automatically scales pixel values up or down to the target bit depth.
+ The only supported bit depths are 8, 16, and 32.
+ """
+
+ def __init__(self,
+ src_bit_depth: int = 16,
+ target_bit_depth: int = 8,
+ _always_apply: bool = True,
+ _p: float = 1.0):
+ """
+ Initializes the PixelDepthTransform.
+
+ :param src_bit_depth: Bit depth of the input image (e.g., 16 for 16-bit).
+ :type src_bit_depth: int
+ :param target_bit_depth: Bit depth to scale the image to (e.g., 8 for 8-bit).
+ :type target_bit_depth: int
+ :param _always_apply: Whether to always apply the transform.
+ :type _always_apply: bool
+ :param _p: Probability of applying the transform.
+ :type _p: float
+ :raises ValueError: If the source or target bit depth is not supported.
+ """
+ if src_bit_depth not in [8, 16, 32]:
+ raise ValueError("Unsupported source bit depth (should be 8 or 16)")
+ if target_bit_depth not in [8, 16, 32]:
+ raise ValueError("Unsupported target bit depth (should be 8 or 16)")
+
+ super(PixelDepthTransform, self).__init__(_always_apply, _p)
+ self.src_bit_depth = src_bit_depth
+ self.target_bit_depth = target_bit_depth
+
+ def apply(self, img, **kwargs):
+ """
+ Apply the bit depth transformation.
+
+ :param img: Input image as a numpy array.
+ :type img: np.ndarray
+ :return: Transformed image scaled to the target bit depth.
+ :rtype: np.ndarray
+ :raises TypeError: If the input image is not a numpy array.
+ """
+ if not isinstance(img, np.ndarray):
+ raise TypeError("Unsupported image type for transform (should be a numpy array)")
+
+ # Maximum pixel value based on source and target bit depth
+ src_max_val = (2 ** self.src_bit_depth) - 1
+ target_max_val = (2 ** self.target_bit_depth) - 1
+
+ if self.target_bit_depth == 32:
+ # Scale to the 32-bit integer range
+ return ((img / src_max_val) * target_max_val).astype(np.uint32)
+ else:
+ # Standard conversion for 8-bit or 16-bit integers
+ return ((img / src_max_val) * target_max_val).astype(
+ np.uint8 if self.target_bit_depth == 8 else np.uint16
+ )
+
+ def invert(self, img, **kwargs):
+ """
+ Optionally invert the bit depth transformation (useful for debugging or preprocessing).
+
+ :param img: Transformed image as a numpy array.
+ :type img: np.ndarray
+ :return: Image restored to the original bit depth.
+ :rtype: np.ndarray
+ :raises TypeError: If the input image is not a numpy array.
+ """
+ if not isinstance(img, np.ndarray):
+ raise TypeError("Unsupported image type for inversion (should be a numpy array)")
+
+ target_max_val = (2 ** self.target_bit_depth) - 1
+ src_max_val = (2 ** self.src_bit_depth) - 1
+
+ # Invert scaling back to original bit depth
+ img = (img / target_max_val) * src_max_val
+ return img.astype(np.uint16) if self.src_bit_depth == 16 else img
+
+ def __repr__(self):
+ return f"PixelDepthTransform(src_bit_depth={self.src_bit_depth}, target_bit_depth={self.target_bit_depth})"
\ No newline at end of file
diff --git a/transforms/README.md b/transforms/README.md
new file mode 100644
index 0000000..d3a5f1c
--- /dev/null
+++ b/transforms/README.md
@@ -0,0 +1,3 @@
+Here lives the image transform class.
+
+For now all of these transforms are just used as input normalization. They all have an invert function to faciliate visualization.
\ No newline at end of file
diff --git a/transforms/ZScoreNormalize.py b/transforms/ZScoreNormalize.py
new file mode 100644
index 0000000..cf91acb
--- /dev/null
+++ b/transforms/ZScoreNormalize.py
@@ -0,0 +1,87 @@
+from albumentations import ImageOnlyTransform
+import numpy as np
+
+"""
+Wrote this to get z score normalizae to work with albumentations
+"""
+class ZScoreNormalize(ImageOnlyTransform):
+ """Z-score normalize each image"""
+
+ def __init__(self, _mean=None, _std=None, _always_apply=False, _p=0.5):
+ """
+ Initializes the ZScoreNormalize transform.
+
+ :param _mean: Precomputed mean for normalization (optional). If None, compute per-image mean.
+ :type _mean: float, optional
+ :param _std: Precomputed standard deviation for normalization (optional). If None, compute per-image std.
+ :type _std: float, optional
+ :param _always_apply: If True, always apply this transformation.
+ :type _always_apply: bool
+ :param _p: Probability of applying this transformation.
+ :type _p: float
+ """
+ super(ZScoreNormalize, self).__init__(_always_apply, _p)
+ self.__mean = _mean
+ self.__std = _std
+
+ @property
+ def mean(self):
+ return self.__mean
+
+ @property
+ def std(self):
+ return self.__std
+
+ def apply(self, _img, **params):
+ """
+ Apply z-score normalization to the image.
+
+ :param _img: Input image as a numpy array.
+ :type _img: np.ndarray
+ :return: Z-score normalized image.
+ :rtype: np.ndarray
+ :raises TypeError: If the input image is not a numpy array.
+ :raises ValueError: If the standard deviation is zero.
+ """
+ if not isinstance(_img, np.ndarray):
+ raise TypeError("Unsupported image type for transform (Should be a numpy array)")
+
+ mean = self.__mean if self.__mean is not None else _img.mean()
+ std = self.__std if self.__std is not None else _img.std()
+
+ if std == 0:
+ raise ValueError("Standard deviation is zero; cannot perform z-score normalization.")
+
+ return (_img - mean) / std
+
+ def invert(self, _img, **kwargs):
+ """
+ Invert the z-score normalization.
+ If this transform is applied on image basis (without global mean and std)
+ Will simply return the z score transformed image back
+
+ :param _img: Input image as a numpy array.
+ :type _img: np.ndarray
+ :return: Inverted image.
+ :rtype: np.ndarray
+ :raises TypeError: If the input image is not a numpy array.
+ :raises ValueError: If the standard deviation is zero.
+ """
+ if not isinstance(_img, np.ndarray):
+ raise TypeError("Unsupported image type for transform (Should be a numpy array)")
+
+ if self.__mean is None or self.__std is None:
+ mean = kwargs.get("mean", None)
+ std = kwargs.get("std", None)
+ if mean is None or std is None:
+ return _img
+ else:
+ return (_img * std) + mean
+ else:
+ mean = self.__mean if self.__mean is not None else _img.mean()
+ std = self.__std if self.__std is not None else _img.std()
+
+ if std == 0:
+ raise ValueError("Standard deviation is zero; cannot perform z-score normalization.")
+
+ return (_img * std) + mean