From d7a03b16795e73d462e1c386963744c3e0f38ab0 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:15:42 -0700 Subject: [PATCH 01/89] Added readme file for the model folders --- models/README.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 models/README.md diff --git a/models/README.md b/models/README.md new file mode 100644 index 0000000..2c0cd8b --- /dev/null +++ b/models/README.md @@ -0,0 +1,3 @@ +Here lives the torch model and parts for FNet, UNet and wGaN GP + +Quite unclean in its current state. \ No newline at end of file From ef442479278c8e7f72f6a3bea54345ff94dd3309 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:15:53 -0700 Subject: [PATCH 02/89] Added model files --- models/discriminator.py | 142 ++++++++++++++++++++++ models/fnet.py | 92 ++++++++++++++ models/unet.py | 108 +++++++++++++++++ models/unet_utils.py | 262 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 604 insertions(+) create mode 100644 models/discriminator.py create mode 100644 models/fnet.py create mode 100644 models/unet.py create mode 100644 models/unet_utils.py diff --git a/models/discriminator.py b/models/discriminator.py new file mode 100644 index 0000000..6f1b8e7 --- /dev/null +++ b/models/discriminator.py @@ -0,0 +1,142 @@ +import torch +from torch import nn +import torch.nn.functional as F + +""" +Implementation of GaN discriminators to use along with UNet or FNet generator. +""" + +class PatchBasedDiscriminator(nn.Module): + + def __init__( + self, + n_in_channels: int, + n_in_filters: int, + _conv_depth: int=4, + _leaky_relu_alpha: float=0.2 + ): + """ + A patch-based discriminator for pix2pix GANs that outputs a feature map + of probabilities + + :param n_in_channels: (int) number of input channels + :type n_in_channels: int + :param n_in_filters: (int) number of filters in the first convolutional layer. + Every subsequent layer will double the number of filters + :type n_in_filters: int + :param _conv_depth: (int) depth of the convolutional network + :type _conv_depth: int + :param _leaky_relu_alpha: (float) alpha value for leaky ReLU activation. + Must be between 0 and 1 + :type _leaky_relu_alpha: float + """ + + super().__init__() + + conv_layers = [] + + n_channels = n_in_filters + conv_layers.append( + nn.Conv2d(n_in_channels, n_channels, kernel_size=4, stride=2, padding=1) + ) + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) + + # Sequentially add convolutional layers + for _ in range(_conv_depth - 2): + conv_layers.append( + nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=2, padding=1) + ) + conv_layers.append(nn.BatchNorm2d(n_channels * 2)) + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) + n_channels *= 2 + + # Another layer of conv without downscaling + ## TODO: figure out if this is needed + conv_layers.append( + nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=1, padding=1) + ) + conv_layers.append(nn.BatchNorm2d(n_channels * 2)) + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) + n_channels *= 2 + self._conv_layers = nn.Sequential(*conv_layers) + + # Output layer to get the probability map + self.out = nn.Sequential( + *[nn.Conv2d(n_channels, 1, kernel_size=4, stride=1, padding=1), + nn.Sigmoid()] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._conv_layers(x) + x = self.out(x) + + return x + +class GlobalDiscriminator(nn.Module): + + def __init__( + self, + n_in_channels: int, + n_in_filters: int, + _conv_depth: int=4, + _leaky_relu_alpha: float=0.2, + _pool_before_fc: bool=False + ): + """ + A global discriminator for pix2pix GANs that outputs a single scalar value as the global probability + + Parameters: + :param n_in_channels: (int) number of input channels + :type n_in_channels: int + :param n_in_filters: (int) number of filters in the first convolutional layer. + Every subsequent layer will double the number of filters + :type n_in_filters: int + :param _conv_depth: (int) depth of the convolutional network + :type _conv_depth: int + :param _leaky_relu_alpha: (float) alpha value for leaky ReLU activation. + Must be between 0 and 1 + :type _leaky_relu_alpha: float + :param _pool_before_fc: (bool) whether to pool before the fully connected network + Pooling before the fully connected network can reduce the number of parameters + :type _pool_before_fc: bool + """ + + super().__init__() + + conv_layers = [] + + n_channels = n_in_filters + conv_layers.append( + nn.Conv2d(n_in_channels, n_channels, kernel_size=4, stride=2, padding=1) + ) + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) + + # Sequentially add convolutional layers + for _ in range(_conv_depth - 1): + conv_layers.append( + nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=2, padding=1) + ) + conv_layers.append(nn.BatchNorm2d(n_channels * 2)) + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) + n_channels *= 2 + + # Flattening + if _pool_before_fc: + conv_layers.append(nn.AdaptiveAvgPool2d((1, 1))) + conv_layers.append(nn.Flatten()) + self._conv_layers = nn.Sequential(*conv_layers) + + + # Fully connected network to output probability + self.fc = nn.Sequential( + nn.LazyLinear(512), + nn.LeakyReLU(_leaky_relu_alpha, inplace=True), + nn.Linear(512, 1), + nn.Sigmoid() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._conv_layers(x) + x = self.fc(x) + + return x \ No newline at end of file diff --git a/models/fnet.py b/models/fnet.py new file mode 100644 index 0000000..c9a7af9 --- /dev/null +++ b/models/fnet.py @@ -0,0 +1,92 @@ +import torch + +""" +Adapted from https://github.com/AllenCellModeling/pytorch_fnet +""" +class FNet(torch.nn.Module): + def __init__(self, depth=4, mult_chan=32, output_activation='sigmoid'): + super().__init__() + self._depth = depth + self._multi_chan = mult_chan + self.net_recurse = _Net_recurse( + n_in_channels=1, + mult_chan=self._multi_chan, + depth=self._depth) + self.conv_out = torch.nn.Conv2d( + self._multi_chan, 1, kernel_size=3, padding=1) + + self.output_activation = output_activation + if self.output_activation == 'sigmoid': + self.output_activation = torch.nn.Sigmoid() + elif self.output_activation == 'relu': + self.output_activation = torch.nn.ReLU() + elif self.output_activation == 'linear': + self.output_activation = torch.nn.Identity() + else: + raise ValueError('Invalid output_activation') + + def forward(self, x): + x_rec = self.net_recurse(x) + x_act = self.conv_out(x_rec) + + return self.output_activation(x_act) + +class _Net_recurse(torch.nn.Module): + def __init__(self, n_in_channels, mult_chan=2, depth=0): + """Class for recursive definition of U-network.p + + Parameters: + in_channels - (int) number of channels for input. + mult_chan - (int) factor to determine number of output channels + depth - (int) if 0, this subnet will only be convolutions that double the channel count. + """ + super().__init__() + self.depth = depth + n_out_channels = n_in_channels * mult_chan + self.sub_2conv_more = SubNet2Conv(n_in_channels, n_out_channels) + + if depth > 0: + self.sub_2conv_less = SubNet2Conv(2 * n_out_channels, n_out_channels) + self.conv_down = torch.nn.Conv2d(n_out_channels, n_out_channels, 2, stride=2) + self.bn0 = torch.nn.BatchNorm2d(n_out_channels) + self.relu0 = torch.nn.ReLU() + + self.convt = torch.nn.ConvTranspose2d(2 * n_out_channels, n_out_channels, kernel_size=2, stride=2) + self.bn1 = torch.nn.BatchNorm2d(n_out_channels) + self.relu1 = torch.nn.ReLU() + self.sub_u = _Net_recurse(n_out_channels, mult_chan=2, depth=(depth - 1)) + + def forward(self, x): + if self.depth == 0: + return self.sub_2conv_more(x) + else: # depth > 0 + x_2conv_more = self.sub_2conv_more(x) + x_conv_down = self.conv_down(x_2conv_more) + x_bn0 = self.bn0(x_conv_down) + x_relu0 = self.relu0(x_bn0) + x_sub_u = self.sub_u(x_relu0) + x_convt = self.convt(x_sub_u) + x_bn1 = self.bn1(x_convt) + x_relu1 = self.relu1(x_bn1) + x_cat = torch.cat((x_2conv_more, x_relu1), 1) # concatenate + x_2conv_less = self.sub_2conv_less(x_cat) + return x_2conv_less + +class SubNet2Conv(torch.nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv1 = torch.nn.Conv2d(n_in, n_out, kernel_size=3, padding=1) + self.bn1 = torch.nn.BatchNorm2d(n_out) + self.relu1 = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(n_out, n_out, kernel_size=3, padding=1) + self.bn2 = torch.nn.BatchNorm2d(n_out) + self.relu2 = torch.nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + return x diff --git a/models/unet.py b/models/unet.py new file mode 100644 index 0000000..6cafce7 --- /dev/null +++ b/models/unet.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle + +from .unet_utils import * + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes, base_channels=64, depth=4, bilinear=False): + """ + Initialize the U-Net model with a customizable depth. + + Args: + n_channels (int): Number of input channels. + n_classes (int): Number of output classes. + base_channels (int): Number of channels for the first layer. Subsequent layers double this value. + depth (int): Number of downsampling steps (controls depth). + bilinear (bool): If True, use bilinear upsampling; otherwise, use transposed convolutions. + """ + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.base_channels = base_channels + self.depth = depth + self.bilinear = bilinear + + in_channels = n_channels # Input channel to the first upsampling layer is the number of input channels + out_channels = base_channels # Output channel of the first upsampling layer is the base number of channels + + # Initial upsampling layer + self.inc = ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=None, + num_layers=2) + + # Contracting path + contracting_path = [] + for _ in range(depth): + # set the number of input channels to the output channels of the previous layer + in_channels = out_channels + # double the number of output channels for the next layer + out_channels *= 2 + contracting_path.append( + Contract( + in_channels=in_channels, + out_channels=out_channels + ) + ) + self.down = nn.ModuleList(contracting_path) + + # Bottleneck + factor = 2 if bilinear else 1 + in_channels = out_channels # Input channel to the bottleneck layer is the output channel of the last downsampling layer + out_channels = in_channels // factor + self.bottleneck = ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=None, + num_layers=2 + ) + + # Expanding path + expanding_path = [] + for _ in range(depth): + # input to expanding path has the same dimension as the output of the bottleneck layer + in_channels = out_channels + # half the number of output channels for the next layer + out_channels = in_channels // 2 + expanding_path.append( + ## TODO: replace this with the Upsample and SkipConnection modules maybe + Up( + in_channels=in_channels, + out_channels=out_channels, + bilinear=bilinear + ) + ) + self.up = nn.ModuleList(expanding_path) + + # Output layer + self.outc = OutConv(base_channels, n_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the U-Net model. + + :param x: Input tensor of shape (batch_size, n_channels, height, width). + :type x: torch.Tensor + :return: Output tensor of shape (batch_size, n_classes, height, width). + :rtype: torch.Tensor + """ + # Contracting path + x_contracted = [] + x = self.inc(x) + for down in self.down: + x_contracted.append(x) + x = down(x) + + # Bottleneck + x = self.bottleneck(x) + + # Expanding path + for i, up in enumerate(self.up): + x = up(x, x_contracted[-(i + 1)]) + + # Final output + logits = self.outc(x) + return logits \ No newline at end of file diff --git a/models/unet_utils.py b/models/unet_utils.py new file mode 100644 index 0000000..0880f74 --- /dev/null +++ b/models/unet_utils.py @@ -0,0 +1,262 @@ +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +""" +Components of the U-Net model +""" + +class ConvBnRelu(nn.Module): + """ + A customizable convolutional block: (Convolution => [BN] => ReLU) * N. + + Allows specifying the number of layers and intermediate channels. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + mid_channels: Optional[List[int]] = None, + num_layers: int = 2): + """ + Initialize the customizable DoubleConv module for upsampling/downsampling the channels. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + :param mid_channels: List of intermediate channel numbers for each convolutional layer. + If unspecified, defaults to [out_channels] * (num_layers - 1). + Order matters: mid_channels[0] corresponds to the first intermediate layer, etc. + :type mid_channels: Optional[List[int]] + :param num_layers: Number of convolutional layers in the block. + :type num_layers: int + """ + super().__init__() + + # Default intermediate channels if not specified + if mid_channels is None: + mid_channels = [out_channels] * (num_layers - 1) + + if len(mid_channels) != num_layers - 1: + raise ValueError("Length of mid_channels must be equal to num_layers - 1.") + + layers = [] + + # Add the first convolution layer + layers.append( + nn.Conv2d(in_channels, mid_channels[0], kernel_size=3, padding=1, bias=False) + ) + layers.append(nn.BatchNorm2d(mid_channels[0])) + layers.append(nn.ReLU(inplace=True)) + + # Add intermediate convolutional layers + for i in range(1, num_layers - 1): + layers.append( + nn.Conv2d(mid_channels[i - 1], mid_channels[i], kernel_size=3, padding=1, bias=False) + ) + layers.append(nn.BatchNorm2d(mid_channels[i])) + layers.append(nn.ReLU(inplace=True)) + + # Add the final convolution layer + layers.append( + nn.Conv2d(mid_channels[-1], out_channels, kernel_size=3, padding=1, bias=False) + ) + layers.append(nn.BatchNorm2d(out_channels)) + layers.append(nn.ReLU(inplace=True)) + + # Combine layers into a sequential block + self.conv_block = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ConvBnRelu module. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Processed output tensor. + :rtype: torch.Tensor + """ + return self.conv_block(x) + +class Contract(nn.Module): + """Downscaling with maxpool then 2 * ConvBnRelu""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), # Halves spatial dimensions + ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=None, + num_layers=2) # Refines features with 2 sequential convolutions + ) + + def forward(self, x): + return self.maxpool_conv(x) + +class Up(nn.Module): + """Upscaling then 2 * ConvBnRelu""" + ## TODO: replace this with the Upsample and SkipConnection modules + def __init__(self, + in_channels: int, + out_channels: int, + bilinear: bool=True): + """ + Up sampling module that combines the upsampled feature map with the skip connection. + Upsampling is done via bilinear interpolation or transposed convolution. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + :param bilinear: If True, use bilinear upsampling + :type bilinear: bool + """ + super().__init__() + + # If bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv = ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=[in_channels // 2], + num_layers=2) + else: + self.up = nn.ConvTranspose2d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv = ConvBnRelu( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=None, + num_layers=2 + ) + + def forward(self, + x1: torch.Tensor, + x2: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the Up module. + :param x1: Input tensor to be upsampled. + :type x1: torch.Tensor + :param x2: Skip connection tensor. + :type x2: torch.Tensor + :return: Processed output tensor. + """ + x1 = self.up(x1) # Upsample x1 + + # Handle potential mismatches in spatial dimensions + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + # Concatenate x1 (upsampled) with x2 (skip connection) + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + +class OutConv(nn.Module): + """ + Final output layer that applies a 1x1 convolution followed by a sigmoid activation. + """ + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the OutConv module. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Processed output tensor. + :rtype: torch.Tensor + """ + return torch.sigmoid(self.conv(x)) + +class Expand(nn.Module): + """Handles upscaling of feature maps.""" + def __init__(self, + in_channels: int, + out_channels: int, + bilinear: bool=True): + """ + Initialize the Upsample module. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + :param bilinear: If True, use bilinear upscaling + :type bilinear: bool + """ + super().__init__() + if bilinear: + # Bilinear interpolation (non-trainable) + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + else: + # Transposed convolution (trainable) + self.up = nn.ConvTranspose2d( + in_channels, out_channels, kernel_size=2, stride=2 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the Upsample module. + + :param x: Input tensor to be upsampled. + :type x: torch.Tensor + :return: Upsampled tensor. + :rtype: torch.Tensor + """ + return self.up(x) + +class SkipConnection(nn.Module): + """Handles padding, concatenation, and DoubleConv refinement.""" + def __init__(self, + in_channels: int, + out_channels: int): + """ + Initialize the SkipConnection module. + Handles the dimension mismatch adjustment between a upsampled feature map and a skip connection, + plus refinement via a ConvBnRelu blocks. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + """ + super().__init__() + self.conv = ConvBnRelu( + in_channels=in_channels + out_channels, + out_channels=out_channels, + mid_channels=None, + num_layers=2 # Refines features with 2 sequential convolutions + ) + + def forward(self, + x1: torch.Tensor, + x2: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the SkipConnection module. + + :param x1: Upsampled feature map. + :type x1: torch.Tensor + :param x2: Skip connection feature map. + :type x2: torch.Tensor + :return: Refined concatenated feature map. + :rtype: torch.Tensor + """ + # Align spatial dimensions via padding + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + + # Concatenate along channel dimension + x = torch.cat([x2, x1], dim=1) + + # Refine concatenated feature map + return self.conv(x) \ No newline at end of file From 6a2b4c386e545dcf00b46d4ceec37d42636ab0d9 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:18:36 -0700 Subject: [PATCH 03/89] Added readme for dataset folder --- datasets/README.md | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 datasets/README.md diff --git a/datasets/README.md b/datasets/README.md new file mode 100644 index 0000000..6dd6d9e --- /dev/null +++ b/datasets/README.md @@ -0,0 +1,2 @@ +Here lives the dataset classes for interacting with cell painting images. +The datasets are currently completely dependent on the pe2loaddata generated csv files. \ No newline at end of file From bea6c02423294c50b89b07a55b804a36629827b4 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:18:42 -0700 Subject: [PATCH 04/89] Added dataset files --- datasets/CachedDataset.py | 231 +++++++++++++++ datasets/ImageDataset.py | 430 ++++++++++++++++++++++++++++ datasets/PatchDataset.py | 586 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 1247 insertions(+) create mode 100644 datasets/CachedDataset.py create mode 100644 datasets/ImageDataset.py create mode 100644 datasets/PatchDataset.py diff --git a/datasets/CachedDataset.py b/datasets/CachedDataset.py new file mode 100644 index 0000000..15a24f8 --- /dev/null +++ b/datasets/CachedDataset.py @@ -0,0 +1,231 @@ +from typing import Optional + +import torch +from torch.utils.data import Dataset +from collections import OrderedDict + +class CachedDataset(Dataset): + """ + A patched dataset that caches data from dataset objects that + dynamically loads the data to reduce memory overhead during training + """ + + def __init__( + self, + dataset: Dataset, + cache_size: Optional[int]=None, + prefill_cache: bool=False, + **kwargs + ): + """ + Initialize the CachedDataset from a dataset object + + :param dataset: Dataset object to cache data from + :type dataset: Dataset + :param cache_size: Size of the cache, if None, the cache + size is set to the length of the dataset. + :type cache_size: int + :param prefill_cache: Whether to prefill the cache + :type prefill_cache: bool + """ + + if len(dataset) == 0: + raise ValueError("Dataset is empty") + + self.__dataset = dataset + + self.__cache_size = cache_size if cache_size is not None else len(dataset) + self.__cache = OrderedDict() + + # cache for metadata + self.__cache_input_names = OrderedDict() + self.__cache_target_names = OrderedDict() + + # pointer to the current patch index + self._current_idx = None + + if prefill_cache: + self.cache() + + """ + Overriden methods for Dataset class + """ + def __len__(self): + """ + Return the length of the dataset + """ + return len(self.__dataset) + + def __getitem__(self, _idx: int): + """ + Get the data from the dataset object at the given index + If the data is not in the cache, load it from the dataset object and update the cache + + :param _idx: Index of the data to get + :type _idx: int + """ + self._current_idx = _idx + + if _idx in self.__cache: + # cache hit + return self.__cache[_idx] + else: + # cache miss, load from parent class method dynamically + self._update_cache(_idx) + return self.__cache[_idx] + + """ + Setters + """ + def set_cache_size(self, cache_size: int): + """ + Set the cache size. Does not automatically repopulate the cache but + will pop the cache if the size is exceeded + + :param cache_size: Size of the cache + :type cache_size: int + """ + self.__cache_size = cache_size + # pop the cache if the size is exceeded + while len(self.__cache) > self.__cache_size: + self._pop_cache() + + """ + Properties to remain accessible + """ + @property + def input_names(self): + """ + Get the input names from the dataset object + """ + if self._current_idx is not None: + ## TODO: need to think over if this is at all necessary + if self._current_idx in self.__cache_input_names: + return self.__cache_input_names[self._current_idx] + else: + _ = self.__dataset[self._current_idx] + return self.__dataset.input_names + else: + raise ValueError("No current index set") + + @property + def target_names(self): + """ + Get the target names from the dataset object + """ + if self._current_idx is not None: + ## TODO: need to think over if this is at all necessary + if self._current_idx in self.__cache_target_names: + return self.__cache_target_names[self._current_idx] + else: + _ = self.__dataset[self._current_idx] + return self.__dataset.target_names + else: + raise ValueError("No current index set") + + @property + def input_channel_keys(self): + """ + Get the input channel keys from the dataset object + """ + try: + return self.__dataset.input_channel_keys + except AttributeError: + return None + + @property + def target_channel_keys(self): + """ + Get the target channel keys from the dataset object + """ + try: + return self.__dataset.target_channel_keys + except AttributeError: + return None + + @property + def input_transform(self): + """ + Get the input transform from the dataset object + """ + return self.__dataset.input_transform + + @property + def target_transform(self): + """ + Get the target transform from the dataset object + """ + return self.__dataset.target_transform + + @property + def dataset(self): + """ + Get the dataset object + """ + return self.__dataset + + """ + Cache method + """ + def cache(self): + """ + Clears the current cache and re-populate cache with data from the dataset object + Iteratively calls the update cache method on a sequence of indices to fill the cache + """ + self._clear_cache() + for _idx in range(min(self.__cache_size, len(self.__dataset))): + self._update_cache(_idx) + + """ + Internal helper methods + """ + + def _update_cache(self, _idx: int): + """ + Update the cache with data from the dataset object. + Calls the update cache metadata method as well to sync data and metadata + Pops the cache if the cache size is exceeded on a first in, first out basis + + :param _idx: Index of the data to cache + :type _idx: int + """ + self._current_idx = _idx + self.__cache[_idx] = self.__dataset[_idx] + if len(self.__cache) >= self.__cache_size: + self._pop_cache() + self._update_cache_metadata(_idx) + + def _pop_cache(self): + """ + Helper method to pop the cache on a first in, first out basis + """ + self.__cache.popitem(last=False) + + def _update_cache_metadata(self, _idx: int): + """ + Update the cache metadata with data from the dataset object + Meant to be called by _update_cache method + + :param _idx: Index of the data to cache + :type _idx: int + """ + self.__cache_input_names[_idx] = self.__dataset.input_names + self.__cache_target_names[_idx] = self.__dataset.target_names + + if len(self.__cache_input_names) >= self.__cache_size: + self._pop_cache_metadata() + + def _pop_cache_metadata(self): + """ + Helper method to pop the cache metadata on a first in, first out basis + """ + self.__cache_input_names.popitem(last=False) + self.__cache_target_names.popitem(last=False) + + def _clear_cache(self): + """ + Clear the cache and cache metadata + """ + self.__cache.clear() + self.__cache_input_names.clear() + self.__cache_target_names.clear() \ No newline at end of file diff --git a/datasets/ImageDataset.py b/datasets/ImageDataset.py new file mode 100644 index 0000000..a8aa204 --- /dev/null +++ b/datasets/ImageDataset.py @@ -0,0 +1,430 @@ +import logging +import math +import pathlib +import random +from random import randint +from typing import List, Optional, Pattern, Tuple + +import numpy as np +import pandas as pd +from PIL import Image +from albumentations import ImageOnlyTransform +from albumentations.core.composition import Compose +import torch +from torch.utils.data import Dataset + +class ImageDataset(Dataset): + """ + Image Dataset Class from pe2loaddata generated cellprofiler loaddata csv + """ + def __init__( + self, + _loaddata_csv, + _input_channel_keys: Optional[str | List[str]] = None, + _target_channel_keys: Optional[str | List[str]] = None, + _input_transform: Optional[Compose | ImageOnlyTransform] = None, + _target_transform: Optional[Compose | ImageOnlyTransform] = None, + _PIL_image_mode: str = 'I;16', + verbose: bool = False, + file_column_prefix: str = 'FileName_', + path_column_prefix: str = 'PathName_', + check_exists: bool = False, + **kwargs + ): + """ + Initialize the ImageDataset. + + :param _loaddata_csv: The dataframe or path to a csv file containing the image paths and labels. + :type _loaddata_csv: Union[pd.DataFrame, str] + :param _input_channel_keys: Keys for input channels. Can be a single key or a list of keys. + :type _input_channel_keys: Optional[Union[str, List[str]]] + :param _target_channel_keys: Keys for target channels. Can be a single key or a list of keys. + :type _target_channel_keys: Optional[Union[str, List[str]]] + :param _input_transform: Transformations to apply to the input images. + :type _input_transform: Optional[Union[Compose, ImageOnlyTransform]] + :param _target_transform: Transformations to apply to the target images. + :type _target_transform: Optional[Union[Compose, ImageOnlyTransform]] + :param _PIL_image_mode: Mode to use when loading images with PIL. Default is 'I;16'. + :type _PIL_image_mode: str + :param kwargs: Additional keyword arguments. + """ + + self._initialize_logger(verbose) + self._loaddata_df = self._load_loaddata(_loaddata_csv, **kwargs) + self._channel_keys = list(self.__infer_channel_keys(file_column_prefix, path_column_prefix)) + + # Initialize the cache for the input and target images + self.__input_cache = {} + self.__target_cache = {} + self.__cache_image_id = None + + # Set input/target channels + self.logger.debug("Setting input channel(s) ...") + self._input_channel_keys = self.__check_channel_keys(_input_channel_keys) + self.logger.debug("Setting target channel(s) ...") + self._target_channel_keys = self.__check_channel_keys(_target_channel_keys) + + self.set_input_transform(_input_transform) + self.set_target_transform(_target_transform) + + self._PIL_image_mode = _PIL_image_mode + + # Obtain image paths + self.__image_paths = self._get_image_paths( + file_column_prefix=file_column_prefix, + path_column_prefix=path_column_prefix, + check_exists=check_exists, + **kwargs + ) + # Index patches and images + self.__iter_image_id = list(range(len(self.__image_paths))) + + # Initialize the current input and target names + self.__current_input_names = None + self.__current_target_names = None + + """ + Overridden Iterator functions + """ + def __len__(self): + return len(self.__image_paths) + + def __getitem__(self, _idx: int)->Tuple[torch.Tensor, torch.Tensor]: + """ + Return the input and target images + :param _idx: The index of the image + :type _idx: int + :return: The input and target images, each with dimension [n_channels, height, width] + :rtype: Tuple[torch.Tensor, torch.Tensor] + """ + + if _idx >= len(self) or _idx < 0: + raise IndexError("Index out of bounds") + + if self._input_channel_keys is None or self._target_channel_keys is None: + raise ValueError("Input and target channel keys must be set to access data") + + image_id = self.__iter_image_id[_idx] + self._cache_image(image_id) + + ## Retrieve relevant channels as specified by input and target channel keys and stack + input_images = np.stack( + [self.__input_cache[key] for key in self._input_channel_keys], + axis=0) + target_images = np.stack( + [self.__target_cache[key] for key in self._target_channel_keys], + axis=0) + + ## Apply transform + if self._input_transform: + input_images = self._input_transform(image=input_images)['image'] + if self._target_transform: + target_images = self._target_transform(image=target_images)['image'] + + ## Cast to torch tensor and return + return torch.from_numpy(input_images).float(), torch.from_numpy(target_images).float() + + """ + Properties + """ + + @property + def image_paths(self): + return self.__image_paths + + @property + def input_transform(self): + return self._input_transform + + @property + def target_transform(self): + return self._target_transform + + @property + def input_channel_keys(self): + return self._input_channel_keys + + @property + def target_channel_keys(self): + return self._target_channel_keys + @property + def input_names(self): + return self.__current_input_names + + @property + def target_names(self): + return self.__current_target_names + + """ + Setters + """ + + def set_input_transform(self, _input_transform: Compose | ImageOnlyTransform | None): + """ + Set the input transform + + :param _input_transform: The input transform + :type _input_transform: Compose or ImageOnlyTransform + """ + # Check and set input/target transforms + self.logger.debug("Setting input transform ...") + if self.__check_transforms(_input_transform): + self._input_transform = _input_transform + + + def set_target_transform(self, _target_transform: Compose | ImageOnlyTransform | None): + """ + Set the target transform + + :param _target_transform: The target transform + :type _target_transform: Compose or ImageOnlyTransform + """ + # Check and set input/target transforms + self.logger.debug("Setting target transform ...") + if self.__check_transforms(_target_transform): + self._target_transform = _target_transform + + def set_input_channel_keys(self, _input_channel_keys: str | List[str]): + """ + Set the input channel keys + + :param _input_channel_keys: The input channel keys + :type _input_channel_keys: str or list of str + """ + self._input_channel_keys = self.__check_channel_keys(_input_channel_keys) + self.logger.debug(f"Set input channel(s) as {self._input_channel_keys}") + + # clear the cache + self.__cache_image_id = None + + def set_target_channel_keys(self, _target_channel_keys: str | List[str]): + """ + Set the target channel keys + + :param _target_channel_keys: The target channel keys + :type _target_channel_keys: str or list of str + """ + self._target_channel_keys = self.__check_channel_keys(_target_channel_keys) + self.logger.debug(f"Set target channel(s) as {self._target_channel_keys}") + + # clear the cache + self.__cache_image_id = None + + """ + Internal Helper functions + """ + def _initialize_logger(self, verbose: bool): + """ + Initialize logger instance + """ + self.logger = logging.getLogger(f"{__name__}.{id(self)}") + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.DEBUG if verbose else logging.WARNING) + + def _load_loaddata( + self, + _loaddata_csv: pd.DataFrame | pathlib.Path, + **kwargs + ) -> pd.DataFrame: + """ + Read loaddata csv file + """ + + if _loaddata_csv is None: + raise ValueError("No loaddata csv supplied") + elif isinstance(_loaddata_csv, pd.DataFrame): + self.logger.debug("Dataframe supplied for loaddata_csv, using as is") + return _loaddata_csv + else: + self.logger.debug("Loading loaddata csv from file") + ## Convert string to pathlib Path + if not isinstance(_loaddata_csv, pathlib.Path): + try: + _loaddata_csv = pathlib.Path(_loaddata_csv) + except e: + raise e + + ## Handle file not exist + if not _loaddata_csv.exists(): + raise FileNotFoundError(f"File {_loaddata_csv} not found") + + ## Determine file extension and load accordingly + if _loaddata_csv.suffix == '.csv': + return pd.read_csv(_loaddata_csv) + elif _loaddata_csv.suffix == '.parquet': + return pd.read_parquet(_loaddata_csv) + else: + raise ValueError(f"File type {_loaddata_csv.suffix} not supported") + + def __infer_channel_keys( + self, + file_column_prefix: str, + path_column_prefix: str + ) -> set[str]: + """ + Infer channel names from the columns of loaddata csv + """ + + self.logger.debug("Inferring channel keys from loaddata csv") + # Retrieve columns that indicate path and filename to image files + file_columns = [col for col in self._loaddata_df.columns if col.startswith(file_column_prefix)] + path_columns = [col for col in self._loaddata_df.columns if col.startswith(path_column_prefix)] + + if len(file_columns) == 0 or len(path_columns) == 0: + raise ValueError('No path or file columns found in loaddata csv.') + + # Anything following the prefix should be the channel names + file_channel_keys = [col.replace(file_column_prefix, '') for col in file_columns] + path_channel_keys = [col.replace(path_column_prefix, '') for col in path_columns] + channel_keys = set(file_channel_keys).intersection(set(path_channel_keys)) + + if len(channel_keys) == 0: + raise ValueError('No matching channel keys found between file and path columns.') + + self.logger.debug(f"Channel keys: {channel_keys} inferred from loaddata csv") + + return channel_keys + + def __check_channel_keys( + self, + channel_keys: Optional[str | List[str]] + ) -> List[str]: + """ + Checks user supplied channel key against the inferred ones from the file + """ + if channel_keys is None: + self.logger.debug("No channel keys specified, skip") + return None + elif isinstance(channel_keys, str): + channel_keys = [channel_keys] + elif isinstance(channel_keys, list): + if not all([isinstance(key, str) for key in channel_keys]): + raise ValueError('Channel keys must be a string or a list of strings.') + else: + raise ValueError('Channel keys must be a string or a list of strings.') + + ## Check supplied channel keys against inferred ones + filtered_channel_keys = [] + for key in channel_keys: + if not key in self._channel_keys: + self.logger.debug( + f"ignoring channel key {key} as it does not match loaddata csv file" + ) + else: + filtered_channel_keys.append(key) + + if len(filtered_channel_keys) == 0: + raise ValueError(f'None of the supplied channel keys match the loaddata csv file') + + return filtered_channel_keys + + def __check_transforms( + self, + transforms: Optional[Compose | ImageOnlyTransform] + ) -> bool: + """ + Checks if supplied iamge only transform is of valid type, if so, return True + """ + if transforms is None: + pass + elif isinstance(transforms, Compose): + pass + elif isinstance(transforms, ImageOnlyTransform): + pass + else: + raise TypeError('Invalid image transform type') + + return True + + def _get_image_paths(self, + file_column_prefix: str, + path_column_prefix: str, + check_exists: bool = False, + **kwargs, + ) -> List[dict]: + """ + From loaddata csv, extract the paths to all image channels cooresponding to each view/site + + :param check_exists: check if every individual image file exist and remove those that do not + :type check_exists: bool + :return: A list of dictionaries containing the paths to the image channels + :rtype: List[dict] + """ + + # Define helper function to get the image file paths from all channels + # in a single row of loaddata csv (single view/site), organized into a dict + def get_channel_paths(row: pd.Series) -> Tuple[dict, bool]: + + missing = False + + multi_channel_paths = {} + for channel_key in self._channel_keys: + file_column = f"{file_column_prefix}{channel_key}" + path_column = f"{path_column_prefix}{channel_key}" + + if file_column in row and path_column in row: + file = pathlib.Path( + row[path_column] + ) / row[file_column] + if (not check_exists) or file.exists(): + multi_channel_paths[channel_key] = file + else: + missing = True + + return multi_channel_paths, missing + + image_paths = [] + self.logger.debug( + "Extracting image channel paths of site/view and associated"\ + "cell coordinates (if applicable) from loaddata csv") + + for _, row in self._loaddata_df.iterrows(): + multi_channel_paths, missing = get_channel_paths(row) + if not missing: + image_paths.append(multi_channel_paths) + + self.logger.debug(f"Extracted images of all input and target channels for {len(image_paths)} unique sites/view") + return image_paths + + def _read_convert_image(self, _image_path: pathlib.Path)->np.ndarray: + """ + Read and convert the image to a numpy array + :param _image_path: The path to the image + :type _image_path: pathlib.Path + :return: The image as a numpy array + :rtype: np.ndarray + """ + return np.array(Image.open(_image_path).convert(self._PIL_image_mode)) + + def _cache_image(self, _id: int)->None: + """ + Determines if cached images need to be updated and updates the self.__input_cache and self.__target_cache + Meant to be called by __getitem__ method in dynamic patch cropping + + :param _id: The index of the image + :type _id: int + :return: None + :rtype: None + """ + + if self.__cache_image_id is None or self.__cache_image_id != _id: + self.__cache_image_id = _id + self.__input_cache = {} + self.__target_cache = {} + + ## Update target and input names (which are just file path(s)) + self.__current_input_names = [self.__image_paths[_id][key] for key in self._input_channel_keys] + self.__current_target_names = [self.__image_paths[_id][key] for key in self._target_channel_keys] + + for key in self._input_channel_keys: + self.__input_cache[key] = self._read_convert_image(self.__image_paths[_id][key]) + for key in self._target_channel_keys: + self.__target_cache[key] = self._read_convert_image(self.__image_paths[_id][key]) + else: + # No need to update the cache + pass + + return None \ No newline at end of file diff --git a/datasets/PatchDataset.py b/datasets/PatchDataset.py new file mode 100644 index 0000000..0dd8656 --- /dev/null +++ b/datasets/PatchDataset.py @@ -0,0 +1,586 @@ +import logging +import math +import pathlib +import random +from random import randint +from typing import List, Optional, Tuple + +import numpy as np +import pandas as pd +from pyarrow import parquet as pq +import torch + +from .ImageDataset import ImageDataset + +class PatchDataset(ImageDataset): + """ + Patch Dataset Class from pe2loaddata generated cellprofiler loaddata csv and sc features + """ + def __init__( + self, + _sc_feature: Optional[pd.DataFrame | pathlib.Path] = None, + patch_size: int = 64, + patch_generation_method: str = 'random', + patch_generation_random_seed: Optional[int] = None, + patch_generation_max_attempts: int = 1_000, + n_expected_patches_per_img: int = 5, + candidate_x: str = 'Metadata_Cells_Location_Center_X', + candidate_y: str = 'Metadata_Cells_Location_Center_Y', + **kwargs + ): + + self._patch_size = patch_size + self._merge_fields = None + self._x_col = None + self._y_col = None + self.__cell_coords = [] + + # This intializes the channels keys, loaddata loading, image mode and + # the overriden methods further merge the loaddata with sc features + super().__init__(_sc_feature=_sc_feature, + candidate_x=candidate_x, + candidate_y=candidate_y, + **kwargs) + + + self.__patch_coords = self._generate_patches( + _patch_size=self._patch_size, + patch_generation_method=patch_generation_method, + patch_generation_random_seed=patch_generation_random_seed, + n_expected_patches_per_img=n_expected_patches_per_img, + max_attempts=patch_generation_max_attempts, + consistent_img_size=kwargs.get('consistent_img_size', True), + ) + + # Index patches and images + self.__iter_image_id = [] + self.__iter_patch_id = [] + for i, _patch_coords in enumerate(self.__patch_coords): + for j, _ in enumerate(_patch_coords): + self.__iter_image_id.append(i) + self.__iter_patch_id.append(j) + + # Initialize the cache for the input and target images + self.__input_cache = {} + self.__target_cache = {} + self.__cache_image_id = None + + # Initialize the current input and target names and patch coordinates + self.__current_input_names = None + self.__current_target_names = None + self.__current_patch_coords = None + + """ + Overridden Iterator functions + """ + def __len__(self): + return len(self.__patch_coords) + + def __getitem__(self, _idx: int)->Tuple[torch.Tensor, torch.Tensor]: + """ + Return the input and target images + :param _idx: The index of the image + :type _idx: int + :return: The input and target images, each with dimension [n_channels, height, width] + :rtype: Tuple[torch.Tensor, torch.Tensor] + """ + + if _idx >= len(self) or _idx < 0: + raise IndexError("Index out of bounds") + + if self._input_channel_keys is None or self._target_channel_keys is None: + raise ValueError("Input and target channel keys must be set to access data") + + image_id = self.__iter_image_id[_idx] + patch_id = self.__iter_patch_id[_idx] + self.__current_patch_coords = self.__patch_coords[image_id][patch_id] + + self._cache_image(image_id) + + ## Retrieve relevant channels as specified by input and target channel keys and stack + ## And further crop the patches with __current_patch_coords + input_images = np.stack( + [self._ImageDataset__input_cache[key][ + self.__current_patch_coords[1]:self.__current_patch_coords[1] + self._patch_size, + self.__current_patch_coords[0]:self.__current_patch_coords[0] + self._patch_size + ] for key in self._input_channel_keys], + axis=0) + target_images = np.stack( + [self._ImageDataset__target_cache[key][ + self.__current_patch_coords[1]:self.__current_patch_coords[1] + self._patch_size, + self.__current_patch_coords[0]:self.__current_patch_coords[0] + self._patch_size + ] for key in self._target_channel_keys], + axis=0) + + ## Apply transform + if self._input_transform: + input_images = self._input_transform(image=input_images)['image'] + if self._target_transform: + target_images = self._target_transform(image=target_images)['image'] + + ## Cast to torch tensor and return + return torch.from_numpy(input_images).float(), torch.from_numpy(target_images).float() + + """ + Properties + """ + + @property + def patch_size(self): + return self._patch_size + + @property + def cell_coords(self): + return self.__cell_coords + + @property + def all_patch_coords(self): + return self.__patch_coords + + @property + def patch_coords(self): + return self.__current_patch_coords + + """ + Internal Helper functions + """ + + def __preload_sc_feature(self, + _sc_feature: pd.DataFrame | pathlib.Path | None) -> List[str]: + """ + Preload the sc feature dataframe/parquet file limiting only to the column headers + If a dataframe is supplied, use as is + If a path to a csv file is supplied, load only the header row + If a path to a parquet file is supplied, load only the parquet schema name + :param _sc_feature: The path to a csv file containing the cell profiler sc features + :type _sc_feature: str or pathlib.Path + """ + + if _sc_feature is None: + # No sc feature supplied, cell coordinates not available, patch generation will fixed random + self.logger.debug("No sc feature supplied, patch generation will be random") + self._patch_generation_method = 'random' + return None + + elif isinstance(_sc_feature, pd.DataFrame): + self.logger.debug("Dataframe supplied for sc_feature, using as is") + return _sc_feature.columns.tolist() + + else: + self.logger.debug("Preloading sc feature from file") + if not isinstance(_sc_feature, pathlib.Path): + try: + _sc_feature = pathlib.Path(_sc_feature) + except e: + raise e + + if not _sc_feature.exists(): + raise FileNotFoundError(f"File {_sc_feature} not found") + + if _sc_feature.suffix == '.csv': + self.logger.debug("Preloading sc feature from csv file") + return pd.read_csv(_sc_feature, nrows=0).columns.tolist() + elif _sc_feature.suffix == '.parquet': + pq_file = pq.ParquetFile(_sc_feature) + return pq_file.schema.names + else: + raise ValueError(f"File type {_sc_feature.suffix} not supported") + + def __infer_merge_fields(self, + _loaddata_df, + _sc_col_names: List[str] + ) -> List[str] | None: + """ + Find the columns that are common to both dataframes to use in an inner join + Mean to be used to associate loaddata_csv with sc features + + :param loaddata_csv: The first dataframe + :type loaddata_csv: pd.DataFrame + :param sc_feature: The second dataframe + :type sc_feature: pd.DataFrame + :return: The columns that are common to both dataframes + :rtype: List[str] + """ + if _sc_col_names is None: + return None + + self.logger.debug("Both loaddata_csv and sc_feature supplied, " \ + "inferring merge fields to associate the two dataframes") + merge_fields = list(set(_loaddata_df.columns).intersection(set(_sc_col_names))) + if len(merge_fields) == 0: + raise ValueError("No common columns found between loaddata_csv and sc_feature") + self.logger.debug(f"Merge fields inferred: {merge_fields}") + + return merge_fields + + def __infer_x_y_columns(self, + _loaddata_df, + _sc_col_names: List[str], + candidate_x: str, + candidate_y: str) -> Tuple[str, str]: + """ + Infer the columns that contain the x and y coordinates of the patches + :param candidate_x: The candidate column name for the x coordinates + :type candidate_x: str + :param candidate_y: The candidate column name for the y coordinates + :type candidate_y: str + :return: The columns that contain the x and y coordinates of the patches + :rtype: Tuple[str, str] + """ + + if _loaddata_df is None: + return None, None + + if candidate_x not in _sc_col_names or candidate_y not in _sc_col_names: + self.logger.debug(f"X and Y columns {candidate_x}, {candidate_y} not detected in sc_features, attempting to infer from sc_feature dataframe") + + # infer the columns that contain the x and y coordinates + x_col_candidates = [col for col in _sc_col_names if col.lower().endswith('_x')] + y_col_candidates = [col for col in _sc_col_names if col.lower().endswith('_y')] + + if len(x_col_candidates) == 0 or len(y_col_candidates) == 0: + raise ValueError("No columns found containing the x and y coordinates") + else: + # sort x_col and y_col candidates + x_col_candidates.sort() + y_col_candidates.sort() + x_col_detected = x_col_candidates[0] + y_col_detected = y_col_candidates[0] + self.logger.debug(f"X and Y columns {x_col_detected}, {y_col_detected} detected in sc_feature dataframe, using as the coordinates for cell centers") + return x_col_detected, y_col_detected + else: + self.logger.debug(f"X and Y columns {candidate_x}, {candidate_y} detected in sc_feature dataframe, using as the coordinates for cell centers") + return candidate_x, candidate_y + + def __load_sc_feature(self, + _sc_feature: pd.DataFrame | pathlib.Path | None, + _merge_fields: List[str], + _x_col: str, + _y_col: str + ) -> pd.DataFrame | None: + """ + Load the actual sc feature as a dataframe, limiting the columns to the merge fields and the x and y coordinates + :param _sc_feature: The path to a csv file containing the cell profiler sc features + :type _sc_feature: str or pathlib.Path + :return: The dataframe containing the cell profiler sc features + :rtype: pd.DataFrame + """ + + if _sc_feature is None: + return None + elif isinstance(_sc_feature, pd.DataFrame): + self.logger.debug("Dataframe supplied for sc_feature, using as is") + return _sc_feature + else: + self.logger.debug("Loading sc feature from file") + if not isinstance(_sc_feature, pathlib.Path): + try: + _sc_feature = pathlib.Path(_sc_feature) + except e: + raise e + + if not _sc_feature.exists(): + raise FileNotFoundError(f"File {_sc_feature} not found") + + if _sc_feature.suffix == '.csv': + return pd.read_csv(_sc_feature, + usecols=_merge_fields + [_x_col, _y_col]) + elif _sc_feature.suffix == '.parquet': + return pq.read_table(_sc_feature, columns=_merge_fields + [_x_col, _y_col]).to_pandas() + else: + raise ValueError(f"File type {_sc_feature.suffix} not supported") + + """ + Overriden parent class helper functions + """ + + def _load_loaddata(self, + _loaddata_csv: pd.DataFrame | pathlib.Path, + _sc_feature: Optional[pd.DataFrame | pathlib.Path], + candidate_x: str, + candidate_y: str, + ): + """ + Overridden function from parent class + Calls the parent class to get the loaddata df and then merges it with sc_feature + """ + + ## First calls the parent class to get the full loaddata df + loaddata_df = super()._load_loaddata(_loaddata_csv) + + ## Obtain column names of sc features first to avoid needing to read in the whole + ## Parquet file as only a very small number of columns are needed + sc_feature_col_names = self.__preload_sc_feature(_sc_feature) + + ## Infer columns corresponding to x and y coordinates to cells + self._x_col, self._y_col = self.__infer_x_y_columns( + loaddata_df, sc_feature_col_names, candidate_x, candidate_y) + + ## Infer merge fields between the sc features and loaddata + self._merge_fields = self.__infer_merge_fields(loaddata_df, sc_feature_col_names) + + ## Load sc features + sc_feature_df = self.__load_sc_feature( + _sc_feature, self._merge_fields, self._x_col, self._y_col) + + ## Perform the merge and return the merged dataframe (which is loaddata plus columns for x and y coordinates) + return loaddata_df.merge(sc_feature_df, on=self._merge_fields, how='inner') + + def _get_image_paths(self, + file_column_prefix: str, + path_column_prefix: str, + check_exists: bool = False, + **kwargs, + ) -> List[dict]: + """ + Overridden function + From loaddata csv, extract the paths to all image channels cooresponding to each view/site + + :param check_exists: check if every individual image file exist and remove those that do not + :type check_exists: bool + :return: A list of dictionaries containing the paths to the image channels + :rtype: List[dict] + """ + + # Define helper function to get the image file paths from all channels + # in a single row of loaddata csv (single view/site), organized into a dict + def get_channel_paths(row: pd.Series) -> Tuple[dict, bool]: + + missing = False + + multi_channel_paths = {} + for channel_key in self._channel_keys: + file_column = f"{file_column_prefix}{channel_key}" + path_column = f"{path_column_prefix}{channel_key}" + + if file_column in row and path_column in row: + file = pathlib.Path( + row[path_column] + ) / row[file_column] + if (not check_exists) or file.exists(): + multi_channel_paths[channel_key] = file + else: + missing = True + + return multi_channel_paths, missing + + # Define helper function to get the coordinates associated with a condition + def get_associated_coords(group): + + try: + return group.loc[:, [self._x_col, self._y_col]].values + except: + return None + + image_paths = [] + cell_coords = [] + self.logger.debug( + "Extracting image channel paths of site/view and associated"\ + "cell coordinates (if applicable) from loaddata csv") + + n_cells = 0 + grouped = self._loaddata_df.groupby(self._merge_fields) + for _, group in grouped: + + _, row = next(group.iterrows()) + multi_channel_paths, missing = get_channel_paths(row) + if not missing: + image_paths.append(multi_channel_paths) + coords = get_associated_coords(group) + n_cells += len(coords) + cell_coords.append(coords) + + self.logger.debug("Extracted images of all input and target channels for " \ + f"{len(image_paths)} unique sites/view and {n_cells} cells") + + self.__cell_coords = cell_coords + + return image_paths + + """ + Patch generation helper functions + """ + + def _generate_patches(self, + _patch_size: int, + patch_generation_method: str, + patch_generation_random_seed: int, + n_expected_patches_per_img=5, + max_attempts=1_000, + consistent_img_size=True, + )->None: + """ + Generate patches for each image in the dataset + :param patch_generation_method: The method to use for generating patches + :type patch_generation_method: str + :param patch_generation_random_seed: The random seed to use for patch generation + :type patch_generation_random_seed: int + :param consistent_img_size: Whether the images are consistent in size. + If True, the patch generation will be based on the size of the first input channel of first image + If False, the patch generation will be based on the size of each image + :type consistent_img_size: bool + :param n_expected_patches_per_img: The number of patches to generate per image + :type n_expected_patches_per_img: int + :param max_attempts: The maximum number of attempts to generate a patch + :type max_attempts: int + :return: The coordinates of the patches + :rtype: List[List[Tuple[int + """ + if patch_generation_method == 'random_cell': + if self.__cell_coords is None: + raise ValueError("Cell coordinates not available for generating cell containing patches") + else: + self.logger.debug("Generating patches that contain cells") + def patch_fn(image_size, patch_size, cell_coords, n_expected_patches_per_img, max_attempts): + return self.__generate_cell_containing_patches_unit( + image_size, patch_size, cell_coords, n_expected_patches_per_img, max_attempts) + pass + elif patch_generation_method == 'random': + self.logger.debug("Generating random patches") + def patch_fn(image_size, patch_size, cell_coords, n_expected_patches_per_img, max_attempts): + # cell_coords is not used in this case + return self.__generate_random_patches_unit(image_size, patch_size, n_expected_patches_per_img, max_attempts) + pass + else: + raise ValueError("Patch generation method not supported") + + # Generate patches for each image + image_size = None + patch_count = 0 + patch_coords = [] + + # set random seed + if patch_generation_random_seed is not None: + random.seed(patch_generation_random_seed) + for channel_paths, cell_coords in zip(self._ImageDataset__image_paths, self.__cell_coords): + if consistent_img_size: + if image_size is not None: + pass + else: + try: + image_size = self._read_convert_image(channel_paths[self._channel_keys[0]]).shape[0] + self.logger.debug( + f"Image size inferred: {image_size} for all images " + "to force redetect image sizes for each view/site set consistent_img_size=False" + ) + except: + raise ValueError("Error reading image size") + pass + else: + try: + image_size = self._read_convert_image(channel_paths[self._channel_keys[0]]).shape[0] + except: + raise ValueError("Error reading image size") + + coords = patch_fn( + image_size=image_size, + patch_size=_patch_size, + cell_coords=cell_coords, + n_expected_patches_per_img=n_expected_patches_per_img, + max_attempts=max_attempts + ) + patch_coords.append(coords) + patch_count += len(coords) + + self.logger.debug(f"Generated {patch_count} patches for {len(self._ImageDataset__image_paths)} site/view") + return patch_coords + + @staticmethod + def __generate_cell_containing_patches_unit( + image_size, + patch_size, + cell_coords, + expected_n_patches=5, + max_attempts=1_000): + """ + Static helper function to generate patches that contain the cell coordinates + :param image_size: The size of the image (square) + :type image_size: int + :param patch_size: The size of the square patches to generate + :type patch_size: int + :param cell_coords: The coordinates of the cells + :type cell_coords: List[Tuple[int, int]] + :param expected_n_patches: The number of patches to generate + :type expected_n_patches: int + :return: The coordinates of the patches + """ + + unit_size = math.gcd(image_size, patch_size) + tile_size_units = patch_size // unit_size + grid_size_units = image_size // unit_size + + cell_containing_units = {(x // unit_size, y // unit_size) for x, y in cell_coords} + placed_tiles = set() + retained_tiles = [] + + attempts = 0 + n_tiles = 0 + while attempts < max_attempts: + top_left_x = randint(0, grid_size_units - tile_size_units) + top_left_y = randint(0, grid_size_units - tile_size_units) + + tile_units = {(x, y) for x in range(top_left_x, top_left_x + tile_size_units) + for y in range(top_left_y, top_left_y + tile_size_units)} + + if any(tile_units & placed_tile for placed_tile in placed_tiles): + attempts += 1 + continue + + if tile_units & cell_containing_units: + retained_tiles.append((top_left_x * unit_size, top_left_y * unit_size)) + placed_tiles.add(frozenset(tile_units)) + n_tiles += 1 + + attempts += 1 + if n_tiles >= expected_n_patches: + break + + return retained_tiles + + @staticmethod + def __generate_random_patches_unit( + image_size, + patch_size, + expected_n_patches=5, + max_attempts=1_000): + """ + Static helper function to generate random patches + :param image_size: The size of the image (square) + :type image_size: int + :param patch_size: The size of the square patches to generate + :type patch_size: int + :param expected_n_patches: The number of patches to generate + :type expected_n_patches: int + :return: The coordinates of the patches + """ + unit_size = math.gcd(image_size, patch_size) + tile_size_units = patch_size // unit_size + grid_size_units = image_size // unit_size + + placed_tiles = set() + retained_tiles = [] + + attempts = 0 + n_tiles = 0 + while attempts < max_attempts: + top_left_x = randint(0, grid_size_units - tile_size_units) + top_left_y = randint(0, grid_size_units - tile_size_units) + + # check for overlap with already placed tiles + tile_units = {(x, y) for x in range(top_left_x, top_left_x + tile_size_units) + for y in range(top_left_y, top_left_y + tile_size_units)} + + if any(tile_units & placed_tile for placed_tile in placed_tiles): + attempts += 1 + continue + + # no overlap, add the tile to the list of retained tiles + retained_tiles.append((top_left_x * unit_size, top_left_y * unit_size)) + placed_tiles.add(frozenset(tile_units)) + n_tiles += 1 + + attempts += 1 + if n_tiles >= expected_n_patches: + break + + return retained_tiles \ No newline at end of file From 5bfaad6b6e88980aacb843cb7d43a5358ec4be23 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:20:24 -0700 Subject: [PATCH 05/89] Added readme for trainers --- trainers/README.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 trainers/README.md diff --git a/trainers/README.md b/trainers/README.md new file mode 100644 index 0000000..06173b4 --- /dev/null +++ b/trainers/README.md @@ -0,0 +1 @@ +Here lives the trainer class for FNet/UNet and wGaN GP. Shared components between trainers for different models are isolated into the asbstract trainer class \ No newline at end of file From 31f2f1b2637926c230fc25b74e9c3f81b05f8daa Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:20:36 -0700 Subject: [PATCH 06/89] Added trainer files --- trainers/AbstractTrainer.py | 413 ++++++++++++++++++++++++++++++++++++ trainers/Trainer.py | 170 +++++++++++++++ trainers/WGaNTrainer.py | 274 ++++++++++++++++++++++++ 3 files changed, 857 insertions(+) create mode 100644 trainers/AbstractTrainer.py create mode 100644 trainers/Trainer.py create mode 100644 trainers/WGaNTrainer.py diff --git a/trainers/AbstractTrainer.py b/trainers/AbstractTrainer.py new file mode 100644 index 0000000..ed9a9c0 --- /dev/null +++ b/trainers/AbstractTrainer.py @@ -0,0 +1,413 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import List, Callable, Dict, Optional + +import torch +from torch.utils.data import DataLoader, random_split + +from ..metrics.AbstractMetrics import AbstractMetrics +from ..callbacks.AbstractCallback import AbstractCallback + + +class AbstractTrainer(ABC): + """ + Abstract trainer class for img2img translation models. + Provides shared dataset handling and modular callbacks for logging and evaluation. + """ + + def __init__( + self, + dataset: torch.utils.data.Dataset, + batch_size: int = 16, + epochs: int = 10, + patience: int = 5, + callbacks: List[AbstractCallback] = None, + metrics: Dict[str, AbstractMetrics] = None, + device: Optional[torch.device] = None, + **kwargs, + ): + """ + :param dataset: The dataset to be used for training. + :type dataset: torch.utils.data.Dataset + :param batch_size: The batch size for training. + :type batch_size: int + :param epochs: The number of epochs for training. + :type epochs: int + :param patience: The number of epochs with no improvement after which training will be stopped. + :type patience: int + :param callbacks: List of callback functions to be executed + at the end of each epoch. + :type callbacks: list of callable + :param metrics: Dictionary of metrics to be logged. + :type metrics: dict + :param device: (optional) The device to be used for training. + :type device: torch.device + """ + + self._batch_size = batch_size + self._epochs = epochs + self._patience = patience + self._callbacks = callbacks if callbacks else [] + self._metrics = metrics if metrics else {} + + if isinstance(device, torch.device): + self._device = device + else: + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self._best_model = None + self._best_loss = float("inf") + self._early_stop_counter = 0 + + # Customize data splits + self._train_ratio = kwargs.get("train", 0.7) + self._val_ratio = kwargs.get("val", 0.15) + self._test_ratio = kwargs.get("test", 1.0 - self._train_ratio - self._val_ratio) + + if not (0 < self._train_ratio + self._val_ratio + self._test_ratio <= 1.0): + raise ValueError("Data split ratios must sum to 1.0 or less.") + + train_size = int(self._train_ratio * len(dataset)) + val_size = int(self._val_ratio * len(dataset)) + test_size = len(dataset) - train_size - val_size + self._train_dataset, self._val_dataset, self._test_dataset = random_split( + dataset, [train_size, val_size, test_size] + ) + + # Create DataLoaders + self._train_loader = DataLoader( + self._train_dataset, batch_size=self._batch_size, shuffle=True + ) + self._val_loader = DataLoader( + self._val_dataset, batch_size=self._batch_size, shuffle=False + ) + + # Epoch counter + self._epoch = 0 + + # Loss and metrics storage + self._train_losses = defaultdict(list) + self._val_losses = defaultdict(list) + self._train_metrics = defaultdict(list) + self._val_metrics = defaultdict(list) + + @abstractmethod + def train_step(self, inputs: torch.tensor, targets: torch.tensor)->Dict[str, torch.Tensor]: + """ + Abstract method for training the model on one batch + Must be implemented by subclasses. + This should be where the losses and metrics are calculated. + Should return a dictionary with loss name as key and torch tensor loss as value. + + :param inputs: The input data. + :type inputs: torch.Tensor + :param targets: The target data. + :type targets: torch.Tensor + :return: A dictionary containing the loss values for the batch. + :rtype: dict[str, torch.Tensor] + """ + pass + + @abstractmethod + def evaluate_step(self, inputs: torch.tensor, targets: torch.tensor)->Dict[str, torch.Tensor]: + """ + Abstract method for evaluating the model on one batch + Must be implemented by subclasses. + This should be where the losses and metrics are calculated. + Should return a dictionary with loss name as key and torch tensor loss as value. + + :param inputs: The input data. + :type inputs: torch.Tensor + :param targets: The target data. + :type targets: torch.Tensor + :return: A dictionary containing the loss values for the batch. + :rtype: dict[str, torch.Tensor] + """ + pass + + @abstractmethod + def train_epoch(self)->dict[str, torch.Tensor]: + """ + Can be overridden by subclasses to implement custom training logic. + Make calls to the train_step method for each batch + in the training DataLoader. + + Return a dictionary with loss name as key and + torch tensor loss as value. Multiple losses can be returned. + + :return: A dictionary containing the loss values for the epoch. + :rtype: dict[str, torch.Tensor] + """ + + pass + + @abstractmethod + def evaluate_epoch(self)->dict[str, torch.Tensor]: + """ + Can be overridden by subclasses to implement custom evaluation logic. + Should make calls to the evaluate_step method for each batch + in the validation DataLoader. + + Should return a dictionary with loss name as key and + torch tensor loss as value. Multiple losses can be returned. + + :return: A dictionary containing the loss values for the epoch. + :rtype: dict[str, torch.Tensor] + """ + + pass + + def train(self): + """ + Train the model for the specified number of epochs. + Make calls to the train epoch and evaluate epoch methods. + """ + + self.model.to(self.device) + + # callbacks + for callback in self.callbacks: + callback.on_train_start(self) + + for epoch in range(self.epochs): + + # Increment the epoch counter + self.epoch += 1 + + # callbacks + for callback in self.callbacks: + callback.on_epoch_start(self) + + # Access all the metrics and reset them + for _, metric in self.metrics.items(): + metric.reset() + + # Train the model for one epoch + train_loss = self.train_epoch() + for loss_name, loss in train_loss.items(): + self._train_losses[loss_name].append(loss) + + # Evaluate the model for one epoch + val_loss = self.evaluate_epoch() + for loss_name, loss in val_loss.items(): + self._val_losses[loss_name].append(loss) + + # Access all the metrics and compute the final epoch metric value + for metric_name, metric in self.metrics.items(): + train_metric, val_metric = metric.compute() + self._train_metrics[metric_name].append(train_metric.item()) + self._val_metrics[metric_name].append(val_metric.item()) + + # Invoke callback on epoch_end + for callback in self.callbacks: + callback.on_epoch_end(self) + + # Update early stopping + val_loss = next(iter(val_loss.values())) + self.update_early_stop(val_loss) + + # Check if early stopping is needed + if self.early_stop_counter >= self.patience: + break + + for callback in self.callbacks: + callback.on_train_end(self) + + def update_early_stop(self, val_loss: torch.Tensor): + """ + Method to update the early stopping criterion + + :param val_loss: The loss value on the validation set + :type val_loss: torch.Tensor + """ + + if val_loss < self.best_loss: + self.best_loss = val_loss + self.early_stop_counter = 0 + self.best_model = self.model.state_dict().copy() + else: + self.early_stop_counter += 1 + + """ + Log property + """ + @property + def log(self): + """ + Returns the training and validation losses and metrics. + """ + log ={ + **{'epoch': list(range(1, self.epoch + 1))}, + **self._train_losses, + **{f'val_{key}': val for key, val in self._val_losses.items()}, + **self._train_metrics, + **{f'val_{key}': val for key, val in self._val_metrics.items()} + } + + return log + + """ + Properties for accessing various attributes of the trainer. + """ + @property + def train_ratio(self): + return self._train_ratio + + @property + def val_ratio(self): + return self._val_ratio + + @property + def test_ratio(self): + return self._test_ratio + + @property + def model(self): + return self._model + + @property + def optimizer(self): + return self._optimizer + + @property + def device(self): + return self._device + + @property + def batch_size(self): + return self._batch_size + + @property + def epochs(self): + return self._epochs + + @property + def patience(self): + return self._patience + + @property + def callbacks(self): + return self._callbacks + + @property + def best_model(self): + return self._best_model + + @property + def best_loss(self): + return self._best_loss + + @property + def early_stop_counter(self): + return self._early_stop_counter + + @property + def metrics(self): + return self._metrics + + @property + def epoch(self): + return self._epoch + + @property + def train_losses(self): + return self._train_losses + + @property + def val_losses(self): + return self._val_losses + + @property + def train_metrics(self): + return self._train_metrics + + @property + def val_metrics(self): + return self._val_metrics + + """ + Setters for best model and best loss and early stop counter + Meant to be used by the subclasses to update the best model and loss + """ + + @best_model.setter + def best_model(self, value: torch.nn.Module): + self._best_model = value + + @best_loss.setter + def best_loss(self, value): + self._best_loss = value + + @early_stop_counter.setter + def early_stop_counter(self, value: int): + self._early_stop_counter = value + + @epoch.setter + def epoch(self, value: int): + self._epoch = value + + """ + Update loss and metrics + """ + + def update_loss(self, + loss: torch.Tensor, + loss_name: str, + validation: bool = False): + if validation: + self._val_losses[loss_name].append(loss) + else: + self._train_losses[loss_name].append(loss) + + def update_metrics(self, + metric: torch.tensor, + metric_name: str, + validation: bool = False): + if validation: + self._val_metrics[metric_name].append(metric) + else: + self._train_metrics[metric_name].append(metric) + + """ + Properties for accessing the split datasets. + """ + @property + def train_dataset(self, loader=False): + """ + Returns the training dataset or DataLoader if loader=True + + :param loader: (bool) whether to return a DataLoader or the dataset + :type loader: bool + """ + if loader: + return self._train_loader + else: + return self._train_dataset + + @property + def val_dataset(self, loader=False): + """ + Returns the validation dataset or DataLoader if loader=True + + :param loader: (bool) whether to return a DataLoader or the dataset + :type loader: bool + """ + if loader: + return self._val_loader + else: + return self._val_dataset + + @property + def test_dataset(self, loader=False): + """ + Returns the test dataset or DataLoader if loader=True + Generates the DataLoader on the fly as the test data loader is not + pre-defined during object initialization + + :param loader: (bool) whether to return a DataLoader or the dataset + :type loader: bool + """ + if loader: + return DataLoader(self._test_dataset, batch_size=self._batch_size, shuffle=False) + else: + return self._test_dataset diff --git a/trainers/Trainer.py b/trainers/Trainer.py new file mode 100644 index 0000000..ce5e707 --- /dev/null +++ b/trainers/Trainer.py @@ -0,0 +1,170 @@ +from collections import defaultdict +from typing import Optional, List + +import torch +from torch.utils.data import DataLoader, random_split + +from .AbstractTrainer import AbstractTrainer + +class Trainer(AbstractTrainer): + """ + Trainer class for single img2img convolutional models backpropagating on single loss items + """ + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + backprop_loss: torch.nn.Module | List[torch.nn.Module], + # rest of the arguments are passed to and handled by the parent class + # - dataset + # - batch_size + # - epochs + # - patience + # - callbacks + # - metrics + **kwargs + ): + """ + Initialize the trainer with the model, optimizer and loss function. + + :param model: The model to be trained. + :type model: torch.nn.Module + :param optimizer: The optimizer to be used for training. + :type optimizer: torch.optim.Optimizer + :param backprop_loss: The loss function to be used for training or a list of loss functions. + :type backprop_loss: torch.nn.Module + """ + + super().__init__(**kwargs) + + self._model = model + self._optimizer = optimizer + self._backprop_loss = backprop_loss \ + if isinstance(backprop_loss, list) else [backprop_loss] + + # Make an initial copy of the model + self.best_model = self.model.state_dict().copy() + + """ + Overidden methods from the parent abstract class + """ + def train_step(self, inputs: torch.tensor, targets: torch.tensor): + """ + Perform a single training step on batch. + + :param inputs: The input image data batch + :type inputs: torch.tensor + :param targets: The target image data batch + :type targets: torch.tensor + """ + # move the data to the device + inputs, targets = inputs.to(self.device), targets.to(self.device) + + # set the model to train + self.model.train() + # set the optimizer gradients to zero + self.optimizer.zero_grad() + + # Forward pass + outputs = self.model(inputs) + + # Back propagate the loss + losses = {} + total_loss = torch.tensor(0.0, device=self.device) + for loss in self._backprop_loss: + losses[type(loss).__name__] = loss(outputs, targets) + total_loss += losses[type(loss).__name__] + + total_loss.backward() + self.optimizer.step() + + # Calculate the metrics outputs and update the metrics + for _, metric in self.metrics.items(): + metric.update(outputs, targets, validation=False) + + return { + key: value.item() for key, value in losses.items() + } + + def evaluate_step(self, inputs: torch.tensor, targets: torch.tensor): + """ + Perform a single evaluation step on batch. + + :param inputs: The input image data batch + :type inputs: torch.tensor + :param targets: The target image data batch + :type targets: torch.tensor + """ + # move the data to the device + inputs, targets = inputs.to(self.device), targets.to(self.device) + + # set the model to evaluation + self.model.eval() + + with torch.no_grad(): + # Forward pass + outputs = self.model(inputs) + + # calculate the loss + losses = {} + for loss in self._backprop_loss: + losses[type(loss).__name__] = loss(outputs, targets) + + # Calculate the metrics outputs and update the metrics + for _, metric in self.metrics.items(): + metric.update(outputs, targets, validation=True) + + return { + key: value.item() for key, value in losses.items() + } + + def train_epoch(self): + """ + Train the model for one epoch. + """ + + super().train_epoch() + + self._model.train() + losses = defaultdict(list) + # Iterate over the train_loader + for inputs, targets in self._train_loader: + batch_loss = self.train_step(inputs, targets) + for key, value in batch_loss.items(): + losses[key].append(value) + + # reduce loss + return { + key: sum(value) / len(value) for key, value in losses.items() + } + + def evaluate_epoch(self): + """ + Evaluate the model for one epoch. + """ + + self._model.eval() + losses = defaultdict(list) + # Iterate over the val_loader + for inputs, targets in self._val_loader: + batch_loss = self.evaluate_step(inputs, targets) + for key, value in batch_loss.items(): + losses[key].append(value) + + # reduce loss + return { + key: sum(value) / len(value) for key, value in losses.items() + } + + # @property + # def log(self): + # """ + # Returns the training and validation losses and metrics. + # """ + # log ={ + # **{'epoch': list(range(1, self.epoch + 1))}, + # **self._train_metrics, + # **{f'val_{key}': val for key, val in self._val_metrics.items()} + # } + + # return log \ No newline at end of file diff --git a/trainers/WGaNTrainer.py b/trainers/WGaNTrainer.py new file mode 100644 index 0000000..0e83b7a --- /dev/null +++ b/trainers/WGaNTrainer.py @@ -0,0 +1,274 @@ +from typing import Optional +from collections import defaultdict + +import torch +import torch.autograd as autograd +from torch.utils.data import DataLoader + +from .AbstractTrainer import AbstractTrainer + +class WGaNTrainer(AbstractTrainer): + def __init__(self, + generator: torch.nn.Module, + discriminator: torch.nn.Module, + gen_optimizer: torch.optim.Optimizer, + disc_optimizer: torch.optim.Optimizer, + generator_loss_fn: torch.nn.Module, + discriminator_loss_fn: torch.nn.Module, + gradient_penalty_fn: Optional[torch.nn.Module]=None, + discriminator_update_freq: int=1, + generator_update_freq: int=5, + # rest of the arguments are passed to and handled by the parent class + # - dataset + # - batch_size + # - epochs + # - patience + # - callbacks + # - metrics + **kwargs): + """ + Initializes the WGaN Trainer class. + + :param generator: The image2image generator model (e.g., UNet) + :type generator: torch.nn.Module + :param discriminator: The discriminator model + :type discriminator: torch.nn.Module + :param gen_optimizer: Generator optimizer + :type gen_optimizer: torch.optim.Optimizer + :param disc_optimizer: Discriminator optimizer + :type disc_optimizer: torch.optim.Optimizer + :param generator_loss_fn: Generator loss function + :type generator_loss_fn: torch.nn.Module + :param discriminator_loss_fn: Adverserial loss function + :type discriminator_loss_fn: torch.nn.Module + :param gradient_penalty_fn: (optional) Gradient penalty loss function + :type gradient_penalty_fn: torch.nn.Module + :param discriminator_update_freq: How frequently to update the discriminator + :type discriminator_update_freq: int + :param generator_update_freq: How frequently to update the generator + :type generator_update_freq: int + :param kwargs: Additional arguments passed to the AbstractTrainer + :type kwargs: dict + """ + super().__init__(**kwargs) + + # Validate update frequencies + if discriminator_update_freq > 1 and generator_update_freq > 1: + raise ValueError( + "Both discriminator_update_freq and generator_update_freq cannot be greater than 1. " + "At least one network must update every epoch." + ) + + self._generator = generator + self._discriminator = discriminator + self._gen_optimizer = gen_optimizer + self._disc_optimizer = disc_optimizer + self._generator_loss_fn = generator_loss_fn + self._generator_loss_fn.trainer = self + self._discriminator_loss_fn = discriminator_loss_fn + self._discriminator_loss_fn.trainer = self + self._gradient_penalty_fn = gradient_penalty_fn + if self._gradient_penalty_fn is not None: + self._gradient_penalty_fn.trainer = self + + # Make an initial copy of the generator and discriminator models + self.best_generator = self._generator.state_dict().copy() + self.best_discriminator = self._discriminator.state_dict().copy() + + # Global step counter and update frequencies + self._discriminator_update_freq = discriminator_update_freq + self._generator_update_freq = generator_update_freq + + ## TODO: instead of memorizing the same loss, keep a running average of the losses + # Memory for discriminator and generator losses from the most recent update + self._last_discriminator_loss = torch.tensor(0.0, device=self.device).detach() + self._last_gradient_penalty_loss = torch.tensor(0.0, device=self.device).detach() + self._last_generator_loss = torch.tensor(0.0, device=self.device).detach() + + def train_step(self, + inputs: torch.tensor, + targets: torch.tensor + ): + """ + Perform a single training step on batch. + + :param inputs: The input image data batch + :type inputs: torch.tensor + :param targets: The target image data batch + :type targets: torch.tensor + """ + inputs, targets = inputs.to(self.device), targets.to(self.device) + + gp_loss = torch.tensor(0.0, device=self.device) + + # foward pass to generate image (shared by both updates) + generated_images = self._generator(inputs) + + # Train Discriminator + if self.epoch % self._discriminator_update_freq == 0: + self._disc_optimizer.zero_grad() + + real_images = targets + + # Concatenate input channel and real/generated image channels along the + # channel dimension to feed full stacked multi-channel images to the discriminator + real_input_pair = torch.cat((real_images, inputs), 1) + generated_input_pair = torch.cat((generated_images.detach(), inputs), 1) + + discriminator_real_score = self._discriminator(real_input_pair).mean() + discriminator_fake_score = self._discriminator(generated_input_pair).mean() + + # Adverserial loss + discriminator_loss = self._discriminator_loss_fn(discriminator_real_score, discriminator_fake_score) + + # Compute Gradient penalty loss if fn is supplied + if self._gradient_penalty_fn is not None: + gp_loss = self._gradient_penalty_fn(real_input_pair, generated_input_pair) + + total_discriminator_loss = discriminator_loss + gp_loss + total_discriminator_loss.backward() + self._disc_optimizer.step() + + # memorize current discriminator loss until next discriminator update + self._last_discriminator_loss = discriminator_loss.detach() + self._last_gradient_penalty_loss = gp_loss + else: + # when not being updated, use the loss from previus update + discriminator_loss = self._last_discriminator_loss + gp_loss = self._last_gradient_penalty_loss + + # Train Generator + if self.epoch % self._generator_update_freq == 0: + self._gen_optimizer.zero_grad() + + discriminator_fake_score = self._discriminator(torch.cat((generated_images, inputs), 1)).mean() + generator_loss = self._generator_loss_fn(discriminator_fake_score, generated_images, real_images, self.epoch) + generator_loss.backward() + self._gen_optimizer.step() + + # memorize current generator loss until next generator update + self._last_generator_loss = generator_loss.detach() + else: + # when not being updated, set the loss to zero + generator_loss = self._last_generator_loss + + for _, metric in self.metrics.items(): + ## TODO: centralize the update of metrics + # compute the generated fake targets regardless for use with metrics + generated_images = self._generator(inputs).detach() + metric.update(generated_images, targets, validation=False) + ## After each batch -> after each epoch + + loss = {type(self._discriminator_loss_fn).__name__: discriminator_loss.item(), + type(self._generator_loss_fn).__name__: generator_loss.item()} + if self._gradient_penalty_fn is not None: + loss = { + **loss, + **{type(self._gradient_penalty_fn).__name__: gp_loss.item()} + } + + return loss + + def evaluate_step(self, + inputs: torch.tensor, + targets: torch.tensor + ): + """ + Perform a single evaluation step on batch. + + :param inputs: The input image data batch + :type inputs: torch.tensor + :param targets: The target image data batch + :type targets: torch.tensor + """ + inputs, targets = inputs.to(self.device), targets.to(self.device) + + self._generator.eval() + self._discriminator.eval() + with torch.no_grad(): + + real_images = targets + generated_images = self._generator(inputs) + + # Concatenate input channel and real/generated image channels along the + # channel dimension to feed full stacked multi-channel images to the discriminator + real_input_pair = torch.cat((real_images, inputs), 1) + generated_input_pair = torch.cat((generated_images, inputs), 1) + + discriminator_real_score = self._discriminator(real_input_pair).mean() + discriminator_fake_score = self._discriminator(generated_input_pair).mean() + + # Compute losses + discriminator_loss = self._discriminator_loss_fn(discriminator_real_score, discriminator_fake_score) + + ## TODO: decide if gradient loss computation during eval mode is meaningful + gp_loss = torch.tensor(0.0, device=self.device) + + generator_loss = self._generator_loss_fn(discriminator_fake_score, generated_images, real_images, self.epoch) + + for _, metric in self.metrics.items(): + metric.update(generated_images, targets, validation=True) + + loss = {type(self._discriminator_loss_fn).__name__: discriminator_loss.item(), + type(self._generator_loss_fn).__name__: generator_loss.item()} + if self._gradient_penalty_fn is not None: + loss = { + **loss, + **{type(self._gradient_penalty_fn).__name__: gp_loss.item()} + } + + return loss + + def train_epoch(self): + + super().train_epoch() + + self._generator.train() + self._discriminator.train() + + epoch_losses = defaultdict(list) + for inputs, targets in self._train_loader: + losses = self.train_step(inputs, targets) + for key, value in losses.items(): + epoch_losses[key].append(value) + + for key, _ in epoch_losses.items(): + epoch_losses[key] = sum(epoch_losses[key])/len(self._train_loader) + + return epoch_losses + + def evaluate_epoch(self): + + self._generator.eval() + self._discriminator.eval() + + epoch_losses = defaultdict(list) + for inputs, targets in self._val_loader: + losses = self.evaluate_step(inputs, targets) + for key, value in losses.items(): + epoch_losses[key].append(value) + + for key, _ in epoch_losses.items(): + epoch_losses[key] = sum(epoch_losses[key])/len(self._val_loader) + + return epoch_losses + + def train(self): + + self._discriminator.to(self.device) + + super().train() + + @property + def model(self) -> torch.nn.Module: + """ + return the generator + """ + return self._generator + + @property + def discriminator(self) -> torch.nn.Module: + """ + returns the discriminator + """ + return self._discriminator \ No newline at end of file From 34969843254ba17c999f7c8ec8863717355b9816 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:23:52 -0700 Subject: [PATCH 07/89] Added loss files --- losses/AbstractLoss.py | 33 +++++++++++++++++++++++ losses/DiscriminatorLoss.py | 17 ++++++++++++ losses/GeneratorLoss.py | 50 +++++++++++++++++++++++++++++++++++ losses/GradientPenaltyLoss.py | 36 +++++++++++++++++++++++++ losses/README.md | 1 + 5 files changed, 137 insertions(+) create mode 100644 losses/AbstractLoss.py create mode 100644 losses/DiscriminatorLoss.py create mode 100644 losses/GeneratorLoss.py create mode 100644 losses/GradientPenaltyLoss.py create mode 100644 losses/README.md diff --git a/losses/AbstractLoss.py b/losses/AbstractLoss.py new file mode 100644 index 0000000..a343f33 --- /dev/null +++ b/losses/AbstractLoss.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod +import torch.nn as nn + +""" +Adapted from https://github.com/WayScience/nuclear_speckles_analysis +""" +class AbstractLoss(nn.Module, ABC): + """Abstract class for metrics""" + + def __init__(self, _metric_name: str): + + super(AbstractLoss, self).__init__() + + self._metric_name = _metric_name + self._trainer = None + + @property + def trainer(self): + return self._trainer + + @trainer.setter + def trainer(self, value): + self._trainer = value + + @property + def metric_name(self, _metric_name: str): + """Defines the mertic name returned by the class.""" + return self._metric_name + + @abstractmethod + def forward(self): + """Computes the metric given information about the data.""" + pass \ No newline at end of file diff --git a/losses/DiscriminatorLoss.py b/losses/DiscriminatorLoss.py new file mode 100644 index 0000000..681887e --- /dev/null +++ b/losses/DiscriminatorLoss.py @@ -0,0 +1,17 @@ +import torch + +from .AbstractLoss import AbstractLoss + +class DiscriminatorLoss(AbstractLoss): + def __init__(self, _metric_name): + super().__init__(_metric_name) + + def forward(self, real_output, fake_output): + + # If the probability output is more than Scalar, take the mean of the output + if real_output.dim() >= 3: + real_output = torch.mean(real_output, tuple(range(2, real_output.dim()))) + if fake_output.dim() >= 3: + fake_output = torch.mean(fake_output, tuple(range(2, fake_output.dim()))) + + return (fake_output - real_output).mean() \ No newline at end of file diff --git a/losses/GeneratorLoss.py b/losses/GeneratorLoss.py new file mode 100644 index 0000000..2bee83b --- /dev/null +++ b/losses/GeneratorLoss.py @@ -0,0 +1,50 @@ +from typing import Optional + +import torch +from torch.nn import L1Loss + +from .AbstractLoss import AbstractLoss + +class GeneratorLoss(AbstractLoss): + def __init__(self, + _metric_name: str, + reconstruction_loss: Optional[torch.tensor] = L1Loss() + ): + """ + :param reconstruction_loss: The image reconstruction loss, + defaults to L1Loss(reduce=False) + :type reconstruction_loss: torch.tensor + """ + + super().__init__(_metric_name) + + self._reconstruction_loss = reconstruction_loss + + def forward(self, + discriminator_probs: torch.tensor, + fake_images: torch.tensor, + real_images: torch.tensor, + epoch: Optional[int] = None + ): + """ + Computes the loss for the GaN generator. + + :param discriminator_probs: The probabilities of the discriminator for the fake images being real. + :type discriminator_probs: torch.tensor + :param fake_images: The fake images generated by the generator. + :type fake_images: torch.tensor + :param real_images: The real images. + :type real_images: torch.tensor + :param epoch: The current epoch number. + Used for a smoothing weight for the adversarial loss component + :type epoch: int + """ + + # Adversarial loss + adversarial_loss = -torch.mean(discriminator_probs) + if epoch is not None: + adversarial_loss = 0.01 * adversarial_loss/(epoch + 1) + + image_loss = self._reconstruction_loss(fake_images, real_images) + + return adversarial_loss + image_loss.mean() \ No newline at end of file diff --git a/losses/GradientPenaltyLoss.py b/losses/GradientPenaltyLoss.py new file mode 100644 index 0000000..f8e3cb4 --- /dev/null +++ b/losses/GradientPenaltyLoss.py @@ -0,0 +1,36 @@ + +import torch +import torch.autograd as autograd + +from .AbstractLoss import AbstractLoss + +class GradientPenaltyLoss(AbstractLoss): + def __init__(self, _metric_name, discriminator, weight=10.0): + super().__init__(_metric_name) + + ## TODO: add a wrapper class for GaN loss functions to + # dynamically access discriminator from the trainer class + self._discriminator = discriminator + self._weight = weight + + def forward(self, real_imgs, fake_imgs): + + device = self.trainer.device + + batch_size = real_imgs.size(0) + ## TODO: check if expand_as behaves as expected + eta = torch.rand(batch_size, 1, 1, 1, device=device).expand_as(real_imgs) + interpolated = (eta * real_imgs + (1 - eta) * fake_imgs).requires_grad_(True) + prob_interpolated = self._discriminator(interpolated) + + gradients = autograd.grad( + outputs=prob_interpolated, + inputs=interpolated, + grad_outputs=torch.ones_like(prob_interpolated), + create_graph=True, + retain_graph=True, + )[0] + + gradients = gradients.view(batch_size, -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return self._weight * gradient_penalty \ No newline at end of file diff --git a/losses/README.md b/losses/README.md new file mode 100644 index 0000000..2fe0a00 --- /dev/null +++ b/losses/README.md @@ -0,0 +1 @@ +Here lives the loss functions used by wGaN GP \ No newline at end of file From c187a310767f8bb95a169ffa712541345a7eea2a Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:25:18 -0700 Subject: [PATCH 08/89] Added metrics files --- metrics/AbstractMetrics.py | 80 ++++++++++++++++++++++++++++++++++++++ metrics/PSNR.py | 29 ++++++++++++++ metrics/README.md | 2 + metrics/SSIM.py | 38 ++++++++++++++++++ 4 files changed, 149 insertions(+) create mode 100644 metrics/AbstractMetrics.py create mode 100644 metrics/PSNR.py create mode 100644 metrics/README.md create mode 100644 metrics/SSIM.py diff --git a/metrics/AbstractMetrics.py b/metrics/AbstractMetrics.py new file mode 100644 index 0000000..54aeada --- /dev/null +++ b/metrics/AbstractMetrics.py @@ -0,0 +1,80 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import torch +import torch.nn as nn + +""" +Adapted from https://github.com/WayScience/nuclear_speckles_analysis +""" +class AbstractMetrics(nn.Module, ABC): + """Abstract class for metrics""" + + def __init__(self, _metric_name: str): + + super(AbstractMetrics, self).__init__() + + self.__metric_name = _metric_name + + self.__train_metric_values = [] + self.__val_metric_values = [] + + @property + def metric_name(self): + """Defines the mertic name returned by the class.""" + return self.__metric_name + + @property + def train_metric_values(self): + """Returns the training metric values.""" + return self.__train_metric_values + + @property + def val_metric_values(self): + """Returns the validation metric values.""" + return self.__val_metric_values + + @abstractmethod + def forward(self, + _generated_outputs: torch.tensor, + _targets: torch.tensor + ) -> torch.tensor: + """Computes the metric given information about the data.""" + pass + + def update(self, + _generated_outputs: torch.tensor, + _targets: torch.tensor, + validation: bool=False + ) -> None: + """Updates the metric with the new data.""" + if validation: + self.__val_metric_values.append(self.forward(_generated_outputs, _targets)) + else: + self.__train_metric_values.append(self.forward(_generated_outputs, _targets)) + + def reset(self): + """Resets the metric.""" + self.__train_metric_values = [] + self.__val_metric_values = [] + + def compute(self, aggregation: Optional[str] = 'mean'): + """Computes the final metric value.""" + + if aggregation == 'mean': + return \ + torch.mean(torch.stack(self.__train_metric_values)) if len(self.__train_metric_values) > 0 else None , \ + torch.mean(torch.stack(self.__val_metric_values)) if len(self.__val_metric_values) > 0 else None + + elif aggregation == 'sum': + return \ + torch.sum(torch.stack(self.__train_metric_values)) if len(self.__train_metric_values) > 0 else None , \ + torch.sum(torch.stack(self.__val_metric_values)) if len(self.__val_metric_values) > 0 else None + + elif aggregation is None: + return \ + torch.stack(self.__train_metric_values) if len(self.__train_metric_values) > 0 else None , \ + torch.stack(self.__val_metric_values) if len(self.__val_metric_values) > 0 else None + + else: + raise ValueError(f"Aggregation method {aggregation} is not supported.") \ No newline at end of file diff --git a/metrics/PSNR.py b/metrics/PSNR.py new file mode 100644 index 0000000..2f39a3a --- /dev/null +++ b/metrics/PSNR.py @@ -0,0 +1,29 @@ +import torch +from .AbstractMetrics import AbstractMetrics + +""" +Adapted from https://github.com/WayScience/nuclear_speckles_analysis +""" +class PSNR(AbstractMetrics): + """Computes and tracks the Peak Signal-to-Noise Ratio (PSNR).""" + + def __init__(self, _metric_name: str, _max_pixel_value: int = 1): + + super(PSNR, self).__init__(_metric_name) + + self.__max_pixel_value = _max_pixel_value + + def forward(self, _generated_outputs: torch.Tensor, _targets: torch.Tensor): + + mse = torch.mean((_generated_outputs - _targets) ** 2, dim=[2, 3]) + psnr = torch.where( + mse == 0, + torch.tensor(0.0), + 10 * torch.log10((self.__max_pixel_value**2) / mse), + ) + + return psnr.mean() + + @property + def metric_name(self): + return self.__metric_name \ No newline at end of file diff --git a/metrics/README.md b/metrics/README.md new file mode 100644 index 0000000..ac8b458 --- /dev/null +++ b/metrics/README.md @@ -0,0 +1,2 @@ +Here lives the metric classes which is dependent on a abstract metric class +Each metric needs to have a foward function implemented over target and predict while the abstract class functions inhertied handles accumulation \ No newline at end of file diff --git a/metrics/SSIM.py b/metrics/SSIM.py new file mode 100644 index 0000000..6ba630f --- /dev/null +++ b/metrics/SSIM.py @@ -0,0 +1,38 @@ +import torch +from .AbstractMetrics import AbstractMetrics + +""" +Adapted from https://github.com/WayScience/nuclear_speckles_analysis +""" +class SSIM(AbstractMetrics): + """Computes and tracks the Structural Similarity Index Measure (SSIM).""" + + def __init__(self, _metric_name: str, _max_pixel_value: int = 1): + + super(SSIM, self).__init__(_metric_name) + + self.__max_pixel_value = _max_pixel_value + + def forward(self, _generated_outputs: torch.Tensor, _targets: torch.Tensor): + + mu1 = _generated_outputs.mean(dim=[2, 3], keepdim=True) + mu2 = _targets.mean(dim=[2, 3], keepdim=True) + + sigma1_sq = ((_generated_outputs - mu1) ** 2).mean(dim=[2, 3], keepdim=True) + sigma2_sq = ((_targets - mu2) ** 2).mean(dim=[2, 3], keepdim=True) + sigma12 = ((_generated_outputs - mu1) * (_targets - mu2)).mean( + dim=[2, 3], keepdim=True + ) + + c1 = (self.__max_pixel_value * 0.01) ** 2 + c2 = (self.__max_pixel_value * 0.03) ** 2 + + ssim_value = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ( + (mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2) + ) + + return ssim_value.mean() + + @property + def metric_name(self): + return self.__metric_name \ No newline at end of file From 3e921b955ac2140c688dd5dfe1250ca2269fd980 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:26:53 -0700 Subject: [PATCH 09/89] Added callback files --- callbacks/AbstractCallback.py | 40 ++++++++++++++++++++ callbacks/IntermediatePlot.py | 54 ++++++++++++++++++++++++++ callbacks/MlflowLogger.py | 71 +++++++++++++++++++++++++++++++++++ callbacks/README.md | 3 ++ 4 files changed, 168 insertions(+) create mode 100644 callbacks/AbstractCallback.py create mode 100644 callbacks/IntermediatePlot.py create mode 100644 callbacks/MlflowLogger.py create mode 100644 callbacks/README.md diff --git a/callbacks/AbstractCallback.py b/callbacks/AbstractCallback.py new file mode 100644 index 0000000..c1462a4 --- /dev/null +++ b/callbacks/AbstractCallback.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import List, Callable, Dict + +import torch + +class AbstractCallback(ABC): + """ + Abstract class for callbacks in the training process. + Callbacks can be used to plot intermediate metrics, log contents, save checkpoints, etc. + """ + + def __init__(self, name: str): + """ + :param name: Name of the callback. + """ + self._name = name + + def on_train_start(self, trainer): + """ + Called at the start of training. + """ + pass + + def on_epoch_start(self, trainer): + """ + Called at the start of each epoch. + """ + pass + + def on_epoch_end(self, trainer): + """ + Called at the end of each epoch. + """ + pass + + def on_train_end(self, trainer): + """ + Called at the end of training. + """ + pass \ No newline at end of file diff --git a/callbacks/IntermediatePlot.py b/callbacks/IntermediatePlot.py new file mode 100644 index 0000000..30c084f --- /dev/null +++ b/callbacks/IntermediatePlot.py @@ -0,0 +1,54 @@ +from typing import List, Union + +import torch +import torch.nn as nn + +from .AbstractCallback import AbstractCallback +from ..datasets.PatchDataset import PatchDataset +from ..datasets.PatchDataset2 import PatchDataset as PDS2 + +from ..evaluation.visualization_utils import plot_patches + +class IntermediatePatchPlot(AbstractCallback): + """ + Callback to save the model weights at the end of each epoch. + """ + + def __init__(self, + name: str, + path: str, + dataset: PatchDataset | PDS2, + plot_n_patches: int=5, + plot_metrics: List[nn.Module]=None, + **kwargs): + """ + :param name: Name of the callback. + :param path: Path to save the model weights. + """ + super().__init__(name) + self._path = path + if not isinstance(dataset, Union[PatchDataset, PDS2]): + raise TypeError(f"Expected PatchDataset, got {type(dataset)}") + self._dataset = dataset + + # Additional kwargs passed to plot_patches + self.plot_n_patches = plot_n_patches + self.plot_metrics = plot_metrics + self.plot_kwargs = kwargs + + def on_epoch_end(self, trainer): + """ + Plot dataset with model predictions at the end of each epoch. + """ + + original_device = next(trainer.model.parameters()).device + + plot_patches( + _dataset = self._dataset, + _n_patches = self.plot_n_patches, + _model = trainer.model, + _metrics = self.plot_metrics, + save_path = f"{self._path}/epoch_{trainer.epoch}.png", + device=original_device, + **self.plot_kwargs + ) \ No newline at end of file diff --git a/callbacks/MlflowLogger.py b/callbacks/MlflowLogger.py new file mode 100644 index 0000000..e4675e7 --- /dev/null +++ b/callbacks/MlflowLogger.py @@ -0,0 +1,71 @@ +import os +import pathlib +import tempfile + +import mlflow +import torch + +from .AbstractCallback import AbstractCallback + +class MlflowLogger(AbstractCallback): + """ + Callback to log metrics to MLflow. + """ + + def __init__(self, + name: str, + artifact_name: str = 'best_model_weights.pth', + mlflow_uri: pathlib.Path | str = 'mlruns', + mlflow_experiment_name: str = 'default_experiment', + mlflow_start_run_args: dict = {}, + mlflow_log_params_args: dict = {}, + + ): + """ + :param name: Name of the callback. + """ + super().__init__(name) + + try: + mlflow.set_tracking_uri(mlflow_uri) + mlflow.set_experiment(mlflow_experiment_name) + except Exception as e: + print(f"Error setting MLflow tracking URI: {e}") + + self._artifact_name = artifact_name + self._mlflow_start_run_args = mlflow_start_run_args + self._mlflow_log_params_args = mlflow_log_params_args + + def on_train_start(self, trainer): + """ + Called at the start of training. + """ + mlflow.start_run( + **self._mlflow_start_run_args + ) + mlflow.log_params( + self._mlflow_log_params_args + ) + + def on_epoch_end(self, trainer): + """ + Called at the end of each epoch. + """ + for key, values in trainer.log.items(): + if values is not None and len(values) > 0: + value = values[-1] + else: + value = None + mlflow.log_metric(key, value, step=trainer.epoch) + + def on_train_end(self, trainer): + """ + Called at the end of training. + """ + # Save weights to a temporary directory and log artifacts + with tempfile.TemporaryDirectory() as tmpdirname: + weights_path = os.path.join(tmpdirname, self._artifact_name) + torch.save(trainer.best_model, weights_path) + mlflow.log_artifact(weights_path, artifact_path="models") + + mlflow.end_run() \ No newline at end of file diff --git a/callbacks/README.md b/callbacks/README.md new file mode 100644 index 0000000..d776cf7 --- /dev/null +++ b/callbacks/README.md @@ -0,0 +1,3 @@ +Here lives the callback classes that are meant to be fed into trainers to do stuff like saving images every epoch and logging. + +The callback classes must inherit the abstract class. \ No newline at end of file From 8af1475d32d4a2367f422d3bfbdd8a46cf0cc413 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:28:04 -0700 Subject: [PATCH 10/89] Added transform files --- transforms/MinMaxNormalize.py | 34 ++++++++++++++ transforms/PixelDepthTransform.py | 70 +++++++++++++++++++++++++++++ transforms/README.md | 3 ++ transforms/ZScoreNormalize.py | 74 +++++++++++++++++++++++++++++++ 4 files changed, 181 insertions(+) create mode 100644 transforms/MinMaxNormalize.py create mode 100644 transforms/PixelDepthTransform.py create mode 100644 transforms/README.md create mode 100644 transforms/ZScoreNormalize.py diff --git a/transforms/MinMaxNormalize.py b/transforms/MinMaxNormalize.py new file mode 100644 index 0000000..c081bf6 --- /dev/null +++ b/transforms/MinMaxNormalize.py @@ -0,0 +1,34 @@ +from albumentations import ImageOnlyTransform +import numpy as np + +""" +Adapted from https://github.com/WayScience/nuclear_speckles_analysis +""" +class MinMaxNormalize(ImageOnlyTransform): + """Min-Max normalize each image""" + + def __init__(self, + _normalization_factor: float, + _always_apply: bool=False, + _p: float=0.5): + super(MinMaxNormalize, self).__init__(_always_apply, _p) + self.__normalization_factor = _normalization_factor + + @property + def normalization_factor(self): + return self.__normalization_factor + + def apply(self, _img, **kwargs): + + if isinstance(_img, np.ndarray): + return _img / self.normalization_factor + + else: + raise TypeError("Unsupported image type for transform (Should be a numpy array)") + + def invert(self, _img, **kwargs): + + if isinstance(_img, np.ndarray): + return _img * self.normalization_factor + else: + raise TypeError("Unsupported image type for transform (Should be a numpy array)") diff --git a/transforms/PixelDepthTransform.py b/transforms/PixelDepthTransform.py new file mode 100644 index 0000000..6142352 --- /dev/null +++ b/transforms/PixelDepthTransform.py @@ -0,0 +1,70 @@ +from albumentations import ImageOnlyTransform +import numpy as np + +class PixelDepthTransform(ImageOnlyTransform): + """ + Transform to convert images from a specified bit depth to another bit depth (e.g., 16-bit to 8-bit). + Automatically scales pixel values up or down to the target bit depth. + The only supported bit depths are 8, 16, and 32. + """ + + def __init__(self, + src_bit_depth: int = 16, + target_bit_depth: int = 8, + _always_apply: bool = True, + _p: float = 1.0): + """ + :param src_bit_depth: Bit depth of the input image (e.g., 16 for 16-bit). + :param target_bit_depth: Bit depth to scale the image to (e.g., 8 for 8-bit). + :param _always_apply: Whether to always apply the transform. + :param _p: Probability of applying the transform. + """ + if src_bit_depth not in [8, 16, 32]: + raise ValueError("Unsupported source bit depth (should be 8 or 16)") + if target_bit_depth not in [8, 16, 32]: + raise ValueError("Unsupported target bit depth (should be 8 or 16)") + + super(PixelDepthTransform, self).__init__(_always_apply, _p) + self.src_bit_depth = src_bit_depth + self.target_bit_depth = target_bit_depth + + def apply(self, img, **kwargs): + """ + Apply the bit depth transformation. + :param img: Input image as a numpy array. + :return: Transformed image scaled to the target bit depth. + """ + if not isinstance(img, np.ndarray): + raise TypeError("Unsupported image type for transform (should be a numpy array)") + + # Maximum pixel value based on source and target bit depth + src_max_val = (2 ** self.src_bit_depth) - 1 + target_max_val = (2 ** self.target_bit_depth) - 1 + + if self.target_bit_depth == 32: + # Scale to the 32-bit integer range + return ((img / src_max_val) * target_max_val).astype(np.uint32) + else: + # Standard conversion for 8-bit or 16-bit integers + return ((img / src_max_val) * target_max_val).astype( + np.uint8 if self.target_bit_depth == 8 else np.uint16 + ) + + def invert(self, img, **kwargs): + """ + Optionally invert the bit depth transformation (useful for debugging or preprocessing). + :param img: Transformed image as a numpy array. + :return: Image restored to the original bit depth. + """ + if not isinstance(img, np.ndarray): + raise TypeError("Unsupported image type for inversion (should be a numpy array)") + + target_max_val = (2 ** self.target_bit_depth) - 1 + src_max_val = (2 ** self.src_bit_depth) - 1 + + # Invert scaling back to original bit depth + img = (img / target_max_val) * src_max_val + return img.astype(np.uint16) if self.src_bit_depth == 16 else img + + def __repr__(self): + return f"PixelDepthTransform(src_bit_depth={self.src_bit_depth}, target_bit_depth={self.target_bit_depth})" \ No newline at end of file diff --git a/transforms/README.md b/transforms/README.md new file mode 100644 index 0000000..d3a5f1c --- /dev/null +++ b/transforms/README.md @@ -0,0 +1,3 @@ +Here lives the image transform class. + +For now all of these transforms are just used as input normalization. They all have an invert function to faciliate visualization. \ No newline at end of file diff --git a/transforms/ZScoreNormalize.py b/transforms/ZScoreNormalize.py new file mode 100644 index 0000000..1b588d8 --- /dev/null +++ b/transforms/ZScoreNormalize.py @@ -0,0 +1,74 @@ +from albumentations import ImageOnlyTransform +import numpy as np + +""" +Wrote this to get z score normalizae to work with albumentations +""" +class ZScoreNormalize(ImageOnlyTransform): + """Z-score normalize each image""" + + def __init__(self, _mean=None, _std=None, _always_apply=False, _p=0.5): + """ + Args: + _mean (float): Precomputed mean for normalization (optional). If None, compute per-image mean. + _std (float): Precomputed standard deviation for normalization (optional). If None, compute per-image std. + _always_apply (bool): If True, always apply this transformation. + _p (float): Probability of applying this transformation. + """ + super(ZScoreNormalize, self).__init__(_always_apply, _p) + self.__mean = _mean + self.__std = _std + + @property + def mean(self): + return self.__mean + + @property + def std(self): + return self.__std + + def apply(self, _img, **params): + """ + Apply z-score normalization to the image. + Args: + _img (np.ndarray): Input image as a numpy array. + Returns: + np.ndarray: Z-score normalized image. + """ + if not isinstance(_img, np.ndarray): + raise TypeError("Unsupported image type for transform (Should be a numpy array)") + + mean = self.__mean if self.__mean is not None else _img.mean() + std = self.__std if self.__std is not None else _img.std() + + if std == 0: + raise ValueError("Standard deviation is zero; cannot perform z-score normalization.") + + return (_img - mean) / std + + def invert(self, _img, **kwargs): + """ + Invert the z-score normalization. + Args: + _img (np.ndarray): Input image as a numpy array. + Returns: + np.ndarray: Inverted image. + """ + if not isinstance(_img, np.ndarray): + raise TypeError("Unsupported image type for transform (Should be a numpy array)") + + if self.__mean is None or self.__std is None: + mean = kwargs.get("mean", None) + std = kwargs.get("std", None) + if mean is None or std is None: + return _img + else: + return (_img * std) + mean + else: + mean = self.__mean if self.__mean is not None else _img.mean() + std = self.__std if self.__std is not None else _img.std() + + if std == 0: + raise ValueError("Standard deviation is zero; cannot perform z-score normalization.") + + return (_img * std) + mean From 744efcc90de6e2844a48e889f05291f19d0bc412 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:30:36 -0700 Subject: [PATCH 11/89] Added evaluation files --- evaluation/README.md | 3 + evaluation/evaluation_utils.py | 28 ++++ evaluation/visualization_utils.py | 216 ++++++++++++++++++++++++++++++ 3 files changed, 247 insertions(+) create mode 100644 evaluation/README.md create mode 100644 evaluation/evaluation_utils.py create mode 100644 evaluation/visualization_utils.py diff --git a/evaluation/README.md b/evaluation/README.md new file mode 100644 index 0000000..a698a2a --- /dev/null +++ b/evaluation/README.md @@ -0,0 +1,3 @@ +Here lives some collection of functions useful for evaluating model performance and plotting + +These code are pretty messy and will need major revisions. \ No newline at end of file diff --git a/evaluation/evaluation_utils.py b/evaluation/evaluation_utils.py new file mode 100644 index 0000000..e83ae8b --- /dev/null +++ b/evaluation/evaluation_utils.py @@ -0,0 +1,28 @@ +from collections import defaultdict +from typing import List, Callable + +import pandas as pd +import torch +from torch.utils.data import DataLoader + +def evaluate_metrics( + _model: torch.nn.Module, + _dataset: torch.utils.data.Dataset, + _metrics: List[Callable|torch.nn.Module], + _device:str='cpu' +): + metrics = defaultdict(list) + _model.to(_device) + _model.eval() + + data_loader = DataLoader(_dataset, batch_size=1, shuffle=False) + + with torch.no_grad(): + for input, target in data_loader: + input = input.to(_device) + target = target.to(_device) + output = _model(input) + for _metric in _metrics: + metrics[_metric.__class__.__name__].append(_metric(output, target).item()) + + return pd.DataFrame(metrics) \ No newline at end of file diff --git a/evaluation/visualization_utils.py b/evaluation/visualization_utils.py new file mode 100644 index 0000000..5029e59 --- /dev/null +++ b/evaluation/visualization_utils.py @@ -0,0 +1,216 @@ +import pathlib +from typing import Tuple, List +import random + +import numpy as np +import torch +from torch.utils.data import Dataset +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle +from PIL import Image +from albumentations import ImageOnlyTransform +from albumentations.core.composition import Compose + +def invert_transforms( + numpy_img: np.ndarray, + transforms: ImageOnlyTransform | Compose = None + ) -> np.ndarray: + + if isinstance(transforms, ImageOnlyTransform): + return transforms.invert(numpy_img) + elif isinstance(transforms, Compose): + for transform in reversed(transforms.transforms): + numpy_img = transform.invert(numpy_img) + elif transforms is None: + return numpy_img + else: + raise ValueError(f"Invalid transforms type: {type(transforms)}") + + return numpy_img + +def format_img( + _tensor_img: torch.Tensor, + cast_to_type: torch.dtype = None + ) -> np.ndarray: + + if cast_to_type is not None: + _tensor_img = _tensor_img.to(cast_to_type) + + img = torch.squeeze(_tensor_img).cpu().numpy() + + return img + +def evaluate_and_format_imgs( + _input: torch.Tensor, + _target: torch.Tensor, + model=None, + _input_transform: ImageOnlyTransform | Compose=None, + _target_transform: ImageOnlyTransform | Compose=None, + device: str='cpu' + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + input_transform = invert_transforms( + format_img(_input), + _input_transform + ) + target_transform = invert_transforms( + format_img(_target), + _target_transform + ) + + if model is not None: + model.to(device) + model.eval() + with torch.no_grad(): + # Forward Pass + output = model(_input.unsqueeze(1).to(device)) + + output_transform = invert_transforms( + format_img(output), + _target_transform + ) + else: + output_transform = None + + return input_transform, target_transform, output_transform + +def plot_patch( + _raw_img: np.ndarray, + _patch_size: int, + _patch_coords: Tuple[int, int], + _input: torch.Tensor, + _target: torch.Tensor, + _output: torch.Tensor = None, + axes: List = None, + **kwargs +): + ## Plot keyword arguments + cmap = kwargs.get("cmap", "gray") + vmin = kwargs.get("vmin", None) + vmax = kwargs.get("vmax", None) + figsize = kwargs.get("figsize", None) + if figsize is None: + panel_width = kwargs.get("panel_width", 5) + figsize = (panel_width, panel_width * 3 if _output is None else 4) + else: + panel_width = None + + if axes is None: + fig, ax = plt.subplots(1, 3 if _output is None else 4, figsize=figsize) + else: + ax = axes + + # plot image + ax[0].imshow(_raw_img, cmap=cmap) + ax[0].set_title("Raw Image") + ax[0].axis("off") + + rect = Rectangle( + _patch_coords, + _patch_size, + _patch_size, + linewidth=1, + edgecolor="r", + facecolor="none" + ) + + if vmin is None: + vmin = min(_output.min(), _target.min()) + if vmax is None: + vmax = max(_output.max(), _target.max()) + + ax[0].add_patch(rect) + + # plot input + ax[1].imshow(_input, cmap=cmap) + ax[1].set_title("Input") + ax[1].axis("off") + + # plot target + ax[2].imshow(_target, cmap=cmap, vmin=vmin, vmax=vmax) + ax[2].set_title("Target") + ax[2].axis("off") + + if _output is not None: + ax[3].imshow(_output, cmap=cmap, vmin=vmin, vmax=vmax) + ax[3].set_title("Output") + ax[3].axis("off") + +def plot_patches( + _dataset: Dataset, + _n_patches: int=5, + _model: torch.nn.Module=None, + _patch_index: List[int]=None, + _random_seed: int=42, + _metrics: List[torch.nn.Module]=None, + device: str='cpu', + **kwargs +): + ## Plot keyword arguments + cmap = kwargs.get("cmap", "gray") + vmin = kwargs.get("vmin", None) + vmax = kwargs.get("vmax", None) + panel_width = kwargs.get("panel_width", 5) + save_path = kwargs.get("save_path", None) + show_plot = kwargs.get("show_plot", True) + + ## Generate random patch indices to visualize + if _patch_index is None: + random.seed(_random_seed) + _patch_index = random.sample(range(len(_dataset)), _n_patches) + else: + _patch_index = [i for i in _patch_index if i < len(_dataset)] + _n_patches = len(_patch_index) + + figsize = kwargs.get("figsize", None) + if figsize is None: + figsize = (panel_width * _n_patches, panel_width * 3 if _model is None else 4, ) + fig, axes = plt.subplots(_n_patches, 3 if _model is None else 4, figsize=figsize) + + for i, row_axes in zip(_patch_index, axes): + _input, _target = _dataset[i] + _raw_image = np.array(Image.open( + _dataset.input_names[0] + )) + _input, _target, _output = evaluate_and_format_imgs( + _input, + _target, + _model, + device=device + ) + + plot_patch( + _raw_img=_raw_image, + _patch_size=_input.shape[-1], + _patch_coords=_dataset.patch_coords, + _input=_input, + _target=_target, + _output=_output, + axes=row_axes, + cmap=cmap, + vmin=vmin, + vmax=vmax + ) + + ## Compute metrics for single set of target output pairs and add to subplot title + metric_str = "" + if _metrics is not None: + for _metric in _metrics: + metric_val = _metric( + torch.tensor(_output).unsqueeze(0).unsqueeze(0), + torch.tensor(_target).unsqueeze(0).unsqueeze(0) + ).item() + metric_str = f"{metric_str}\n{_metric.__class__.__name__}: {metric_val:.2f}" + row_axes[-1].set_title( + row_axes[-1].get_title() + metric_str + ) + + plt.tight_layout() + + if show_plot: + plt.show() + + if save_path is not None: + plt.savefig(save_path) + + plt.close() \ No newline at end of file From 81f3ccd92ffdfc703e9a33a84edc62ea40911181 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:31:34 -0700 Subject: [PATCH 12/89] Made some modifications to callback --- callbacks/IntermediatePlot.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/callbacks/IntermediatePlot.py b/callbacks/IntermediatePlot.py index 30c084f..56f6224 100644 --- a/callbacks/IntermediatePlot.py +++ b/callbacks/IntermediatePlot.py @@ -5,7 +5,6 @@ from .AbstractCallback import AbstractCallback from ..datasets.PatchDataset import PatchDataset -from ..datasets.PatchDataset2 import PatchDataset as PDS2 from ..evaluation.visualization_utils import plot_patches @@ -17,7 +16,7 @@ class IntermediatePatchPlot(AbstractCallback): def __init__(self, name: str, path: str, - dataset: PatchDataset | PDS2, + dataset: PatchDataset, plot_n_patches: int=5, plot_metrics: List[nn.Module]=None, **kwargs): @@ -27,7 +26,7 @@ def __init__(self, """ super().__init__(name) self._path = path - if not isinstance(dataset, Union[PatchDataset, PDS2]): + if not isinstance(dataset, PatchDataset): raise TypeError(f"Expected PatchDataset, got {type(dataset)}") self._dataset = dataset From c8a27ea3bbce4bab3e88b7810dcd4457c79f8f82 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:38:52 -0700 Subject: [PATCH 13/89] Added gitignore --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b293968 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +# images and anything under mlflow +*.png +examples/example_train/* + +# pycache +*.pyc \ No newline at end of file From b7c0df2117dcb4826c0379f3ccf4bde45b44e387 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 14 Feb 2025 14:39:03 -0700 Subject: [PATCH 14/89] Added notebook that is a minimal example --- examples/minimal_example.ipynb | 733 +++++++++++++++++++++++++++++++++ 1 file changed, 733 insertions(+) create mode 100644 examples/minimal_example.ipynb diff --git a/examples/minimal_example.ipynb b/examples/minimal_example.ipynb new file mode 100644 index 0000000..4cda17b --- /dev/null +++ b/examples/minimal_example.ipynb @@ -0,0 +1,733 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A minimal example to demonstrate how the trainer for FNet and wGaN GP plus the callbacks works along with patched dataset\n", + "\n", + "Is dependent on the files produced by 1.illumination_correction/0.create_loaddata_csvs ALSF pilot data repo https://github.com/WayScience/pediatric_cancer_atlas_profiling" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/weishanli/Waylab\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/weishanli/anaconda3/envs/speckle_analysis/lib/python3.11/site-packages/albumentations/__init__.py:28: UserWarning: A new version of Albumentations is available: '2.0.4' (you have '2.0.1'). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n", + " check_for_updates()\n" + ] + } + ], + "source": [ + "import sys\n", + "import pathlib\n", + "\n", + "import pandas as pd\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import mlflow\n", + "\n", + "sys.path.append(str(pathlib.Path('.').absolute().parent.parent))\n", + "print(str(pathlib.Path('.').absolute().parent.parent))\n", + "\n", + "## Dataset\n", + "from virtual_stain_flow.datasets.PatchDataset import PatchDataset\n", + "from virtual_stain_flow.datasets.CachedDataset import CachedDataset\n", + "\n", + "## FNet training\n", + "from virtual_stain_flow.models.fnet import FNet\n", + "from virtual_stain_flow.trainers.Trainer import Trainer\n", + "\n", + "## wGaN training\n", + "from virtual_stain_flow.models.unet import UNet\n", + "from virtual_stain_flow.models.discriminator import GlobalDiscriminator\n", + "from virtual_stain_flow.trainers.WGaNTrainer import WGaNTrainer\n", + "\n", + "## wGaN losses\n", + "from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss\n", + "from virtual_stain_flow.losses.DiscriminatorLoss import DiscriminatorLoss\n", + "from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss\n", + "\n", + "from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize\n", + "\n", + "## Metrics\n", + "from virtual_stain_flow.metrics.PSNR import PSNR\n", + "from virtual_stain_flow.metrics.SSIM import SSIM\n", + "\n", + "## callback\n", + "from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger\n", + "from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePatchPlot\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specify train output paths" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "EXAMPLE_DIR = pathlib.Path('.').absolute() / 'example_train'\n", + "EXAMPLE_DIR.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "!rm -rf example_train/*" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "PLOT_DIR = EXAMPLE_DIR / 'plot'\n", + "PLOT_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "MLFLOW_DIR =EXAMPLE_DIR / 'mlflow'\n", + "MLFLOW_DIR.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specify paths to loaddata and read a single" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "## REPLACE WITH YOUR OWN PATHS\n", + "analysis_home_path = pathlib.Path('/home/weishanli/Waylab/ALSF_pilot/ALSF_img2img_prototyping')\n", + "sc_features_parquet_path = pathlib.Path(\n", + " '/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pilot_data/preprocessed_profiles_SN0313537/single_cell_profiles'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " FileName_OrigBrightfield \\\n", + "2079 r06c22f01p01-ch1sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch1sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch1sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch1sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch1sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigBrightfield \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigER \\\n", + "2079 r06c22f01p01-ch2sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch2sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch2sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch2sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch2sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigER \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigAGP \\\n", + "2079 r06c22f01p01-ch3sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch3sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch3sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch3sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch3sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigAGP \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigMito \\\n", + "2079 r06c22f01p01-ch4sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch4sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch4sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch4sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch4sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigMito \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigDNA \\\n", + "2079 r06c22f01p01-ch5sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch5sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch5sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch5sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch5sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigDNA ... \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "\n", + " Metadata_AbsPositionZ Metadata_ChannelID Metadata_Col Metadata_FieldID \\\n", + "2079 0.134358 6 22 1 \n", + "668 0.134405 6 9 3 \n", + "2073 0.134366 6 22 4 \n", + "1113 0.134347 6 13 7 \n", + "788 0.134381 6 10 6 \n", + "\n", + " Metadata_PlaneID Metadata_PositionX Metadata_PositionY \\\n", + "2079 1 0.000000 0.000000 \n", + "668 1 0.000000 0.000646 \n", + "2073 1 0.000646 0.000646 \n", + "1113 1 -0.000646 -0.000646 \n", + "788 1 -0.000646 0.000000 \n", + "\n", + " Metadata_PositionZ Metadata_Row Metadata_Reimaged \n", + "2079 -0.000006 6 False \n", + "668 -0.000006 5 False \n", + "2073 -0.000006 5 False \n", + "1113 -0.000006 6 False \n", + "788 -0.000006 6 False \n", + "\n", + "[5 rows x 25 columns]\n", + " Metadata_Plate Metadata_Well Metadata_Site \\\n", + "0 BR00143976 C03 2 \n", + "1 BR00143976 C03 6 \n", + "2 BR00143976 C03 9 \n", + "3 BR00143976 C03 5 \n", + "4 BR00143976 C03 7 \n", + "\n", + " Metadata_Cells_Location_Center_X Metadata_Cells_Location_Center_Y \n", + "0 629.552987 62.017799 \n", + "1 279.951864 56.588228 \n", + "2 876.508878 205.794360 \n", + "3 479.254866 45.496581 \n", + "4 866.557068 205.908787 \n" + ] + } + ], + "source": [ + "loaddata_csv_path = analysis_home_path \\\n", + " / '0.data_analysis_and_preprocessing' / 'loaddata_csvs'\n", + "\n", + "if loaddata_csv_path.exists():\n", + " try:\n", + " loaddata_csv = next(loaddata_csv_path.glob('*.csv'))\n", + " except:\n", + " raise FileNotFoundError(\"No loaddata csv found\")\n", + "else:\n", + " raise ValueError(\"Incorrect loaddata csv path\")\n", + "\n", + "loaddata_df = pd.read_csv(loaddata_csv)\n", + "# subsample to reduce runtime\n", + "loaddata_df = loaddata_df.sample(n=100, random_state=42)\n", + "\n", + "sc_features = pd.DataFrame()\n", + "for plate in loaddata_df['Metadata_Plate'].unique():\n", + " sc_features_parquet = sc_features_parquet_path / f'{plate}_sc_normalized.parquet'\n", + " if not sc_features_parquet.exists():\n", + " print(f'{sc_features_parquet} does not exist, skipping...')\n", + " continue \n", + " else:\n", + " sc_features = pd.concat([\n", + " sc_features, \n", + " pd.read_parquet(\n", + " sc_features_parquet,\n", + " columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']\n", + " )\n", + " ])\n", + "\n", + "print(loaddata_df.head())\n", + "print(sc_features.head())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Patch size and channels" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "PATCH_SIZE = 256\n", + "\n", + "channel_names = [\n", + " \"OrigBrightfield\",\n", + " \"OrigDNA\",\n", + " \"OrigER\",\n", + " \"OrigMito\",\n", + " \"OrigRNA\",\n", + " \"OrigAGP\",\n", + "]\n", + "input_channel_name = \"OrigBrightfield\"\n", + "target_channel_names = [ch for ch in channel_names if ch != input_channel_name]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prep Patch dataset and Cache" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-02-14 14:33:43,785 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n", + "2025-02-14 14:33:43,785 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-02-14 14:33:43,785 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n", + "2025-02-14 14:33:43,786 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n", + "2025-02-14 14:33:43,786 - DEBUG - Merge fields inferred: ['Metadata_Site', 'Metadata_Plate', 'Metadata_Well']\n", + "2025-02-14 14:33:43,786 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-02-14 14:33:43,808 - DEBUG - Inferring channel keys from loaddata csv\n", + "2025-02-14 14:33:43,809 - DEBUG - Channel keys: {'OrigRNA', 'OrigMito', 'OrigER', 'OrigBrightfield', 'OrigAGP', 'OrigDNA'} inferred from loaddata csv\n", + "2025-02-14 14:33:43,809 - DEBUG - Setting input channel(s) ...\n", + "2025-02-14 14:33:43,809 - DEBUG - No channel keys specified, skip\n", + "2025-02-14 14:33:43,809 - DEBUG - Setting target channel(s) ...\n", + "2025-02-14 14:33:43,809 - DEBUG - No channel keys specified, skip\n", + "2025-02-14 14:33:43,809 - DEBUG - Setting input transform ...\n", + "2025-02-14 14:33:43,810 - DEBUG - Setting target transform ...\n", + "2025-02-14 14:33:43,810 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n", + "2025-02-14 14:33:43,832 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n", + "2025-02-14 14:33:43,832 - DEBUG - Generating patches that contain cells\n", + "2025-02-14 14:33:43,857 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n", + "2025-02-14 14:33:44,237 - DEBUG - Generated 461 patches for 93 site/view\n", + "2025-02-14 14:33:44,238 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-14 14:33:44,238 - DEBUG - Set target channel(s) as ['OrigDNA']\n" + ] + } + ], + "source": [ + "pds = PatchDataset(\n", + " _loaddata_csv=loaddata_df,\n", + " _sc_feature=sc_features,\n", + " _input_channel_keys=None,\n", + " _target_channel_keys=None,\n", + " _input_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),\n", + " _target_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),\n", + " patch_size=PATCH_SIZE,\n", + " verbose=True,\n", + " patch_generation_method=\"random_cell\",\n", + " patch_generation_random_seed=42\n", + ")\n", + "\n", + "## Set input and target channels\n", + "pds.set_input_channel_keys([input_channel_name])\n", + "pds.set_target_channel_keys('OrigDNA')\n", + "\n", + "## Cache for faster training \n", + "cds = CachedDataset(\n", + " pds,\n", + " prefill_cache=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FNet trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model without callback and check logs" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "model = FNet(depth=4)\n", + "lr = 3e-4\n", + "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=None,\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda'\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochL1Lossval_L1Losspsnrssimval_psnrval_ssim
010.4788380.4644696.3280120.0198896.5855280.026380
120.4091230.4540157.6803250.0338956.7807150.026846
230.3583650.4258668.7902060.0423797.3293780.028593
340.3104460.37652210.0156050.0475748.3844400.031521
450.2746480.32629111.0842160.0509699.6142840.035671
560.2355580.28046312.3493740.06554710.9243710.040533
670.2075860.22384013.4213410.07273312.8221320.047623
780.1820400.19801514.3961230.07936013.8689110.055814
890.1547120.15825015.8210450.09873615.7265260.069852
9100.1382530.14088116.7469220.10889516.7409480.083398
\n", + "
" + ], + "text/plain": [ + " epoch L1Loss val_L1Loss psnr ssim val_psnr val_ssim\n", + "0 1 0.478838 0.464469 6.328012 0.019889 6.585528 0.026380\n", + "1 2 0.409123 0.454015 7.680325 0.033895 6.780715 0.026846\n", + "2 3 0.358365 0.425866 8.790206 0.042379 7.329378 0.028593\n", + "3 4 0.310446 0.376522 10.015605 0.047574 8.384440 0.031521\n", + "4 5 0.274648 0.326291 11.084216 0.050969 9.614284 0.035671\n", + "5 6 0.235558 0.280463 12.349374 0.065547 10.924371 0.040533\n", + "6 7 0.207586 0.223840 13.421341 0.072733 12.822132 0.047623\n", + "7 8 0.182040 0.198015 14.396123 0.079360 13.868911 0.055814\n", + "8 9 0.154712 0.158250 15.821045 0.098736 15.726526 0.069852\n", + "9 10 0.138253 0.140881 16.746922 0.108895 16.740948 0.083398" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame(trainer.log)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train with mlflow logger callbacks" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'lr': 3e-4\n", + " },\n", + " )\n", + "\n", + "del trainer\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=[mlflow_logger_callback],\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda'\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# wGaN GP example with mlflow logger callback and plot callback" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "generator = UNet(\n", + " n_channels=1,\n", + " n_classes=1\n", + ")\n", + "\n", + "discriminator = GlobalDiscriminator(\n", + " n_in_channels = 2,\n", + " n_in_filters = 64,\n", + " _conv_depth = 4,\n", + " _pool_before_fc = True\n", + ")\n", + "\n", + "generator_optimizer = optim.Adam(generator.parameters(), \n", + " lr=0.0002, \n", + " betas=(0., 0.9))\n", + "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n", + " lr=0.00002, \n", + " betas=(0., 0.9),\n", + " weight_decay=0.001)\n", + "\n", + "gp_loss = GradientPenaltyLoss(\n", + " _metric_name='gp_loss',\n", + " discriminator=discriminator,\n", + " weight=10.0,\n", + ")\n", + "\n", + "gen_loss = GeneratorLoss(\n", + " _metric_name='gen_loss'\n", + ")\n", + "\n", + "disc_loss = DiscriminatorLoss(\n", + " _metric_name='disc_loss'\n", + ")\n", + "\n", + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train_wgan', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'gen_lr': 0.0002,\n", + " 'disc_lr': 0.00002\n", + " },\n", + " )\n", + "\n", + "plot_callback = IntermediatePatchPlot(\n", + " name='plotter',\n", + " path=PLOT_DIR,\n", + " dataset=pds, # give it the patch dataset as opposed to the cached dataset\n", + " plot_metrics=[SSIM(_metric_name='ssim'), PSNR(_metric_name='psnr')],\n", + " figsize=(20, 25),\n", + " show_plot=False,\n", + ")\n", + "\n", + "wgan_trainer = WGaNTrainer(\n", + " dataset=cds,\n", + " batch_size=16,\n", + " epochs=20,\n", + " patience=20,\n", + " device='cuda',\n", + " generator=generator,\n", + " discriminator=discriminator,\n", + " gen_optimizer=generator_optimizer,\n", + " disc_optimizer=discriminator_optimizer,\n", + " generator_loss_fn=gen_loss,\n", + " discriminator_loss_fn=disc_loss,\n", + " gradient_penalty_fn=gp_loss,\n", + " discriminator_update_freq=1,\n", + " generator_update_freq=2,\n", + " callbacks=[mlflow_logger_callback, plot_callback],\n", + " metrics={'ssim': SSIM(_metric_name='ssim'), \n", + " 'psnr': PSNR(_metric_name='psnr')},\n", + ")\n", + "\n", + "wgan_trainer.train()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "speckle_analysis", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From a311fd750af32a239c1a99955c0f0f6c9b4df8da Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sun, 16 Feb 2025 18:41:03 -0700 Subject: [PATCH 15/89] Modified the way trainers are accessed by callbacks, instead of trainer passing itself into each callback each invocation the trainers and initialized as internal attribute of each callback during trainer initialization. --- callbacks/AbstractCallback.py | 33 +++- callbacks/IntermediatePlot.py | 23 ++- callbacks/MlflowLogger.py | 28 ++- examples/minimal_example.ipynb | 339 ++++++++++++++++++++++++--------- trainers/AbstractTrainer.py | 32 +++- 5 files changed, 340 insertions(+), 115 deletions(-) diff --git a/callbacks/AbstractCallback.py b/callbacks/AbstractCallback.py index c1462a4..26d0b12 100644 --- a/callbacks/AbstractCallback.py +++ b/callbacks/AbstractCallback.py @@ -14,26 +14,51 @@ def __init__(self, name: str): :param name: Name of the callback. """ self._name = name + self._trainer = None + + @property + def name(self): + """ + Getter for callback name + """ + return self._name + + @property + def trainer(self): + """ + Allows for access of trainer + """ + return self._trainer + + def _set_trainer(self, trainer): + """ + Helper function called by trainer class to initialize trainer value field + + :param trainer: trainer object + :type trainer: AbstractTrainer or subclass + """ + + self._trainer = trainer - def on_train_start(self, trainer): + def on_train_start(self): """ Called at the start of training. """ pass - def on_epoch_start(self, trainer): + def on_epoch_start(self): """ Called at the start of each epoch. """ pass - def on_epoch_end(self, trainer): + def on_epoch_end(self): """ Called at the end of each epoch. """ pass - def on_train_end(self, trainer): + def on_train_end(self): """ Called at the end of training. """ diff --git a/callbacks/IntermediatePlot.py b/callbacks/IntermediatePlot.py index 56f6224..92f2df5 100644 --- a/callbacks/IntermediatePlot.py +++ b/callbacks/IntermediatePlot.py @@ -13,7 +13,7 @@ class IntermediatePatchPlot(AbstractCallback): Callback to save the model weights at the end of each epoch. """ - def __init__(self, + def __init__(self, name: str, path: str, dataset: PatchDataset, @@ -21,8 +21,21 @@ def __init__(self, plot_metrics: List[nn.Module]=None, **kwargs): """ + Initialize the IntermediatePlot callback. + :param name: Name of the callback. + :type name: str :param path: Path to save the model weights. + :type path: str + :param dataset: Dataset to be used for plotting intermediate results. + :type dataset: PatchDataset + :param plot_n_patches: Number of patches to plot, defaults to 5. + :type plot_n_patches: int, optional + :param plot_metrics: List of metrics to compute and display in plot title, defaults to None. + :type plot_metrics: List[nn.Module], optional + :param kwargs: Additional keyword arguments to be passed to plot_patches. + :type kwargs: dict + :raises TypeError: If the dataset is not an instance of PatchDataset. """ super().__init__(name) self._path = path @@ -35,19 +48,19 @@ def __init__(self, self.plot_metrics = plot_metrics self.plot_kwargs = kwargs - def on_epoch_end(self, trainer): + def on_epoch_end(self): """ Plot dataset with model predictions at the end of each epoch. """ - original_device = next(trainer.model.parameters()).device + original_device = next(self.trainer.model.parameters()).device plot_patches( _dataset = self._dataset, _n_patches = self.plot_n_patches, - _model = trainer.model, + _model = self.trainer.model, _metrics = self.plot_metrics, - save_path = f"{self._path}/epoch_{trainer.epoch}.png", + save_path = f"{self._path}/epoch_{self.trainer.epoch}.png", device=original_device, **self.plot_kwargs ) \ No newline at end of file diff --git a/callbacks/MlflowLogger.py b/callbacks/MlflowLogger.py index e4675e7..bac3671 100644 --- a/callbacks/MlflowLogger.py +++ b/callbacks/MlflowLogger.py @@ -13,16 +13,30 @@ class MlflowLogger(AbstractCallback): """ def __init__(self, + name: str, artifact_name: str = 'best_model_weights.pth', mlflow_uri: pathlib.Path | str = 'mlruns', - mlflow_experiment_name: str = 'default_experiment', + mlflow_experiment_name: str = 'Default', mlflow_start_run_args: dict = {}, mlflow_log_params_args: dict = {}, ): """ + Initialize the MlflowLogger callback. + :param name: Name of the callback. + :type name: str + :param artifact_name: Name of the artifact file to log, defaults to 'best_model_weights.pth'. + :type artifact_name: str, optional + :param mlflow_uri: URI for the MLflow tracking server, defaults to 'mlruns' under current wd. + :type mlflow_uri: pathlib.Path or str, optional + :param mlflow_experiment_name: Name of the MLflow experiment, defaults to 'Default'. + :type mlflow_experiment_name: str, optional + :param mlflow_start_run_args: Additional arguments for starting an MLflow run, defaults to {}. + :type mlflow_start_run_args: dict, optional + :param mlflow_log_params_args: Additional arguments for logging parameters to MLflow, defaults to {}. + :type mlflow_log_params_args: dict, optional """ super().__init__(name) @@ -36,7 +50,7 @@ def __init__(self, self._mlflow_start_run_args = mlflow_start_run_args self._mlflow_log_params_args = mlflow_log_params_args - def on_train_start(self, trainer): + def on_train_start(self): """ Called at the start of training. """ @@ -47,25 +61,25 @@ def on_train_start(self, trainer): self._mlflow_log_params_args ) - def on_epoch_end(self, trainer): + def on_epoch_end(self): """ Called at the end of each epoch. """ - for key, values in trainer.log.items(): + for key, values in self.trainer.log.items(): if values is not None and len(values) > 0: value = values[-1] else: value = None - mlflow.log_metric(key, value, step=trainer.epoch) + mlflow.log_metric(key, value, step=self.trainer.epoch) - def on_train_end(self, trainer): + def on_train_end(self): """ Called at the end of training. """ # Save weights to a temporary directory and log artifacts with tempfile.TemporaryDirectory() as tmpdirname: weights_path = os.path.join(tmpdirname, self._artifact_name) - torch.save(trainer.best_model, weights_path) + torch.save(self.trainer.best_model, weights_path) mlflow.log_artifact(weights_path, artifact_path="models") mlflow.end_run() \ No newline at end of file diff --git a/examples/minimal_example.ipynb b/examples/minimal_example.ipynb index 4cda17b..3eccca4 100644 --- a/examples/minimal_example.ipynb +++ b/examples/minimal_example.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -37,7 +37,6 @@ "import pandas as pd\n", "import torch.nn as nn\n", "import torch.optim as optim\n", - "import mlflow\n", "\n", "sys.path.append(str(pathlib.Path('.').absolute().parent.parent))\n", "print(str(pathlib.Path('.').absolute().parent.parent))\n", @@ -325,27 +324,27 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-02-14 14:33:43,785 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n", - "2025-02-14 14:33:43,785 - DEBUG - Dataframe supplied for sc_feature, using as is\n", - "2025-02-14 14:33:43,785 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n", - "2025-02-14 14:33:43,786 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n", - "2025-02-14 14:33:43,786 - DEBUG - Merge fields inferred: ['Metadata_Site', 'Metadata_Plate', 'Metadata_Well']\n", - "2025-02-14 14:33:43,786 - DEBUG - Dataframe supplied for sc_feature, using as is\n", - "2025-02-14 14:33:43,808 - DEBUG - Inferring channel keys from loaddata csv\n", - "2025-02-14 14:33:43,809 - DEBUG - Channel keys: {'OrigRNA', 'OrigMito', 'OrigER', 'OrigBrightfield', 'OrigAGP', 'OrigDNA'} inferred from loaddata csv\n", - "2025-02-14 14:33:43,809 - DEBUG - Setting input channel(s) ...\n", - "2025-02-14 14:33:43,809 - DEBUG - No channel keys specified, skip\n", - "2025-02-14 14:33:43,809 - DEBUG - Setting target channel(s) ...\n", - "2025-02-14 14:33:43,809 - DEBUG - No channel keys specified, skip\n", - "2025-02-14 14:33:43,809 - DEBUG - Setting input transform ...\n", - "2025-02-14 14:33:43,810 - DEBUG - Setting target transform ...\n", - "2025-02-14 14:33:43,810 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n", - "2025-02-14 14:33:43,832 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n", - "2025-02-14 14:33:43,832 - DEBUG - Generating patches that contain cells\n", - "2025-02-14 14:33:43,857 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n", - "2025-02-14 14:33:44,237 - DEBUG - Generated 461 patches for 93 site/view\n", - "2025-02-14 14:33:44,238 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", - "2025-02-14 14:33:44,238 - DEBUG - Set target channel(s) as ['OrigDNA']\n" + "2025-02-16 18:38:29,485 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n", + "2025-02-16 18:38:29,485 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-02-16 18:38:29,486 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n", + "2025-02-16 18:38:29,486 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n", + "2025-02-16 18:38:29,486 - DEBUG - Merge fields inferred: ['Metadata_Plate', 'Metadata_Site', 'Metadata_Well']\n", + "2025-02-16 18:38:29,486 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-02-16 18:38:29,506 - DEBUG - Inferring channel keys from loaddata csv\n", + "2025-02-16 18:38:29,506 - DEBUG - Channel keys: {'OrigAGP', 'OrigDNA', 'OrigMito', 'OrigBrightfield', 'OrigRNA', 'OrigER'} inferred from loaddata csv\n", + "2025-02-16 18:38:29,507 - DEBUG - Setting input channel(s) ...\n", + "2025-02-16 18:38:29,507 - DEBUG - No channel keys specified, skip\n", + "2025-02-16 18:38:29,507 - DEBUG - Setting target channel(s) ...\n", + "2025-02-16 18:38:29,507 - DEBUG - No channel keys specified, skip\n", + "2025-02-16 18:38:29,507 - DEBUG - Setting input transform ...\n", + "2025-02-16 18:38:29,508 - DEBUG - Setting target transform ...\n", + "2025-02-16 18:38:29,508 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n", + "2025-02-16 18:38:29,524 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n", + "2025-02-16 18:38:29,524 - DEBUG - Generating patches that contain cells\n", + "2025-02-16 18:38:29,541 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n", + "2025-02-16 18:38:29,913 - DEBUG - Generated 461 patches for 93 site/view\n", + "2025-02-16 18:38:29,913 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-16 18:38:29,913 - DEBUG - Set target channel(s) as ['OrigDNA']\n" ] } ], @@ -421,6 +420,158 @@ "outputs": [ { "data": { + "application/vnd.microsoft.datawrangler.viewer.v0+json": { + "columns": [ + { + "name": "index", + "rawType": "int64", + "type": "integer" + }, + { + "name": "epoch", + "rawType": "int64", + "type": "integer" + }, + { + "name": "L1Loss", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_L1Loss", + "rawType": "float64", + "type": "float" + }, + { + "name": "psnr", + "rawType": "float64", + "type": "float" + }, + { + "name": "ssim", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_psnr", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_ssim", + "rawType": "float64", + "type": "float" + } + ], + "conversionMethod": "pd.DataFrame", + "ref": "147787f0-baab-44cf-a23b-5eddc27ca95f", + "rows": [ + [ + "0", + "1", + "0.48533629775047304", + "0.48593512177467346", + "6.1978302001953125", + "0.013830780982971191", + "6.226802349090576", + "0.027033504098653793" + ], + [ + "1", + "2", + "0.4069934129714966", + "0.4599619507789612", + "7.657495021820068", + "0.030011823400855064", + "6.700008869171143", + "0.028457120060920715" + ], + [ + "2", + "3", + "0.3600158095359802", + "0.4170601963996887", + "8.672477722167969", + "0.03410625830292702", + "7.542667388916016", + "0.03110448829829693" + ], + [ + "3", + "4", + "0.30498751997947693", + "0.36080312728881836", + "10.097426414489746", + "0.043731819838285446", + "8.787214279174805", + "0.03489777073264122" + ], + [ + "4", + "5", + "0.2654287576675415", + "0.308460533618927", + "11.309123992919922", + "0.05290067195892334", + "10.130162239074707", + "0.0397116057574749" + ], + [ + "5", + "6", + "0.21867721974849702", + "0.26461610198020935", + "12.826869010925293", + "0.06800838559865952", + "11.43663215637207", + "0.04436549171805382" + ], + [ + "6", + "7", + "0.20276209115982055", + "0.22711718082427979", + "13.544842720031738", + "0.07081294804811478", + "12.7350492477417", + "0.05026976764202118" + ], + [ + "7", + "8", + "0.17667416632175445", + "0.1996302306652069", + "14.679831504821777", + "0.09041351824998856", + "13.814286231994629", + "0.054895032197237015" + ], + [ + "8", + "9", + "0.15714640021324158", + "0.1599843055009842", + "15.63988971710205", + "0.11157377064228058", + "15.694003105163574", + "0.0697111263871193" + ], + [ + "9", + "10", + "0.1425831973552704", + "0.1461305469274521", + "16.46839141845703", + "0.13909131288528442", + "16.479198455810547", + "0.07866603136062622" + ] + ], + "shape": { + "columns": 7, + "rows": 10 + } + }, "text/html": [ "
\n", "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FileName_OrigBrightfieldPathName_OrigBrightfieldFileName_OrigERPathName_OrigERFileName_OrigAGPPathName_OrigAGPFileName_OrigMitoPathName_OrigMitoFileName_OrigDNAPathName_OrigDNA...Metadata_AbsPositionZMetadata_ChannelIDMetadata_ColMetadata_FieldIDMetadata_PlaneIDMetadata_PositionXMetadata_PositionYMetadata_PositionZMetadata_RowMetadata_Reimaged
2079r06c22f01p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c22f01p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c22f01p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c22f01p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c22f01p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.134358622110.0000000.000000-0.0000066False
668r05c09f03p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c09f03p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c09f03p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c09f03p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c09f03p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13440569310.0000000.000646-0.0000065False
2073r05c22f04p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c22f04p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c22f04p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c22f04p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r05c22f04p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.134366622410.0006460.000646-0.0000065False
1113r06c13f07p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c13f07p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c13f07p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c13f07p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c13f07p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13434761371-0.000646-0.000646-0.0000066False
788r06c10f06p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c10f06p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c10f06p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c10f06p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r06c10f06p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13438161061-0.0006460.000000-0.0000066False
..................................................................
1730r03c19f03p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c19f03p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c19f03p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c19f03p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c19f03p01-ch6sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.134366619310.0000000.000646-0.0000043True
196r12c04f08p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r12c04f08p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r12c04f08p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r12c04f08p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r12c04f08p01-ch6sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13449164810.000000-0.000646-0.00000412True
367r07c06f08p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r07c06f08p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r07c06f08p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r07c06f08p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r07c06f08p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13444766810.000000-0.000646-0.0000067False
650r03c09f03p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c09f03p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c09f03p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c09f03p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r03c09f03p01-ch5sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.13442869310.0000000.000646-0.0000063False
2064r04c22f04p01-ch1sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r04c22f04p01-ch2sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r04c22f04p01-ch4sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r04c22f04p01-ch3sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi...r04c22f04p01-ch6sk1fk1fl1.tiff/home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi......0.134379622410.0006460.000646-0.0000044True
\n", + "

100 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " FileName_OrigBrightfield \\\n", + "2079 r06c22f01p01-ch1sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch1sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch1sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch1sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch1sk1fk1fl1.tiff \n", + "... ... \n", + "1730 r03c19f03p01-ch1sk1fk1fl1.tiff \n", + "196 r12c04f08p01-ch1sk1fk1fl1.tiff \n", + "367 r07c06f08p01-ch1sk1fk1fl1.tiff \n", + "650 r03c09f03p01-ch1sk1fk1fl1.tiff \n", + "2064 r04c22f04p01-ch1sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigBrightfield \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "... ... \n", + "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigER \\\n", + "2079 r06c22f01p01-ch2sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch2sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch2sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch2sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch2sk1fk1fl1.tiff \n", + "... ... \n", + "1730 r03c19f03p01-ch2sk1fk1fl1.tiff \n", + "196 r12c04f08p01-ch2sk1fk1fl1.tiff \n", + "367 r07c06f08p01-ch2sk1fk1fl1.tiff \n", + "650 r03c09f03p01-ch2sk1fk1fl1.tiff \n", + "2064 r04c22f04p01-ch2sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigER \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "... ... \n", + "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigAGP \\\n", + "2079 r06c22f01p01-ch3sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch3sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch3sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch3sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch3sk1fk1fl1.tiff \n", + "... ... \n", + "1730 r03c19f03p01-ch4sk1fk1fl1.tiff \n", + "196 r12c04f08p01-ch4sk1fk1fl1.tiff \n", + "367 r07c06f08p01-ch3sk1fk1fl1.tiff \n", + "650 r03c09f03p01-ch3sk1fk1fl1.tiff \n", + "2064 r04c22f04p01-ch4sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigAGP \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "... ... \n", + "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigMito \\\n", + "2079 r06c22f01p01-ch4sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch4sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch4sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch4sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch4sk1fk1fl1.tiff \n", + "... ... \n", + "1730 r03c19f03p01-ch3sk1fk1fl1.tiff \n", + "196 r12c04f08p01-ch3sk1fk1fl1.tiff \n", + "367 r07c06f08p01-ch4sk1fk1fl1.tiff \n", + "650 r03c09f03p01-ch4sk1fk1fl1.tiff \n", + "2064 r04c22f04p01-ch3sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigMito \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "... ... \n", + "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... \n", + "\n", + " FileName_OrigDNA \\\n", + "2079 r06c22f01p01-ch5sk1fk1fl1.tiff \n", + "668 r05c09f03p01-ch5sk1fk1fl1.tiff \n", + "2073 r05c22f04p01-ch5sk1fk1fl1.tiff \n", + "1113 r06c13f07p01-ch5sk1fk1fl1.tiff \n", + "788 r06c10f06p01-ch5sk1fk1fl1.tiff \n", + "... ... \n", + "1730 r03c19f03p01-ch6sk1fk1fl1.tiff \n", + "196 r12c04f08p01-ch6sk1fk1fl1.tiff \n", + "367 r07c06f08p01-ch5sk1fk1fl1.tiff \n", + "650 r03c09f03p01-ch5sk1fk1fl1.tiff \n", + "2064 r04c22f04p01-ch6sk1fk1fl1.tiff \n", + "\n", + " PathName_OrigDNA ... \\\n", + "2079 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "668 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "2073 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "1113 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "788 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "... ... ... \n", + "1730 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "196 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "367 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "650 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "2064 /home/weishanli/Waylab/ALSF_pilot/data/ALSF_pi... ... \n", + "\n", + " Metadata_AbsPositionZ Metadata_ChannelID Metadata_Col Metadata_FieldID \\\n", + "2079 0.134358 6 22 1 \n", + "668 0.134405 6 9 3 \n", + "2073 0.134366 6 22 4 \n", + "1113 0.134347 6 13 7 \n", + "788 0.134381 6 10 6 \n", + "... ... ... ... ... \n", + "1730 0.134366 6 19 3 \n", + "196 0.134491 6 4 8 \n", + "367 0.134447 6 6 8 \n", + "650 0.134428 6 9 3 \n", + "2064 0.134379 6 22 4 \n", + "\n", + " Metadata_PlaneID Metadata_PositionX Metadata_PositionY \\\n", + "2079 1 0.000000 0.000000 \n", + "668 1 0.000000 0.000646 \n", + "2073 1 0.000646 0.000646 \n", + "1113 1 -0.000646 -0.000646 \n", + "788 1 -0.000646 0.000000 \n", + "... ... ... ... \n", + "1730 1 0.000000 0.000646 \n", + "196 1 0.000000 -0.000646 \n", + "367 1 0.000000 -0.000646 \n", + "650 1 0.000000 0.000646 \n", + "2064 1 0.000646 0.000646 \n", + "\n", + " Metadata_PositionZ Metadata_Row Metadata_Reimaged \n", + "2079 -0.000006 6 False \n", + "668 -0.000006 5 False \n", + "2073 -0.000006 5 False \n", + "1113 -0.000006 6 False \n", + "788 -0.000006 6 False \n", + "... ... ... ... \n", + "1730 -0.000004 3 True \n", + "196 -0.000004 12 True \n", + "367 -0.000006 7 False \n", + "650 -0.000006 3 False \n", + "2064 -0.000004 4 True \n", + "\n", + "[100 rows x 25 columns]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaddata_df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "EXAMPLE_PATCH_DATA_EXPORT_PATH = pathlib.Path('.').absolute().parent.parent / 'example_patch_data'\n", + "EXAMPLE_PATCH_DATA_EXPORT_PATH.mkdir(exist_ok=True)\n", + "INPUT_EXPORT_PATH = EXAMPLE_PATCH_DATA_EXPORT_PATH / input_channel_name\n", + "INPUT_EXPORT_PATH.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-02-20 00:15:47,252 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 00:15:47,252 - DEBUG - Set target channel(s) as ['OrigDNA']\n", + "2025-02-20 00:15:47,475 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 00:15:47,475 - DEBUG - Set target channel(s) as ['OrigER']\n", + "2025-02-20 00:15:47,676 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 00:15:47,677 - DEBUG - Set target channel(s) as ['OrigMito']\n", + "2025-02-20 00:15:47,850 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 00:15:47,850 - DEBUG - Set target channel(s) as ['OrigRNA']\n", + "2025-02-20 00:15:48,034 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 00:15:48,035 - DEBUG - Set target channel(s) as ['OrigAGP']\n" + ] + } + ], + "source": [ + "for j, channel_name in enumerate(target_channel_names):\n", + "\n", + " pds.set_input_channel_keys([input_channel_name])\n", + " pds.set_target_channel_keys([channel_name])\n", + "\n", + " CHANNEL_EXPORT_PATH = EXAMPLE_PATCH_DATA_EXPORT_PATH / channel_name\n", + " CHANNEL_EXPORT_PATH.mkdir(exist_ok=True)\n", + "\n", + " for i in range(len(pds)):\n", + " input, target = pds[i]\n", + " input_name = pds.input_names\n", + " target_name = pds.target_names\n", + " patch_coord = pds.patch_coords\n", + "\n", + " if j == 0:\n", + " imageio.imwrite(\n", + " INPUT_EXPORT_PATH / f'{input_name[0].stem}_{patch_coord[0]}_{patch_coord[1]}.tiff', \n", + " input[0].numpy().astype(np.uint16))\n", + "\n", + " imageio.imwrite(\n", + " CHANNEL_EXPORT_PATH / f'{target_name[0].stem}_{patch_coord[0]}_{patch_coord[1]}.tiff', \n", + " target[0].numpy().astype(np.uint16))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "speckle_analysis", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/trainers/Trainer.py b/trainers/Trainer.py index ce5e707..5ebee41 100644 --- a/trainers/Trainer.py +++ b/trainers/Trainer.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Optional, List +from typing import Optional, List, Union import torch from torch.utils.data import DataLoader, random_split @@ -14,7 +14,7 @@ def __init__( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, - backprop_loss: torch.nn.Module | List[torch.nn.Module], + backprop_loss: Union[torch.nn.Module, List[torch.nn.Module]], # rest of the arguments are passed to and handled by the parent class # - dataset # - batch_size From d57d80a477a41ba4cdde8314f3237d84077bc35f Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 20 Feb 2025 10:26:47 -0700 Subject: [PATCH 27/89] Re-ran example notebook --- examples/minimal_example.ipynb | 206 +++++++++++++++------------------ 1 file changed, 95 insertions(+), 111 deletions(-) diff --git a/examples/minimal_example.ipynb b/examples/minimal_example.ipynb index b000f23..f968a70 100644 --- a/examples/minimal_example.ipynb +++ b/examples/minimal_example.ipynb @@ -20,14 +20,6 @@ "text": [ "/home/weishanli/Waylab\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/weishanli/anaconda3/envs/speckle_analysis/lib/python3.11/site-packages/albumentations/__init__.py:28: UserWarning: A new version of Albumentations is available: '2.0.4' (you have '2.0.1'). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n", - " check_for_updates()\n" - ] } ], "source": [ @@ -325,27 +317,27 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-02-18 10:15:52,796 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n", - "2025-02-18 10:15:52,797 - DEBUG - Dataframe supplied for sc_feature, using as is\n", - "2025-02-18 10:15:52,797 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n", - "2025-02-18 10:15:52,797 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n", - "2025-02-18 10:15:52,797 - DEBUG - Merge fields inferred: ['Metadata_Plate', 'Metadata_Site', 'Metadata_Well']\n", - "2025-02-18 10:15:52,797 - DEBUG - Dataframe supplied for sc_feature, using as is\n", - "2025-02-18 10:15:52,820 - DEBUG - Inferring channel keys from loaddata csv\n", - "2025-02-18 10:15:52,820 - DEBUG - Channel keys: {'OrigMito', 'OrigRNA', 'OrigDNA', 'OrigAGP', 'OrigER', 'OrigBrightfield'} inferred from loaddata csv\n", - "2025-02-18 10:15:52,820 - DEBUG - Setting input channel(s) ...\n", - "2025-02-18 10:15:52,821 - DEBUG - No channel keys specified, skip\n", - "2025-02-18 10:15:52,821 - DEBUG - Setting target channel(s) ...\n", - "2025-02-18 10:15:52,821 - DEBUG - No channel keys specified, skip\n", - "2025-02-18 10:15:52,821 - DEBUG - Setting input transform ...\n", - "2025-02-18 10:15:52,822 - DEBUG - Setting target transform ...\n", - "2025-02-18 10:15:52,822 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n", - "2025-02-18 10:15:52,840 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n", - "2025-02-18 10:15:52,840 - DEBUG - Generating patches that contain cells\n", - "2025-02-18 10:15:52,858 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n", - "2025-02-18 10:15:53,246 - DEBUG - Generated 461 patches for 93 site/view\n", - "2025-02-18 10:15:53,246 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", - "2025-02-18 10:15:53,246 - DEBUG - Set target channel(s) as ['OrigDNA']\n" + "2025-02-20 10:22:43,813 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n", + "2025-02-20 10:22:43,813 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-02-20 10:22:43,813 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n", + "2025-02-20 10:22:43,813 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n", + "2025-02-20 10:22:43,813 - DEBUG - Merge fields inferred: ['Metadata_Site', 'Metadata_Well', 'Metadata_Plate']\n", + "2025-02-20 10:22:43,813 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-02-20 10:22:43,850 - DEBUG - Inferring channel keys from loaddata csv\n", + "2025-02-20 10:22:43,851 - DEBUG - Channel keys: {'OrigER', 'OrigMito', 'OrigBrightfield', 'OrigRNA', 'OrigDNA', 'OrigAGP'} inferred from loaddata csv\n", + "2025-02-20 10:22:43,851 - DEBUG - Setting input channel(s) ...\n", + "2025-02-20 10:22:43,851 - DEBUG - No channel keys specified, skip\n", + "2025-02-20 10:22:43,851 - DEBUG - Setting target channel(s) ...\n", + "2025-02-20 10:22:43,851 - DEBUG - No channel keys specified, skip\n", + "2025-02-20 10:22:43,851 - DEBUG - Setting input transform ...\n", + "2025-02-20 10:22:43,851 - DEBUG - Setting target transform ...\n", + "2025-02-20 10:22:43,851 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n", + "2025-02-20 10:22:43,875 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n", + "2025-02-20 10:22:43,875 - DEBUG - Generating patches that contain cells\n", + "2025-02-20 10:22:43,899 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n", + "2025-02-20 10:22:44,318 - DEBUG - Generated 461 patches for 93 site/view\n", + "2025-02-20 10:22:44,319 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 10:22:44,319 - DEBUG - Set target channel(s) as ['OrigDNA']\n" ] } ], @@ -453,102 +445,102 @@ " \n", " 0\n", " 1\n", - " 0.456883\n", - " 0.474489\n", - " 6.632703\n", - " 0.018613\n", - " 6.418595\n", - " 0.026234\n", + " 0.478871\n", + " 0.458065\n", + " 6.329101\n", + " 0.028812\n", + " 6.742533\n", + " 0.032114\n", " \n", " \n", " 1\n", " 2\n", - " 0.400776\n", - " 0.453664\n", - " 7.765750\n", - " 0.025268\n", - " 6.804714\n", - " 0.027302\n", + " 0.428279\n", + " 0.438360\n", + " 7.297186\n", + " 0.043409\n", + " 7.122532\n", + " 0.033492\n", " \n", " \n", " 2\n", " 3\n", - " 0.358529\n", - " 0.414073\n", - " 8.702980\n", - " 0.024505\n", - " 7.590396\n", - " 0.029586\n", + " 0.382448\n", + " 0.412966\n", + " 8.268170\n", + " 0.043539\n", + " 7.640979\n", + " 0.035500\n", " \n", " \n", " 3\n", " 4\n", - " 0.304907\n", - " 0.343430\n", - " 10.004862\n", - " 0.034862\n", - " 9.206196\n", - " 0.034965\n", + " 0.343706\n", + " 0.396219\n", + " 9.204294\n", + " 0.046435\n", + " 8.010185\n", + " 0.038775\n", " \n", " \n", " 4\n", " 5\n", - " 0.272098\n", - " 0.276465\n", - " 10.939101\n", - " 0.042278\n", - " 11.154539\n", - " 0.041764\n", + " 0.294751\n", + " 0.368890\n", + " 10.388447\n", + " 0.056670\n", + " 8.629682\n", + " 0.039226\n", " \n", " \n", " 5\n", " 6\n", - " 0.232606\n", - " 0.226715\n", - " 12.247836\n", - " 0.047612\n", - " 12.796427\n", - " 0.051223\n", + " 0.259627\n", + " 0.291441\n", + " 11.531624\n", + " 0.062310\n", + " 10.704532\n", + " 0.037991\n", " \n", " \n", " 6\n", " 7\n", - " 0.192591\n", - " 0.203203\n", - " 13.729950\n", - " 0.087152\n", - " 13.635719\n", - " 0.051709\n", + " 0.227637\n", + " 0.246790\n", + " 12.642867\n", + " 0.074122\n", + " 12.188992\n", + " 0.058526\n", " \n", " \n", " 7\n", " 8\n", - " 0.165829\n", - " 0.173625\n", - " 15.012578\n", - " 0.091547\n", - " 14.903670\n", - " 0.072398\n", + " 0.199925\n", + " 0.207696\n", + " 13.745816\n", + " 0.075425\n", + " 13.466536\n", + " 0.055031\n", " \n", " \n", " 8\n", " 9\n", - " 0.141518\n", - " 0.136609\n", - " 16.231110\n", - " 0.093675\n", - " 16.804295\n", - " 0.080310\n", + " 0.164005\n", + " 0.157645\n", + " 15.256923\n", + " 0.090092\n", + " 15.916401\n", + " 0.080869\n", " \n", " \n", " 9\n", " 10\n", - " 0.122126\n", - " 0.129176\n", - " 17.322950\n", - " 0.115414\n", - " 17.221750\n", - " 0.093323\n", + " 0.141493\n", + " 0.142309\n", + " 16.464945\n", + " 0.103618\n", + " 16.622252\n", + " 0.086847\n", " \n", " \n", "\n", @@ -556,16 +548,16 @@ ], "text/plain": [ " epoch L1Loss val_L1Loss psnr ssim val_psnr val_ssim\n", - "0 1 0.456883 0.474489 6.632703 0.018613 6.418595 0.026234\n", - "1 2 0.400776 0.453664 7.765750 0.025268 6.804714 0.027302\n", - "2 3 0.358529 0.414073 8.702980 0.024505 7.590396 0.029586\n", - "3 4 0.304907 0.343430 10.004862 0.034862 9.206196 0.034965\n", - "4 5 0.272098 0.276465 10.939101 0.042278 11.154539 0.041764\n", - "5 6 0.232606 0.226715 12.247836 0.047612 12.796427 0.051223\n", - "6 7 0.192591 0.203203 13.729950 0.087152 13.635719 0.051709\n", - "7 8 0.165829 0.173625 15.012578 0.091547 14.903670 0.072398\n", - "8 9 0.141518 0.136609 16.231110 0.093675 16.804295 0.080310\n", - "9 10 0.122126 0.129176 17.322950 0.115414 17.221750 0.093323" + "0 1 0.478871 0.458065 6.329101 0.028812 6.742533 0.032114\n", + "1 2 0.428279 0.438360 7.297186 0.043409 7.122532 0.033492\n", + "2 3 0.382448 0.412966 8.268170 0.043539 7.640979 0.035500\n", + "3 4 0.343706 0.396219 9.204294 0.046435 8.010185 0.038775\n", + "4 5 0.294751 0.368890 10.388447 0.056670 8.629682 0.039226\n", + "5 6 0.259627 0.291441 11.531624 0.062310 10.704532 0.037991\n", + "6 7 0.227637 0.246790 12.642867 0.074122 12.188992 0.058526\n", + "7 8 0.199925 0.207696 13.745816 0.075425 13.466536 0.055031\n", + "8 9 0.164005 0.157645 15.256923 0.090092 15.916401 0.080869\n", + "9 10 0.141493 0.142309 16.464945 0.103618 16.622252 0.086847" ] }, "execution_count": 10, @@ -593,7 +585,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Early termination at epoch 6 with best validation metric 6.414916038513184\n" + "Early termination at epoch 6 with best validation metric 6.6218791007995605\n" ] } ], @@ -764,15 +756,7 @@ "cell_type": "code", "execution_count": 14, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Early termination at epoch 9 with best validation metric 0.21918950974941254\n" - ] - } - ], + "outputs": [], "source": [ "generator = UNet(\n", " n_channels=1,\n", @@ -848,7 +832,7 @@ ], "metadata": { "kernelspec": { - "display_name": "speckle_analysis", + "display_name": "cp_gan_env", "language": "python", "name": "python3" }, @@ -862,7 +846,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.9.21" } }, "nbformat": 4, From 09b63d5624f9031e80eca954f1344bc930cfc4c9 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 20 Feb 2025 12:11:49 -0700 Subject: [PATCH 28/89] Added dataset class that does not rely on pe2loaddata generated file index and one example that make uses of it --- datasets/GenericImageDataset.py | 299 ++++++++ .../minimal_example_generic_dataset.ipynb | 669 ++++++++++++++++++ 2 files changed, 968 insertions(+) create mode 100644 datasets/GenericImageDataset.py create mode 100644 examples/minimal_example_generic_dataset.ipynb diff --git a/datasets/GenericImageDataset.py b/datasets/GenericImageDataset.py new file mode 100644 index 0000000..71713f7 --- /dev/null +++ b/datasets/GenericImageDataset.py @@ -0,0 +1,299 @@ +import logging +import pathlib +import re +from collections import defaultdict +from typing import List, Optional, Union, Tuple, Dict + +import numpy as np +import torch +from PIL import Image +from albumentations import ImageOnlyTransform +from albumentations.core.composition import Compose +from torch.utils.data import Dataset + + +class GenericImageDataset(Dataset): + """ + A generic image dataset that automatically associates images under a supplied path + with sites and channels based on two separate regex patterns for site and channel detection. + """ + + def __init__( + self, + image_dir: Union[str, pathlib.Path], + site_pattern: str, + channel_pattern: str, + _input_channel_keys: Optional[Union[str, List[str]]] = None, + _target_channel_keys: Optional[Union[str, List[str]]] = None, + _input_transform: Optional[Union[Compose, ImageOnlyTransform]] = None, + _target_transform: Optional[Union[Compose, ImageOnlyTransform]] = None, + _PIL_image_mode: str = 'I;16', + verbose: bool = False, + check_exists: bool = True, + **kwargs + ): + """ + Initialize the dataset. + + :param image_dir: Directory containing the images. + :param site_pattern: Regex pattern to extract site identifiers. + :param channel_pattern: Regex pattern to extract channel identifiers. + :param _input_channel_keys: List of channel names to use as inputs. + :param _target_channel_keys: List of channel names to use as targets. + :param _input_transform: Transformations to apply to input images. + :param _target_transform: Transformations to apply to target images. + :param _PIL_image_mode: Mode for loading images. + :param check_exists: Whether to check if all referenced image files exist. + """ + + self._initialize_logger(verbose) + self.image_dir = pathlib.Path(image_dir).resolve() + self.site_pattern = re.compile(site_pattern) + self.channel_pattern = re.compile(channel_pattern) + self._PIL_image_mode = _PIL_image_mode + + if not self.image_dir.exists(): + raise FileNotFoundError(f"Image directory {self.image_dir} not found") + + # Parse images and organize by site + self._channel_keys = [] + self.__image_paths = self._get_image_paths(check_exists) + + # Set input and target channel keys + self._input_channel_keys = self.__check_channel_keys(_input_channel_keys) + self._target_channel_keys = self.__check_channel_keys(_target_channel_keys) + + self.set_input_transform(_input_transform) + self.set_target_transform(_target_transform) + + # Index patches and images + self.__iter_image_id = list(range(len(self.__image_paths))) + + # Initialize cache + self.__input_cache = {} + self.__target_cache = {} + self.__cache_image_id = None + + # Initialize the current input and target names + self.__current_input_names = None + self.__current_target_names = None + + """ + Properties + """ + + @property + def image_paths(self): + return self.__image_paths + + @property + def input_transform(self): + return self._input_transform + + @property + def target_transform(self): + return self._target_transform + + @property + def input_channel_keys(self): + return self._input_channel_keys + + @property + def target_channel_keys(self): + return self._target_channel_keys + @property + def input_names(self): + return self.__current_input_names + + @property + def target_names(self): + return self.__current_target_names + + """ + Setters + """ + + def set_input_transform(self, _input_transform: Optional[Union[Compose, ImageOnlyTransform]] = None): + """Sets the input image transform.""" + self.logger.debug("Setting input transform ...") + self._input_transform = _input_transform + + def set_target_transform(self, _target_transform: Optional[Union[Compose, ImageOnlyTransform]] = None): + """Sets the target image transform.""" + self.logger.debug("Setting target transform ...") + self._target_transform = _target_transform + + def set_input_channel_keys(self, _input_channel_keys: Union[str, List[str]]): + """ + Set the input channel keys + + :param _input_channel_keys: The input channel keys + :type _input_channel_keys: str or list of str + """ + self._input_channel_keys = self.__check_channel_keys(_input_channel_keys) + self.logger.debug(f"Set input channel(s) as {self._input_channel_keys}") + + # clear the cache + self.__cache_image_id = None + + def set_target_channel_keys(self, _target_channel_keys: Union[str, List[str]]): + """ + Set the target channel keys + + :param _target_channel_keys: The target channel keys + :type _target_channel_keys: str or list of str + """ + self._target_channel_keys = self.__check_channel_keys(_target_channel_keys) + self.logger.debug(f"Set target channel(s) as {self._target_channel_keys}") + + # clear the cache + self.__cache_image_id = None + + """ + Logging and Debugging + """ + + def _initialize_logger(self, verbose: bool): + """Initializes the logger.""" + self.logger = logging.getLogger(f"{__name__}.{id(self)}") + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.DEBUG if verbose else logging.WARNING) + + """ + Internal helper functions + """ + + def _get_image_paths(self, check_exists: bool): + + # sets for all unique sites and channels + sites = set() + channels = set() + image_files = list(self.image_dir.glob("*")) + + site_to_channels = defaultdict(dict) + for file in image_files: + site_match = self.site_pattern.search(file.name) + try: + site = site_match.group(1) + except: + continue + sites.add(site) + + channel_match = self.channel_pattern.search(file.name) + try: + channel = channel_match.group(1) + except: + continue + channels.add(channel) + + site_to_channels[site][channel] = file + + # format as list of dicts + image_paths = [] + for site, channel_to_file in site_to_channels.items(): + ## Keep only sites with all channels + if all([c in site_to_channels[site] for c in channels]): + if check_exists and not all(path.exists() for path in channel_to_file.values()): + continue + image_paths.append(channel_to_file) + + self.logger.debug(f"Channel keys: {channels} detected") + self._channel_keys = list(channels) + + return image_paths + + def __len__(self): + return len(self.__image_paths) + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Retrieves the input and target images for a given index. + + :param idx: The index of the image. + :return: Tuple of input and target images as tensors. + """ + if idx >= len(self) or idx < 0: + raise IndexError("Index out of bounds") + + site_id = self.__iter_image_id[idx] + self._cache_image(site_id) + + # Stack input and target images + input_images = np.stack([self.__input_cache[key] for key in self._input_channel_keys], axis=0) + target_images = np.stack([self.__target_cache[key] for key in self._target_channel_keys], axis=0) + + # Apply transformations + if self._input_transform: + input_images = self._input_transform(image=input_images)['image'] + if self._target_transform: + target_images = self._target_transform(image=target_images)['image'] + + return torch.from_numpy(input_images).float(), torch.from_numpy(target_images).float() + + def _cache_image(self, site_id: str) -> None: + """ + Loads and caches images for a given site ID. + + :param site_id: The site ID. + """ + if self.__cache_image_id != site_id: + self.__cache_image_id = site_id + self.__input_cache = {} + self.__target_cache = {} + + ## Update target and input names (which are just file path(s)) + self.__current_input_names = [self.__image_paths[site_id][key] for key in self._input_channel_keys] + self.__current_target_names = [self.__image_paths[site_id][key] for key in self._target_channel_keys] + + for key in self._input_channel_keys: + self.__input_cache[key] = self._read_convert_image(self.__image_paths[site_id][key]) + for key in self._target_channel_keys: + self.__target_cache[key] = self._read_convert_image(self.__image_paths[site_id][key]) + + def _read_convert_image(self, image_path: pathlib.Path) -> np.ndarray: + """ + Reads and converts an image to a numpy array. + + :param image_path: The image file path. + :return: The image as a numpy array. + """ + return np.array(Image.open(image_path).convert(self._PIL_image_mode)) + + def __check_channel_keys( + self, + channel_keys: Optional[Union[str, List[str]]] + ) -> List[str]: + """ + Checks user supplied channel key against the inferred ones from the file + + :param channel_keys: user supplied list or single object of string channel keys + :type channel_keys: string or list of strings + """ + if channel_keys is None: + self.logger.debug("No channel keys specified, skip") + return None + elif isinstance(channel_keys, str): + channel_keys = [channel_keys] + elif isinstance(channel_keys, list): + if not all([isinstance(key, str) for key in channel_keys]): + raise ValueError('Channel keys must be a string or a list of strings.') + else: + raise ValueError('Channel keys must be a string or a list of strings.') + + ## Check supplied channel keys against inferred ones + filtered_channel_keys = [] + for key in channel_keys: + if not key in self._channel_keys: + self.logger.debug( + f"ignoring channel key {key} as it does not match loaddata csv file" + ) + else: + filtered_channel_keys.append(key) + + if len(filtered_channel_keys) == 0: + raise ValueError(f'None of the supplied channel keys match the loaddata csv file') + + return filtered_channel_keys \ No newline at end of file diff --git a/examples/minimal_example_generic_dataset.ipynb b/examples/minimal_example_generic_dataset.ipynb new file mode 100644 index 0000000..16a05f8 --- /dev/null +++ b/examples/minimal_example_generic_dataset.ipynb @@ -0,0 +1,669 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A minimal example to demonstrate how the trainer for FNet and wGaN GP plus the callbacks works along with patched dataset\n", + "\n", + "Is dependent on the files produced by 1.illumination_correction/0.create_loaddata_csvs ALSF pilot data repo https://github.com/WayScience/pediatric_cancer_atlas_profiling" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/weishanli/Waylab\n" + ] + } + ], + "source": [ + "import sys\n", + "import pathlib\n", + "\n", + "import pandas as pd\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "\n", + "sys.path.append(str(pathlib.Path('.').absolute().parent.parent))\n", + "print(str(pathlib.Path('.').absolute().parent.parent))\n", + "\n", + "## Dataset\n", + "from virtual_stain_flow.datasets.GenericImageDataset import GenericImageDataset\n", + "from virtual_stain_flow.datasets.CachedDataset import CachedDataset\n", + "\n", + "## FNet training\n", + "from virtual_stain_flow.models.fnet import FNet\n", + "from virtual_stain_flow.trainers.Trainer import Trainer\n", + "\n", + "## wGaN training\n", + "from virtual_stain_flow.models.unet import UNet\n", + "from virtual_stain_flow.models.discriminator import GlobalDiscriminator\n", + "from virtual_stain_flow.trainers.WGaNTrainer import WGaNTrainer\n", + "\n", + "## wGaN losses\n", + "from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss\n", + "from virtual_stain_flow.losses.DiscriminatorLoss import DiscriminatorLoss\n", + "from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss\n", + "\n", + "from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize\n", + "\n", + "## Metrics\n", + "from virtual_stain_flow.metrics.MetricsWrapper import MetricsWrapper\n", + "from virtual_stain_flow.metrics.PSNR import PSNR\n", + "from virtual_stain_flow.metrics.SSIM import SSIM\n", + "\n", + "## callback\n", + "from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger\n", + "from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePatchPlot\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specify train data and output paths" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "EXAMPLE_PATCH_DATA_EXPORT_PATH = '/REPLACE/WITH/PATH/TO/DATA'\n", + "\n", + "EXAMPLE_DIR = pathlib.Path('.').absolute() / 'example_train_generic_dataset'\n", + "EXAMPLE_DIR.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "!rm -rf example_train_generic_dataset/*" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "PLOT_DIR = EXAMPLE_DIR / 'plot'\n", + "PLOT_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "MLFLOW_DIR =EXAMPLE_DIR / 'mlflow'\n", + "MLFLOW_DIR.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure channels" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "channel_names = [\n", + " \"OrigBrightfield\",\n", + " \"OrigDNA\",\n", + " \"OrigER\",\n", + " \"OrigMito\",\n", + " \"OrigRNA\",\n", + " \"OrigAGP\",\n", + "]\n", + "input_channel_name = \"OrigBrightfield\"\n", + "target_channel_names = [ch for ch in channel_names if ch != input_channel_name]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prep Patch dataset and Cache" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-02-20 12:06:27,871 - DEBUG - Channel keys: {'OrigRNA', 'OrigAGP', 'OrigER', 'OrigBrightfield', 'OrigDNA', 'OrigMito'} detected\n", + "2025-02-20 12:06:27,871 - DEBUG - No channel keys specified, skip\n", + "2025-02-20 12:06:27,871 - DEBUG - No channel keys specified, skip\n", + "2025-02-20 12:06:27,872 - DEBUG - Setting input transform ...\n", + "2025-02-20 12:06:27,872 - DEBUG - Setting target transform ...\n", + "2025-02-20 12:06:27,872 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-02-20 12:06:27,872 - DEBUG - Set target channel(s) as ['OrigDNA']\n" + ] + } + ], + "source": [ + "pds = GenericImageDataset(\n", + " image_dir=EXAMPLE_PATCH_DATA_EXPORT_PATH,\n", + " site_pattern=r\"^([^_]+_[^_]+_[^_]+)\",\n", + " channel_pattern=r\"_([^_]+)\\.tiff$\",\n", + " verbose=True\n", + ")\n", + "\n", + "## Set input and target channels\n", + "pds.set_input_channel_keys([input_channel_name])\n", + "pds.set_target_channel_keys('OrigDNA')\n", + "\n", + "## Cache for faster training \n", + "cds = CachedDataset(\n", + " pds,\n", + " prefill_cache=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FNet trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model without callback and check logs" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "model = FNet(depth=4)\n", + "lr = 3e-4\n", + "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=None,\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda'\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochL1Lossval_L1Losspsnrssimval_psnrval_ssim
011492.6101751521.866577-69.927834-5.388215e-09-70.2225659.050420e-11
121480.9667151521.779175-69.699051-8.186153e-09-70.222458-5.678337e-10
231761.0650091521.689453-70.673164-5.811710e-09-70.222351-1.330130e-09
341517.1537811521.613159-70.176323-7.637948e-09-70.222244-1.183239e-09
451555.3652071521.635498-70.198990-7.835469e-09-70.222260-7.694487e-09
561719.5259741521.537537-70.295540-5.934725e-09-70.222168-8.564919e-09
671489.4727381521.481323-69.753349-8.170694e-09-70.222076-1.564081e-09
781703.7644181521.477356-70.690598-4.165215e-09-70.222092-1.391449e-08
891656.3194851521.448120-70.481544-4.851262e-09-70.222031-1.076659e-09
9101514.4575061521.436646-70.083336-2.934100e-09-70.222023-1.150742e-09
\n", + "
" + ], + "text/plain": [ + " epoch L1Loss val_L1Loss psnr ssim val_psnr \\\n", + "0 1 1492.610175 1521.866577 -69.927834 -5.388215e-09 -70.222565 \n", + "1 2 1480.966715 1521.779175 -69.699051 -8.186153e-09 -70.222458 \n", + "2 3 1761.065009 1521.689453 -70.673164 -5.811710e-09 -70.222351 \n", + "3 4 1517.153781 1521.613159 -70.176323 -7.637948e-09 -70.222244 \n", + "4 5 1555.365207 1521.635498 -70.198990 -7.835469e-09 -70.222260 \n", + "5 6 1719.525974 1521.537537 -70.295540 -5.934725e-09 -70.222168 \n", + "6 7 1489.472738 1521.481323 -69.753349 -8.170694e-09 -70.222076 \n", + "7 8 1703.764418 1521.477356 -70.690598 -4.165215e-09 -70.222092 \n", + "8 9 1656.319485 1521.448120 -70.481544 -4.851262e-09 -70.222031 \n", + "9 10 1514.457506 1521.436646 -70.083336 -2.934100e-09 -70.222023 \n", + "\n", + " val_ssim \n", + "0 9.050420e-11 \n", + "1 -5.678337e-10 \n", + "2 -1.330130e-09 \n", + "3 -1.183239e-09 \n", + "4 -7.694487e-09 \n", + "5 -8.564919e-09 \n", + "6 -1.564081e-09 \n", + "7 -1.391449e-08 \n", + "8 -1.076659e-09 \n", + "9 -1.150742e-09 " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame(trainer.log)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model with alternative early termination metric" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Early termination at epoch 6 with best validation metric -70.41797637939453\n" + ] + } + ], + "source": [ + "model = FNet(depth=4)\n", + "lr = 3e-4\n", + "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=None,\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda',\n", + " early_termination_metric = 'psnr' # set early termination metric as psnr for the sake of demonstration\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train with mlflow logger callbacks" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'lr': 3e-4\n", + " },\n", + " )\n", + "\n", + "del trainer\n", + "\n", + "trainer = Trainer(\n", + " model = model,\n", + " optimizer = optimizer,\n", + " backprop_loss = nn.L1Loss(),\n", + " dataset = cds,\n", + " batch_size = 16,\n", + " epochs = 10,\n", + " patience = 5,\n", + " callbacks=[mlflow_logger_callback],\n", + " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", + " device = 'cuda'\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# wGaN GP example with mlflow logger callback and plot callback" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "generator = UNet(\n", + " n_channels=1,\n", + " n_classes=1\n", + ")\n", + "\n", + "discriminator = GlobalDiscriminator(\n", + " n_in_channels = 2,\n", + " n_in_filters = 64,\n", + " _conv_depth = 4,\n", + " _pool_before_fc = True\n", + ")\n", + "\n", + "generator_optimizer = optim.Adam(generator.parameters(), \n", + " lr=0.0002, \n", + " betas=(0., 0.9))\n", + "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n", + " lr=0.00002, \n", + " betas=(0., 0.9),\n", + " weight_decay=0.001)\n", + "\n", + "gp_loss = GradientPenaltyLoss(\n", + " _metric_name='gp_loss',\n", + " discriminator=discriminator,\n", + " weight=10.0,\n", + ")\n", + "\n", + "gen_loss = GeneratorLoss(\n", + " _metric_name='gen_loss'\n", + ")\n", + "\n", + "disc_loss = DiscriminatorLoss(\n", + " _metric_name='disc_loss'\n", + ")\n", + "\n", + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train_wgan', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'gen_lr': 0.0002,\n", + " 'disc_lr': 0.00002\n", + " },\n", + " )\n", + "\n", + "wgan_trainer = WGaNTrainer(\n", + " dataset=cds,\n", + " batch_size=16,\n", + " epochs=20,\n", + " patience=20, # setting this to prevent unwanted early termination here\n", + " device='cuda',\n", + " generator=generator,\n", + " discriminator=discriminator,\n", + " gen_optimizer=generator_optimizer,\n", + " disc_optimizer=discriminator_optimizer,\n", + " generator_loss_fn=gen_loss,\n", + " discriminator_loss_fn=disc_loss,\n", + " gradient_penalty_fn=gp_loss,\n", + " discriminator_update_freq=1,\n", + " generator_update_freq=2,\n", + " callbacks=[mlflow_logger_callback],\n", + " metrics={'ssim': SSIM(_metric_name='ssim'), \n", + " 'psnr': PSNR(_metric_name='psnr')},\n", + ")\n", + "\n", + "wgan_trainer.train()\n", + "\n", + "del generator\n", + "del wgan_trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## # wGaN GP example with mlflow logger callback and alternative early termination loss" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "generator = UNet(\n", + " n_channels=1,\n", + " n_classes=1\n", + ")\n", + "\n", + "discriminator = GlobalDiscriminator(\n", + " n_in_channels = 2,\n", + " n_in_filters = 64,\n", + " _conv_depth = 4,\n", + " _pool_before_fc = True\n", + ")\n", + "\n", + "generator_optimizer = optim.Adam(generator.parameters(), \n", + " lr=0.0002, \n", + " betas=(0., 0.9))\n", + "discriminator_optimizer = optim.Adam(discriminator.parameters(), \n", + " lr=0.00002, \n", + " betas=(0., 0.9),\n", + " weight_decay=0.001)\n", + "\n", + "gp_loss = GradientPenaltyLoss(\n", + " _metric_name='gp_loss',\n", + " discriminator=discriminator,\n", + " weight=10.0,\n", + ")\n", + "\n", + "gen_loss = GeneratorLoss(\n", + " _metric_name='gen_loss'\n", + ")\n", + "\n", + "disc_loss = DiscriminatorLoss(\n", + " _metric_name='disc_loss'\n", + ")\n", + "\n", + "mlflow_logger_callback = MlflowLogger(\n", + " name='mlflow_logger',\n", + " mlflow_uri=MLFLOW_DIR / 'mlruns',\n", + " mlflow_experiment_name='Default',\n", + " mlflow_start_run_args={'run_name': 'example_train_wgan_mae_early_term', 'nested': True},\n", + " mlflow_log_params_args={\n", + " 'gen_lr': 0.0002,\n", + " 'disc_lr': 0.00002\n", + " },\n", + " )\n", + "\n", + "wgan_trainer = WGaNTrainer(\n", + " dataset=cds,\n", + " batch_size=16,\n", + " epochs=20,\n", + " patience=5, # lower patience here\n", + " device='cuda',\n", + " generator=generator,\n", + " discriminator=discriminator,\n", + " gen_optimizer=generator_optimizer,\n", + " disc_optimizer=discriminator_optimizer,\n", + " generator_loss_fn=gen_loss,\n", + " discriminator_loss_fn=disc_loss,\n", + " gradient_penalty_fn=gp_loss,\n", + " discriminator_update_freq=1,\n", + " generator_update_freq=2,\n", + " callbacks=[mlflow_logger_callback],\n", + " metrics={'ssim': SSIM(_metric_name='ssim'), \n", + " 'psnr': PSNR(_metric_name='psnr'),\n", + " 'mae': MetricsWrapper(_metric_name='mae', module=nn.L1Loss()) # use a wrapper for torch nn L1Loss\n", + " },\n", + " early_termination_metric = 'mae' # update early temrination loss with the supplied L1Loss/mae metric instead of the default GaN generator loss\n", + ")\n", + "\n", + "wgan_trainer.train()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cp_gan_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 1d63fadc6e52a63228faa8ea1d06e0baed6e7105 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 20 Feb 2025 12:12:15 -0700 Subject: [PATCH 29/89] Modified gitignore to ignore produced files under additional example training --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index b293968..2201fd5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ # images and anything under mlflow *.png -examples/example_train/* +examples/example_train*/* # pycache *.pyc \ No newline at end of file From 6568d01d67633fa109105c815a746ecf2f0abd07 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 20 Feb 2025 12:16:50 -0700 Subject: [PATCH 30/89] Modified notebook description --- examples/minimal_example_generic_dataset.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/minimal_example_generic_dataset.ipynb b/examples/minimal_example_generic_dataset.ipynb index 16a05f8..ba6da8e 100644 --- a/examples/minimal_example_generic_dataset.ipynb +++ b/examples/minimal_example_generic_dataset.ipynb @@ -6,7 +6,7 @@ "source": [ "# A minimal example to demonstrate how the trainer for FNet and wGaN GP plus the callbacks works along with patched dataset\n", "\n", - "Is dependent on the files produced by 1.illumination_correction/0.create_loaddata_csvs ALSF pilot data repo https://github.com/WayScience/pediatric_cancer_atlas_profiling" + "Is will not be dependent on the pe2loaddata generated index file from the ALSF pilot data repo unlike the other example notebook" ] }, { From 755d162576b0eee54af81bfe4da6cd8dce4af8dd Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Tue, 25 Feb 2025 21:42:27 -0700 Subject: [PATCH 31/89] Added helper function that computes metrics on a per image basis given predicted and target tensors of images. --- evaluation/evaluation_utils.py | 35 ++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/evaluation/evaluation_utils.py b/evaluation/evaluation_utils.py index dba1e71..accb77e 100644 --- a/evaluation/evaluation_utils.py +++ b/evaluation/evaluation_utils.py @@ -1,8 +1,9 @@ from collections import defaultdict -from typing import List, Callable, Union +from typing import List, Dict, Callable, Union import pandas as pd import torch +from torch.nn import Module from torch.utils.data import DataLoader def evaluate_metrics( @@ -25,4 +26,34 @@ def evaluate_metrics( for _metric in _metrics: metrics[_metric.__class__.__name__].append(_metric(output, target).item()) - return pd.DataFrame(metrics) \ No newline at end of file + return pd.DataFrame(metrics) + +def evaluate_per_image_metric( + predictions: torch.Tensor, + targets: torch.Tensor, + metrics: List[Module] +) -> pd.DataFrame: + """ + Computes a set of metrics on a per-image basis and returns the results as a pandas DataFrame. + + :param predictions: Predicted images, shape (N, C, H, W). + :type predictions: torch.Tensor + :param targets: Target images, shape (N, C, H, W). + :type targets: torch.Tensor + :param metrics: List of metric functions to evaluate. + :type metrics: List[torch.nn.Module] + + :return: A DataFrame where each row corresponds to an image and each column corresponds to a metric. + :rtype: pd.DataFrame + """ + if predictions.shape != targets.shape: + raise ValueError(f"Shape mismatch: predictions {predictions.shape} vs targets {targets.shape}") + + results = [] + + for i in range(predictions.shape[0]): # Iterate over images + pred, target = predictions[i].unsqueeze(0), targets[i].unsqueeze(0) # Keep batch dimension + metric_scores = {metric.__class__.__name__: metric.forward(target, pred).item() for metric in metrics} + results.append(metric_scores) + + return pd.DataFrame(results) \ No newline at end of file From b537c8355474c9009794e3a56d6b4680297abeb1 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Tue, 25 Feb 2025 21:44:13 -0700 Subject: [PATCH 32/89] Added a new file which is a collection of helper functions for predicting images given model and dataset as well as formatting tensor predictions to numpy --- evaluation/predict_utils.py | 102 ++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 evaluation/predict_utils.py diff --git a/evaluation/predict_utils.py b/evaluation/predict_utils.py new file mode 100644 index 0000000..8e24fa5 --- /dev/null +++ b/evaluation/predict_utils.py @@ -0,0 +1,102 @@ +from typing import Optional, List, Union, Callable + +import torch +import numpy as np +from torch.utils.data import DataLoader, Dataset, Subset +from albumentations import ImageOnlyTransform, Compose + +def predict_image( + dataset: Dataset, + model: torch.nn.Module, + batch_size: int = 1, + device: str = "cpu", + num_workers: int = 0, + indices: Optional[List[int]] = None +) -> torch.Tensor: + """ + Runs a model on a dataset, performing a forward pass on all (or a subset of) input images + in evaluation mode and returning a stacked tensor of predictions. + DOES NOT check if the dataset dimensions are compatible with the model. + + :param dataset: A dataset that returns (input_tensor, target_tensor) tuples, + where input_tensor has shape (C, H, W). + :type dataset: torch.utils.data.Dataset + :param model: A PyTorch model that is compatible with the dataset inputs. + :type model: torch.nn.Module + :param batch_size: The number of samples per batch (default is 1). + :type batch_size: int, optional + :param device: The device to run inference on, e.g., "cpu" or "cuda". + :type device: str, optional + :param num_workers: Number of workers for the DataLoader (default is 0). + :type num_workers: int, optional + :param indices: Optional list of dataset indices to subset the dataset before inference. + :type indices: Optional[List[int]], optional + + :return: A stacked tensor of model predictions with shape (N, C, H, W), where N is the dataset size or subset size. + :rtype: torch.Tensor + """ + # Subset the dataset if indices are provided + if indices is not None: + dataset = Subset(dataset, indices) + + # Create DataLoader for efficient batch processing + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + model.to(device) + model.eval() + + predictions = [] # List to store predictions + + with torch.no_grad(): + for inputs, _ in dataloader: # Unpacking (input_tensor, target_tensor) + inputs = inputs.to(device) # Move input data to the specified device + + # Forward pass + outputs = model(inputs) + + # Store predictions + predictions.append(outputs.cpu()) # Move to CPU for stacking + + # Stack all predictions into a single tensor + return torch.cat(predictions, dim=0) if predictions else torch.empty(0) + +def process_tensor_image( + img_tensor: torch.Tensor, + dtype: Optional[np.dtype] = None, + dataset: Optional[Dataset] = None, + invert_function: Optional[Callable] = None +) -> np.ndarray: + """ + Processes model output/other image tensor by casting to numpy, applying an optional dtype casting, + and inverting target transformations if a dataset with `target_transform` is provided. + + :param img_tensor: Tensor stack of model-predicted images with shape (N, C, H, W). + :type img_tensor: torch.Tensor + :param dtype: Optional numpy dtype to cast the output array (default: None). + :type dtype: Optional[np.dtype], optional + :param dataset: Optional dataset object with `target_transform` to invert transformations. + :type dataset: Optional[torch.utils.data.Dataset], optional + :param invert_function: Optional function to invert transformations applied to the images. + If provided, overrides the invert function call from dataset transform. + :type invert_function: Optional[Callable], optional + + :return: Processed numpy array of images with shape (N, C, H, W). + :rtype: np.ndarray + """ + # Convert img_tensor to CPU and NumPy + output_images = img_tensor.cpu().numpy() + + # Optionally cast to specified dtype + if dtype is not None: + output_images = output_images.astype(dtype) + + # Apply inverse invert function when supplied or transformation if dataset supplied and target_transform is valid + if invert_function is not None and isinstance(invert_function, Callable): + output_images = np.array([invert_function(img) for img in output_images]) + elif dataset is not None and hasattr(dataset, "target_transform"): + target_transform = dataset.target_transform + if isinstance(target_transform, (ImageOnlyTransform, Compose)): + # Apply the transformation on each image + output_images = np.array([target_transform.invert(img) for img in output_images]) + + return output_images \ No newline at end of file From 1f7db29ee8d86d4cd022813130d240d59975677e Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Tue, 25 Feb 2025 21:45:19 -0700 Subject: [PATCH 33/89] Added new file which is a collection of helper functions to visualize patch dataset --- evaluation/plot_utils.py | 291 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 291 insertions(+) create mode 100644 evaluation/plot_utils.py diff --git a/evaluation/plot_utils.py b/evaluation/plot_utils.py new file mode 100644 index 0000000..a1bc79f --- /dev/null +++ b/evaluation/plot_utils.py @@ -0,0 +1,291 @@ +import random + +import torch +import numpy as np +import matplotlib.pyplot as plt +from typing import Union, Optional, List +from PIL import Image +from torch.utils.data import Dataset +from matplotlib.gridspec import GridSpec +from matplotlib.patches import Rectangle + +from ..datasets.PatchDataset import PatchDataset +from ..evaluation.predict_utils import predict_image, process_tensor_image +from ..evaluation.evaluation_utils import evaluate_per_image_metric + +def plot_single_image( + image: Union[np.ndarray, torch.Tensor], + ax: Optional[plt.Axes] = None, + cmap: str = "gray", + vmin: Optional[float] = None, + vmax: Optional[float] = None, + title: Optional[str] = None, + title_fontsize: int = 10 +): + """ + Plots a single image on the given matplotlib axis or creates a new figure if no axis is provided. + + :param image: The image to plot, either as a NumPy array or a PyTorch tensor. + :type image: Union[np.ndarray, torch.Tensor] + :param ax: Optional existing axis to plot on. If None, a new figure is created. + :type ax: Optional[plt.Axes], default is None + :param cmap: Colormap for visualization (default: "gray"). + :type cmap: str, optional + :param vmin: Minimum value for image scaling. If None, defaults to image min. + :type vmin: Optional[float], optional + :param vmax: Maximum value for image scaling. If None, defaults to image max. + :type vmax: Optional[float], optional + :param title: Optional title for the image. + :type title: Optional[str], optional + :param title_fontsize: Font size of the title. + :type title_fontsize: int, optional + """ + + # Convert tensor to NumPy + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + + # Handle grayscale vs multi-channel + if image.ndim == 3 and image.shape[0] in [1, 3]: # (C, H, W) format + image = np.transpose(image, (1, 2, 0)) # Convert to (H, W, C) + + # Create a new figure if no axis is provided + if ax is None: + fig, ax = plt.subplots(figsize=(5, 5)) + + # Plot the image + im = ax.imshow(image, cmap=cmap, vmin=vmin or image.min(), vmax=vmax or image.max()) + + # Hide axis + ax.axis("off") + + # Add title if provided + if title: + ax.set_title(title, fontsize=title_fontsize, fontweight="bold") + + return im # Return image object for further customization if needed + +def visualize_images_with_stats( + images: Union[torch.Tensor, np.ndarray], + cmap: str = "gray", + figsize: Optional[tuple] = None, + panel_width: int = 3, + show_stats: bool = True, + channel_names: Optional[list] = None, + title_fontsize: int = 10, + axes: Optional[np.ndarray] = None +): + """ + Visualizes images using matplotlib, handling various shapes: + - (H, W) → Single grayscale image. + - (C, H, W) → Multi-channel image. + - (N, C, H, W) → Multiple images, multiple channels. + + Supports external axes input for easier integration. + + :param images: Input images as PyTorch tensor or NumPy array. + :param cmap: Colormap for visualization. + :param figsize: Optional figure size. + :param panel_width: Width of each panel. + :param show_stats: Whether to display statistics (μ, σ, ⊥, ⊤) in titles. + :param channel_names: List of channel names for first row. + :param title_fontsize: Font size for titles. + :param axes: Optional pre-existing matplotlib Axes. + """ + if isinstance(images, torch.Tensor): + images = images.cpu().numpy() + + ndim = images.ndim + if ndim == 2: + images = images[np.newaxis, np.newaxis, ...] # Convert to (1, 1, H, W) + elif ndim == 3: + images = images[np.newaxis, ...] # Convert to (1, C, H, W) + elif ndim != 4: + raise ValueError(f"Unsupported shape {images.shape}. Expected (H, W), (C, H, W), or (N, C, H, W).") + + n_images, n_channels, _, _ = images.shape + + # Create figure and axes if not provided + if axes is None: + figsize = figsize or (n_channels * panel_width, n_images * panel_width) + fig, axes = plt.subplots(n_images, n_channels, figsize=figsize, squeeze=False) + + for i in range(n_images): + for j in range(n_channels): + img = images[i, j] + title = None + + # Compute statistics if needed + if show_stats: + img_mean, img_std, img_min, img_max = np.mean(img), np.std(img), np.min(img), np.max(img) + title = f"μ: {img_mean:.2f} | σ: {img_std:.2f} | ⊥: {img_min:.2f} | ⊤: {img_max:.2f}" + + # Use the helper function for plotting + plot_single_image(img, ax=axes[i, j], cmap=cmap, title=title, title_fontsize=title_fontsize) + + plt.tight_layout() + if axes is None: + plt.show() + +def plot_single_image( + image: np.ndarray, + ax: plt.Axes, + cmap: str = "gray", + vmin: Optional[float] = None, + vmax: Optional[float] = None, + title: Optional[str] = None, + title_fontsize: int = 10 +): + """ + Plots a single image on the given matplotlib axis. + + :param image: The image to plot (NumPy array). + :param ax: The matplotlib axis to plot on. + :param cmap: Colormap for visualization. + :param vmin: Minimum value for scaling. + :param vmax: Maximum value for scaling. + :param title: Optional title for the image. + :param title_fontsize: Font size of the title. + """ + ax.imshow(np.squeeze(image), cmap=cmap, vmin=vmin, vmax=vmax) + ax.axis("off") + if title: + ax.set_title(title, fontsize=title_fontsize, fontweight="bold") + +def plot_patches( + dataset: torch.utils.data.Dataset, + n_patches: int = 5, + model: Optional[torch.nn.Module] = None, + patch_index: Optional[List[int]] = None, + random_seed: int = 42, + metrics: Optional[List[torch.nn.Module]] = None, + device: str = "cpu", + **kwargs +): + """ + Plots dataset patches with optional model predictions and evaluation metrics using GridSpec. + Uses `plot_single_image` to ensure consistency. + + :param dataset: A dataset that returns (input_tensor, target_tensor) tuples. + :param n_patches: Number of patches to visualize (default: 5). + :param model: Optional PyTorch model to run inference on patches. + :param patch_index: List of dataset indices to select specific patches. + :param random_seed: Random seed for reproducibility. + :param metrics: List of metric functions to evaluate model predictions. + :param device: Device to run model inference on, e.g., "cpu" or "cuda". + :param **kwargs: Additional customization options (e.g., `cmap`, `panel_width`, `show_plot`). + """ + + cmap = kwargs.get("cmap", "gray") + panel_width = kwargs.get("panel_width", 5) + show_plot = kwargs.get("show_plot", True) + save_path = kwargs.get("save_path", None) + title_fontsize = kwargs.get("title_fontsize", 12) + + # Select patches + if patch_index is None: + random.seed(random_seed) + patch_index = random.sample(range(len(dataset)), n_patches) + else: + patch_index = [i for i in patch_index if i < len(dataset)] + n_patches = len(patch_index) + + inputs, targets, raw_images, patch_coords = [], [], [], [] + for i in patch_index: + input_tensor, target_tensor = dataset[i] + inputs.append(input_tensor) + targets.append(target_tensor) + patch_coords.append(dataset.patch_coords) # Extract (x, y) coordinates + raw_images.append(np.array(Image.open(dataset.input_names[0]))) + + inputs = torch.stack(inputs) + targets = torch.stack(targets) + + # Run model predictions (if provided) + predictions = predict_image(dataset, model, device=device, indices=patch_index) if model else None + + # Convert tensors to NumPy arrays + inputs_numpy = process_tensor_image(inputs, invert_function=dataset.input_transform.invert) + targets_numpy = process_tensor_image(targets, dataset=dataset) + predictions_numpy = process_tensor_image(predictions, dataset=dataset) if predictions is not None else None + + # Compute evaluation metrics (if applicable) + if metrics and predictions is not None: + metric_values = evaluate_per_image_metric( + predictions=predictions, + targets=targets, + metrics=metrics + ) + else: + metric_values = None + + # Determine number of columns (Raw + Input + Target + Optional Predictions) + n_predictions = predictions_numpy.shape[1] if predictions_numpy is not None else 0 + n_columns = 3 + n_predictions # (Raw, Input, Target, Predictions) + + # Compute raw image global vmin/vmax + raw_vmin, raw_vmax = np.min(raw_images), np.max(raw_images) + + # Set up figure and GridSpec layout with an extra row for column titles + figsize = (panel_width * n_columns, panel_width * (n_patches + 1)) # Extra space for headers + fig = plt.figure(figsize=figsize) + gs = GridSpec(n_patches + 1, n_columns, figure=fig, height_ratios=[0.05] + [1] * n_patches, hspace=0.05, wspace=0.05) + + # Column headers (Shared Titles) + column_titles = ["Raw Image", "Input Patch", "Target Patch"] + [f"Predicted {i+1}" for i in range(n_predictions)] + for j, title in enumerate(column_titles): + ax = fig.add_subplot(gs[0, j]) + ax.set_xticks([]) + ax.set_yticks([]) + ax.axis("off") + ax.text(0.5, 0.5, title, ha="center", va="center", fontsize=title_fontsize, fontweight="bold") + + # Iterate through patches and plot each column separately + for i in range(n_patches): + row_offset = i + 1 # Offset by 1 to account for the title row + + # Extract patch coordinates + patch_x, patch_y = patch_coords[i] + patch_size = targets_numpy.shape[-1] # Infer patch size from target shape + + # Compute per-patch vmin/vmax + input_vmin, input_vmax = np.min(inputs_numpy[i]), np.max(inputs_numpy[i]) + target_vmin, target_vmax = np.min(targets_numpy[i]), np.max(targets_numpy[i]) + + # Plot raw image with patch annotation + ax = fig.add_subplot(gs[row_offset, 0]) + plot_single_image(raw_images[i], ax, cmap, raw_vmin, raw_vmax) + rect = Rectangle((patch_x, patch_y), patch_size, patch_size, linewidth=2, edgecolor="r", facecolor="none") + ax.add_patch(rect) + + # Plot input patch + ax = fig.add_subplot(gs[row_offset, 1]) + plot_single_image(inputs_numpy[i], ax, cmap, input_vmin, input_vmax) + + # Plot target patch + ax = fig.add_subplot(gs[row_offset, 2]) + plot_single_image(targets_numpy[i], ax, cmap, target_vmin, target_vmax) + + # Plot prediction patches (if available) with metrics + if predictions_numpy is not None: + for j in range(n_predictions): + ax = fig.add_subplot(gs[row_offset, 3 + j]) + plot_single_image(predictions_numpy[i, j], ax, cmap, target_vmin, target_vmax) + + # Display metric values below prediction + metric_str = "" + if metric_values is not None: + metric_value_row = metric_values.iloc[i, :] + metric_str = "\n".join( + [f"{metric_name}: {metric_val:.2f}" for metric_name, metric_val in metric_value_row.items()] + ) + + ax.set_title(metric_str, fontsize=title_fontsize - 2) + + # Adjust layout and save/show + if save_path: + plt.savefig(save_path) + if show_plot: + plt.show() + else: + plt.close() \ No newline at end of file From 75d5adc023747a2616a8f2efc9d9ba354a8f8621 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Tue, 25 Feb 2025 21:46:34 -0700 Subject: [PATCH 34/89] Fixed bug in patch dataset returning the wrong length of itself --- datasets/PatchDataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/PatchDataset.py b/datasets/PatchDataset.py index 7f5500b..029c5ef 100644 --- a/datasets/PatchDataset.py +++ b/datasets/PatchDataset.py @@ -97,7 +97,7 @@ def __init__( Overridden Iterator functions """ def __len__(self): - return len(self.__patch_coords) + return len(self.__iter_patch_id) def __getitem__(self, _idx: int)->Tuple[torch.Tensor, torch.Tensor]: """ From 071aa8717bddf6c3424161bdb21ac27d9eb1a218 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Wed, 26 Feb 2025 11:28:44 -0700 Subject: [PATCH 35/89] Modified callback to allow the plot frequency during training to be tuned --- callbacks/IntermediatePlot.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/callbacks/IntermediatePlot.py b/callbacks/IntermediatePlot.py index 8ca8414..755c16c 100644 --- a/callbacks/IntermediatePlot.py +++ b/callbacks/IntermediatePlot.py @@ -20,6 +20,7 @@ def __init__(self, dataset: PatchDataset, plot_n_patches: int=5, plot_metrics: List[nn.Module]=None, + every_n_epochs: int=5, **kwargs): """ Initialize the IntermediatePlot callback. @@ -36,6 +37,8 @@ def __init__(self, :type plot_metrics: List[nn.Module], optional :param kwargs: Additional keyword arguments to be passed to plot_patches. :type kwargs: dict + :param every_n_epochs: How frequent should intermediate plots should be plotted, defaults to 5 + :type every_n_epochs: int :raises TypeError: If the dataset is not an instance of PatchDataset. """ super().__init__(name) @@ -47,15 +50,30 @@ def __init__(self, # Additional kwargs passed to plot_patches self.plot_n_patches = plot_n_patches self.plot_metrics = plot_metrics + self.every_n_epochs = every_n_epochs self.plot_kwargs = kwargs def on_epoch_end(self): """ - Called at the end of each epoch. + Called at the end of each epoch to plot predictions if the epoch is a multiple of `every_n_epochs`. + """ + if (self.trainer.epoch + 1) % self.every_n_epochs == 0 or self.trainer.epoch + 1 == self.trainer.total_epochs: + self._plot() + + def on_train_end(self): + """ + Called at the end of training. Plots if not already done in the last epoch. + """ + if (self.trainer.epoch + 1) % self.every_n_epochs != 0: + self._plot() + def _plot(self): + """ + Helper method to generate and save plots. Plot dataset with model predictions on n random images from dataset at the end of each epoch. + Called by the on_epoch_end and on_train_end methods """ - + original_device = next(self.trainer.model.parameters()).device plot_patches( From d7f62b804bb58e4de9280904876dd608bab5cfd9 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 27 Feb 2025 16:01:34 -0700 Subject: [PATCH 36/89] Fixed bug --- callbacks/IntermediatePlot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks/IntermediatePlot.py b/callbacks/IntermediatePlot.py index 755c16c..fcf34ed 100644 --- a/callbacks/IntermediatePlot.py +++ b/callbacks/IntermediatePlot.py @@ -57,7 +57,7 @@ def on_epoch_end(self): """ Called at the end of each epoch to plot predictions if the epoch is a multiple of `every_n_epochs`. """ - if (self.trainer.epoch + 1) % self.every_n_epochs == 0 or self.trainer.epoch + 1 == self.trainer.total_epochs: + if (self.trainer.epoch + 1) % self.every_n_epochs == 0 or self.trainer.epoch + 1 == self.trainer.epoch: self._plot() def on_train_end(self): From 4ea24a93bb4e36d27524a59047831e58e4a4c3ec Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 27 Feb 2025 16:19:19 -0700 Subject: [PATCH 37/89] Update trainer classes to remove best model attribute that is internally handled by abstract trainer class while updating the early termination coutner with the overridden property return the generator model. --- trainers/Trainer.py | 3 --- trainers/WGaNTrainer.py | 4 ---- 2 files changed, 7 deletions(-) diff --git a/trainers/Trainer.py b/trainers/Trainer.py index 5ebee41..bd5e63c 100644 --- a/trainers/Trainer.py +++ b/trainers/Trainer.py @@ -42,9 +42,6 @@ def __init__( self._backprop_loss = backprop_loss \ if isinstance(backprop_loss, list) else [backprop_loss] - # Make an initial copy of the model - self.best_model = self.model.state_dict().copy() - """ Overidden methods from the parent abstract class """ diff --git a/trainers/WGaNTrainer.py b/trainers/WGaNTrainer.py index 80482bb..e7e931e 100644 --- a/trainers/WGaNTrainer.py +++ b/trainers/WGaNTrainer.py @@ -71,10 +71,6 @@ def __init__(self, if self._gradient_penalty_fn is not None: self._gradient_penalty_fn.trainer = self - # Make an initial copy of the generator and discriminator models - self.best_generator = self._generator.state_dict().copy() - self.best_discriminator = self._discriminator.state_dict().copy() - # Global step counter and update frequencies self._discriminator_update_freq = discriminator_update_freq self._generator_update_freq = generator_update_freq From 6d362b0d58b9050e97c47e57fd41f78f8a74a930 Mon Sep 17 00:00:00 2001 From: Weishan Li <112203562+wli51@users.noreply.github.com> Date: Thu, 27 Feb 2025 16:22:15 -0700 Subject: [PATCH 38/89] Update callbacks/IntermediatePlot.py for clearer documentation of what the plotter does Co-authored-by: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com> --- callbacks/IntermediatePlot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/callbacks/IntermediatePlot.py b/callbacks/IntermediatePlot.py index 8ca8414..14bbb30 100644 --- a/callbacks/IntermediatePlot.py +++ b/callbacks/IntermediatePlot.py @@ -10,8 +10,8 @@ class IntermediatePatchPlot(AbstractCallback): """ - Callback to plot model generated outputs alongside ground - truth and input at the end end of each epoch. + Callback to plot model generated outputs, ground + truth, and input stained image patches at the end of each epoch. """ def __init__(self, From 13cb298f794883bd2e13aa37573a0259b01642e0 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 27 Feb 2025 16:57:22 -0700 Subject: [PATCH 39/89] Removed call to super class method that does not do anything --- trainers/Trainer.py | 3 --- trainers/WGaNTrainer.py | 4 +--- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/trainers/Trainer.py b/trainers/Trainer.py index bd5e63c..46d19a5 100644 --- a/trainers/Trainer.py +++ b/trainers/Trainer.py @@ -119,9 +119,6 @@ def train_epoch(self): """ Train the model for one epoch. """ - - super().train_epoch() - self._model.train() losses = defaultdict(list) # Iterate over the train_loader diff --git a/trainers/WGaNTrainer.py b/trainers/WGaNTrainer.py index e7e931e..999cf2d 100644 --- a/trainers/WGaNTrainer.py +++ b/trainers/WGaNTrainer.py @@ -216,9 +216,7 @@ def evaluate_step(self, return loss def train_epoch(self): - - super().train_epoch() - + self._generator.train() self._discriminator.train() From 93b83f2262b266c3ac953b3f71a331d4f78b2159 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 27 Feb 2025 17:17:03 -0700 Subject: [PATCH 40/89] Modify MlflowLogger class so the default behavior is not to set new tracking uri. --- callbacks/MlflowLogger.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/callbacks/MlflowLogger.py b/callbacks/MlflowLogger.py index 9295829..ffeff65 100644 --- a/callbacks/MlflowLogger.py +++ b/callbacks/MlflowLogger.py @@ -17,7 +17,7 @@ def __init__(self, name: str, artifact_name: str = 'best_model_weights.pth', - mlflow_uri: Union[pathlib.Path, str] = 'mlruns', + mlflow_uri: Union[pathlib.Path, str] = None, mlflow_experiment_name: str = 'Default', mlflow_start_run_args: dict = {}, mlflow_log_params_args: dict = {}, @@ -30,7 +30,11 @@ def __init__(self, :type name: str :param artifact_name: Name of the artifact file to log, defaults to 'best_model_weights.pth'. :type artifact_name: str, optional - :param mlflow_uri: URI for the MLflow tracking server, defaults to 'mlruns' under current wd. + :param mlflow_uri: URI for the MLflow tracking server, defaults to None. + If a path is specified, the logger class will call set_tracking_uri to that supplied path + thereby initiating a new tracking server. + If None (default), the logger class will not tamper with mlflow server to enable logging to a global server + initialized outside of this class. :type mlflow_uri: pathlib.Path or str, optional :param mlflow_experiment_name: Name of the MLflow experiment, defaults to 'Default'. :type mlflow_experiment_name: str, optional @@ -41,11 +45,16 @@ def __init__(self, """ super().__init__(name) + if mlflow_uri is not None: + try: + mlflow.set_tracking_uri(mlflow_uri) + except Exception as e: + raise RuntimeError(f"Error setting MLflow tracking URI: {e}") + try: - mlflow.set_tracking_uri(mlflow_uri) mlflow.set_experiment(mlflow_experiment_name) except Exception as e: - print(f"Error setting MLflow tracking URI: {e}") + raise RuntimeError(f"Error setting MLflow experiment: {e}") self._artifact_name = artifact_name self._mlflow_start_run_args = mlflow_start_run_args From 5380ac727bdaa1f7dea807b72b8b6e82ea24c2ca Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Thu, 27 Feb 2025 17:21:57 -0700 Subject: [PATCH 41/89] Modify MlflowLogger class so the default behavior is not to log any parameter or start_run argument unless explicitly specified --- callbacks/MlflowLogger.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/callbacks/MlflowLogger.py b/callbacks/MlflowLogger.py index ffeff65..7d9b166 100644 --- a/callbacks/MlflowLogger.py +++ b/callbacks/MlflowLogger.py @@ -1,7 +1,7 @@ import os import pathlib import tempfile -from typing import Union +from typing import Union, Dict import mlflow import torch @@ -19,8 +19,8 @@ def __init__(self, artifact_name: str = 'best_model_weights.pth', mlflow_uri: Union[pathlib.Path, str] = None, mlflow_experiment_name: str = 'Default', - mlflow_start_run_args: dict = {}, - mlflow_log_params_args: dict = {}, + mlflow_start_run_args: dict = None, + mlflow_log_params_args: dict = None, ): """ @@ -38,9 +38,9 @@ def __init__(self, :type mlflow_uri: pathlib.Path or str, optional :param mlflow_experiment_name: Name of the MLflow experiment, defaults to 'Default'. :type mlflow_experiment_name: str, optional - :param mlflow_start_run_args: Additional arguments for starting an MLflow run, defaults to {}. + :param mlflow_start_run_args: Additional arguments for starting an MLflow run, defaults to None. :type mlflow_start_run_args: dict, optional - :param mlflow_log_params_args: Additional arguments for logging parameters to MLflow, defaults to {}. + :param mlflow_log_params_args: Additional arguments for logging parameters to MLflow, defaults to None. :type mlflow_log_params_args: dict, optional """ super().__init__(name) @@ -66,12 +66,24 @@ def on_train_start(self): Calls mlflow start run and logs params if provided """ - mlflow.start_run( - **self._mlflow_start_run_args - ) - mlflow.log_params( - self._mlflow_log_params_args - ) + + if self._mlflow_start_run_args is None: + pass + elif isinstance(self._mlflow_start_run_args, Dict): + mlflow.start_run( + **self._mlflow_start_run_args + ) + else: + raise TypeError("mlflow_start_run_args must be None or a dictionary.") + + if self._mlflow_log_params_args is None: + pass + elif isinstance(self._mlflow_log_params_args, Dict): + mlflow.log_params( + **self._mlflow_log_params_args + ) + else: + raise TypeError("mlflow_log_params_args must be None or a dictionary.") def on_epoch_end(self): """ From 9a27c2395382577f6931cdfd253bbda7372b95a7 Mon Sep 17 00:00:00 2001 From: Weishan Li <112203562+wli51@users.noreply.github.com> Date: Fri, 28 Feb 2025 09:41:39 -0700 Subject: [PATCH 42/89] Modify comment for improved clarity Co-authored-by: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com> --- evaluation/visualization_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluation/visualization_utils.py b/evaluation/visualization_utils.py index cf37aca..f98c1b6 100644 --- a/evaluation/visualization_utils.py +++ b/evaluation/visualization_utils.py @@ -192,7 +192,7 @@ def plot_patches( vmax=vmax ) - ## Compute metrics for single set of target output pairs and add to subplot title + ## Compute metrics for single set of (target, output) pairs and add to subplot title metric_str = "" if _metrics is not None: for _metric in _metrics: From c284f054e087592071ebff009320f160e8371600 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 09:44:01 -0700 Subject: [PATCH 43/89] Removed TODO item as it is not going to be useful --- trainers/WGaNTrainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trainers/WGaNTrainer.py b/trainers/WGaNTrainer.py index 999cf2d..9119a09 100644 --- a/trainers/WGaNTrainer.py +++ b/trainers/WGaNTrainer.py @@ -197,7 +197,8 @@ def evaluate_step(self, # Compute losses discriminator_loss = self._discriminator_loss_fn(discriminator_real_score, discriminator_fake_score) - ## TODO: decide if gradient loss computation during eval mode is meaningful + ## Declare an empty tensor for the gradient penalty loss as + # it is not useful during evaluation gp_loss = torch.tensor(0.0, device=self.device) generator_loss = self._generator_loss_fn(discriminator_fake_score, generated_images, real_images, self.epoch) From a6f40f64644819cd1931a41157ebede02be0aa5e Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 09:53:45 -0700 Subject: [PATCH 44/89] Updated variable names and docstring for DiscriminatorLoss.py for improved clarity in what the wGAN discriminator loss does --- losses/DiscriminatorLoss.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/losses/DiscriminatorLoss.py b/losses/DiscriminatorLoss.py index 43a1f96..a9c5e74 100644 --- a/losses/DiscriminatorLoss.py +++ b/losses/DiscriminatorLoss.py @@ -5,30 +5,32 @@ class DiscriminatorLoss(AbstractLoss): """ This class implements the loss function for the discriminator in a Generative Adversarial Network (GAN). - The discriminator loss measures how well the discriminator is able to distinguish between real (ground truth) - images and fake (generated) images produced by the generator. + The discriminator loss measures how well the discriminator is able to distinguish between real (ground expected_truth) + images and fake (expected_generated) images produced by the generator. """ def __init__(self, _metric_name): super().__init__(_metric_name) - def forward(self, truth, generated): + def forward(self, expected_truth, expected_generated): """ - Computes the GaN discriminator loss given ground truth image and generated image + Computes the Wasserstein Discriminator Loss loss given ground expected_truth image and expected_generated image - :param truth: The tensor containing the ground truth image, - should be of shape [batch_size, channel_number, img_height, img_width]. - :type truth: torch.Tensor - :param generated: The tensor containing model generated image, - should be of shape [batch_size, channel_number, img_height, img_width]. - :type generated: torch.Tensor + :param expected_truth: The tensor containing the ground expected_truth + probability score predicted by the discriminator over a batch of real images (input target pair), + should be of shape [batch_size, 1]. + :type expected_truth: torch.Tensor + :param expected_generated: The tensor containing model expected_generated + probability score predicted by the discriminator over a batch of generated images (input generated pair), + should be of shape [batch_size, 1]. + :type expected_generated: torch.Tensor :return: The computed metric as a float value. :rtype: float """ # If the probability output is more than Scalar, take the mean of the output - if truth.dim() >= 3: - truth = torch.mean(truth, tuple(range(2, truth.dim()))) - if generated.dim() >= 3: - generated = torch.mean(generated, tuple(range(2, generated.dim()))) + if expected_truth.dim() >= 3: + expected_truth = torch.mean(expected_truth, tuple(range(2, expected_truth.dim()))) + if expected_generated.dim() >= 3: + expected_generated = torch.mean(expected_generated, tuple(range(2, expected_generated.dim()))) - return (generated - truth).mean() \ No newline at end of file + return (expected_generated - expected_truth).mean() \ No newline at end of file From 778144b5ce1e3d7d083f80b07eedb61b04de69bc Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 10:16:26 -0700 Subject: [PATCH 45/89] Modified MlflowLogger so that the experiment name is also not configured by default --- callbacks/MlflowLogger.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/callbacks/MlflowLogger.py b/callbacks/MlflowLogger.py index 7d9b166..3697c48 100644 --- a/callbacks/MlflowLogger.py +++ b/callbacks/MlflowLogger.py @@ -1,7 +1,7 @@ import os import pathlib import tempfile -from typing import Union, Dict +from typing import Union, Dict, Optional import mlflow import torch @@ -18,7 +18,7 @@ def __init__(self, name: str, artifact_name: str = 'best_model_weights.pth', mlflow_uri: Union[pathlib.Path, str] = None, - mlflow_experiment_name: str = 'Default', + mlflow_experiment_name: Optional[str] = None, mlflow_start_run_args: dict = None, mlflow_log_params_args: dict = None, @@ -36,7 +36,9 @@ def __init__(self, If None (default), the logger class will not tamper with mlflow server to enable logging to a global server initialized outside of this class. :type mlflow_uri: pathlib.Path or str, optional - :param mlflow_experiment_name: Name of the MLflow experiment, defaults to 'Default'. + :param mlflow_experiment_name: Name of the MLflow experiment, defaults to None, which will not call the + set_experiment method of mlflow and will use whichever experiment name that is globally configured. If a + name is provided, the logger class will call set_experiment to that supplied name. :type mlflow_experiment_name: str, optional :param mlflow_start_run_args: Additional arguments for starting an MLflow run, defaults to None. :type mlflow_start_run_args: dict, optional @@ -51,10 +53,11 @@ def __init__(self, except Exception as e: raise RuntimeError(f"Error setting MLflow tracking URI: {e}") - try: - mlflow.set_experiment(mlflow_experiment_name) - except Exception as e: - raise RuntimeError(f"Error setting MLflow experiment: {e}") + if mlflow_experiment_name is not None: + try: + mlflow.set_experiment(mlflow_experiment_name) + except Exception as e: + raise RuntimeError(f"Error setting MLflow experiment: {e}") self._artifact_name = artifact_name self._mlflow_start_run_args = mlflow_start_run_args From b270630dcb68741106eac417be87172c37609547 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 10:20:03 -0700 Subject: [PATCH 46/89] Removed outdated TODO --- datasets/CachedDataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/datasets/CachedDataset.py b/datasets/CachedDataset.py index 672f440..62fe28e 100644 --- a/datasets/CachedDataset.py +++ b/datasets/CachedDataset.py @@ -98,7 +98,6 @@ def input_names(self): Get the input names from the dataset object """ if self._current_idx is not None: - ## TODO: need to think over if this is at all necessary if self._current_idx in self.__cache_input_names: return self.__cache_input_names[self._current_idx] else: From 2eefe5c7a36aadfd244372cde35bc9fa65c062b4 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 10:21:05 -0700 Subject: [PATCH 47/89] Removed commented out code that is no longer needed --- trainers/Trainer.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/trainers/Trainer.py b/trainers/Trainer.py index 46d19a5..ee42103 100644 --- a/trainers/Trainer.py +++ b/trainers/Trainer.py @@ -148,17 +148,4 @@ def evaluate_epoch(self): # reduce loss return { key: sum(value) / len(value) for key, value in losses.items() - } - - # @property - # def log(self): - # """ - # Returns the training and validation losses and metrics. - # """ - # log ={ - # **{'epoch': list(range(1, self.epoch + 1))}, - # **self._train_metrics, - # **{f'val_{key}': val for key, val in self._val_metrics.items()} - # } - - # return log \ No newline at end of file + } \ No newline at end of file From 756429c946715f889fbc525693ecd4746fb30237 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 10:26:22 -0700 Subject: [PATCH 48/89] Modified function name to make metrics aggregation more clear. --- metrics/AbstractMetrics.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/metrics/AbstractMetrics.py b/metrics/AbstractMetrics.py index 54aeada..4ee52e6 100644 --- a/metrics/AbstractMetrics.py +++ b/metrics/AbstractMetrics.py @@ -58,8 +58,15 @@ def reset(self): self.__train_metric_values = [] self.__val_metric_values = [] - def compute(self, aggregation: Optional[str] = 'mean'): - """Computes the final metric value.""" + def compute(self, **kwargs): + """ + Calls the aggregate_metrics method to compute the metric value for now + In future may be used for more complex computations + """ + return self.aggregate_metrics(**kwargs) + + def aggregate_metrics(self, aggregation: Optional[str] = 'mean'): + """Aggregates the metric value over batches""" if aggregation == 'mean': return \ From 30f6b17debcc6b28979cf83a9c09454b5abc0b85 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 10:28:09 -0700 Subject: [PATCH 49/89] Update docstring --- metrics/AbstractMetrics.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/metrics/AbstractMetrics.py b/metrics/AbstractMetrics.py index 4ee52e6..9b8e0c9 100644 --- a/metrics/AbstractMetrics.py +++ b/metrics/AbstractMetrics.py @@ -66,7 +66,14 @@ def compute(self, **kwargs): return self.aggregate_metrics(**kwargs) def aggregate_metrics(self, aggregation: Optional[str] = 'mean'): - """Aggregates the metric value over batches""" + """ + Aggregates the metric value over batches + + :param aggregation: The aggregation method to use, by default 'mean' + :type aggregation: Optional[str] + :return: The aggregated metric value for training and validation + :rtype: Tuple[torch.tensor, torch.tensor] + """ if aggregation == 'mean': return \ From 7ac89b80d9d0718bd8207a390479d3aef30046c4 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 10:29:01 -0700 Subject: [PATCH 50/89] Remove description of where the code is adapted from for consistency --- metrics/AbstractMetrics.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/metrics/AbstractMetrics.py b/metrics/AbstractMetrics.py index 9b8e0c9..24fc571 100644 --- a/metrics/AbstractMetrics.py +++ b/metrics/AbstractMetrics.py @@ -4,9 +4,6 @@ import torch import torch.nn as nn -""" -Adapted from https://github.com/WayScience/nuclear_speckles_analysis -""" class AbstractMetrics(nn.Module, ABC): """Abstract class for metrics""" @@ -68,7 +65,7 @@ def compute(self, **kwargs): def aggregate_metrics(self, aggregation: Optional[str] = 'mean'): """ Aggregates the metric value over batches - + :param aggregation: The aggregation method to use, by default 'mean' :type aggregation: Optional[str] :return: The aggregated metric value for training and validation From ed7d9404168a04067076b87703fe9d8cbad87835 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 22:20:30 -0700 Subject: [PATCH 51/89] fixed bug, log_params should not take keyword arguments, should just be a dict --- callbacks/MlflowLogger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks/MlflowLogger.py b/callbacks/MlflowLogger.py index 3697c48..0996304 100644 --- a/callbacks/MlflowLogger.py +++ b/callbacks/MlflowLogger.py @@ -83,7 +83,7 @@ def on_train_start(self): pass elif isinstance(self._mlflow_log_params_args, Dict): mlflow.log_params( - **self._mlflow_log_params_args + self._mlflow_log_params_args ) else: raise TypeError("mlflow_log_params_args must be None or a dictionary.") From cec0df69ad30a66d7ed651909d086f1ac02d2fe1 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 22:56:03 -0700 Subject: [PATCH 52/89] Changed comment to one line for cleanness --- datasets/CachedDataset.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/datasets/CachedDataset.py b/datasets/CachedDataset.py index 62fe28e..e743dcf 100644 --- a/datasets/CachedDataset.py +++ b/datasets/CachedDataset.py @@ -46,9 +46,7 @@ def __init__( if prefill_cache: self.cache() - """ - Overriden methods for Dataset class - """ + """Overriden methods for Dataset class""" def __len__(self): """ Return the length of the dataset @@ -73,9 +71,8 @@ def __getitem__(self, _idx: int): self._update_cache(_idx) return self.__cache[_idx] - """ - Setters - """ + """Setters""" + def set_cache_size(self, cache_size: int): """ Set the cache size. Does not automatically repopulate the cache but @@ -89,9 +86,7 @@ def set_cache_size(self, cache_size: int): while len(self.__cache) > self.__cache_size: self._pop_cache() - """ - Properties to remain accessible - """ + """Properties to remain accessible""" @property def input_names(self): """ @@ -162,9 +157,7 @@ def dataset(self): """ return self.__dataset - """ - Cache method - """ + """Cache method""" def cache(self): """ Clears the current cache and re-populate cache with data from the dataset object From 7f20d7abd993784974715cfeab3e29b6f986de94 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Fri, 28 Feb 2025 22:58:08 -0700 Subject: [PATCH 53/89] Renamed cache related functions for clarity --- datasets/CachedDataset.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/datasets/CachedDataset.py b/datasets/CachedDataset.py index e743dcf..04bc237 100644 --- a/datasets/CachedDataset.py +++ b/datasets/CachedDataset.py @@ -44,7 +44,7 @@ def __init__( self._current_idx = None if prefill_cache: - self.cache() + self.populate_cache() """Overriden methods for Dataset class""" def __len__(self): @@ -68,7 +68,7 @@ def __getitem__(self, _idx: int): return self.__cache[_idx] else: # cache miss, load from parent class method dynamically - self._update_cache(_idx) + self._push_cache(_idx) return self.__cache[_idx] """Setters""" @@ -158,22 +158,20 @@ def dataset(self): return self.__dataset """Cache method""" - def cache(self): + def populate_cache(self): """ - Clears the current cache and re-populate cache with data from the dataset object - Iteratively calls the update cache method on a sequence of indices to fill the cache + Populates/clears the current cache and re-populate the cache with data from the dataset object + Iteratively calls the _push_cache method on a sequence of indices """ self._clear_cache() for _idx in range(min(self.__cache_size, len(self.__dataset))): - self._update_cache(_idx) + self._push_cache(_idx) - """ - Internal helper methods - """ + """Internal helper methods""" - def _update_cache(self, _idx: int): + def _push_cache(self, _idx: int): """ - Update the cache with data from the dataset object. + Update the cache with a single item retrieved from the dataset object. Calls the update cache metadata method as well to sync data and metadata Pops the cache if the cache size is exceeded on a first in, first out basis @@ -184,7 +182,7 @@ def _update_cache(self, _idx: int): self.__cache[_idx] = self.__dataset[_idx] if len(self.__cache) >= self.__cache_size: self._pop_cache() - self._update_cache_metadata(_idx) + self._push_cache_metadata(_idx) def _pop_cache(self): """ @@ -192,10 +190,10 @@ def _pop_cache(self): """ self.__cache.popitem(last=False) - def _update_cache_metadata(self, _idx: int): + def _push_cache_metadata(self, _idx: int): """ Update the cache metadata with data from the dataset object - Meant to be called by _update_cache method + Meant to be called by _push_cache method :param _idx: Index of the data to cache :type _idx: int From 825b745e3b75eb6565897fa1c606c574f0bb3321 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 1 Mar 2025 01:20:10 -0700 Subject: [PATCH 54/89] Added comment to better describe the wGAN DiscriminatorLoss --- losses/DiscriminatorLoss.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/losses/DiscriminatorLoss.py b/losses/DiscriminatorLoss.py index a9c5e74..caa057f 100644 --- a/losses/DiscriminatorLoss.py +++ b/losses/DiscriminatorLoss.py @@ -4,7 +4,7 @@ class DiscriminatorLoss(AbstractLoss): """ - This class implements the loss function for the discriminator in a Generative Adversarial Network (GAN). + This class implements the loss function for the discriminator in a Wasserstein Generative Adversarial Network (wGAN). The discriminator loss measures how well the discriminator is able to distinguish between real (ground expected_truth) images and fake (expected_generated) images produced by the generator. """ @@ -13,7 +13,7 @@ def __init__(self, _metric_name): def forward(self, expected_truth, expected_generated): """ - Computes the Wasserstein Discriminator Loss loss given ground expected_truth image and expected_generated image + Computes the Wasserstein Discriminator Loss given probability scores expected_truth and expected_generated from the discriminator :param expected_truth: The tensor containing the ground expected_truth probability score predicted by the discriminator over a batch of real images (input target pair), @@ -28,6 +28,8 @@ def forward(self, expected_truth, expected_generated): """ # If the probability output is more than Scalar, take the mean of the output + # For compatibility with both a Discriminator class that would output a scalar probability (currently implemented) + # and a Discriminator class that would output a 2d matrix of probabilities (currently not implemented) if expected_truth.dim() >= 3: expected_truth = torch.mean(expected_truth, tuple(range(2, expected_truth.dim()))) if expected_generated.dim() >= 3: From ed47a060dfe4e56e325bbf4ea46a0716df312d6d Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 1 Mar 2025 01:25:43 -0700 Subject: [PATCH 55/89] Modified docstring for better clarity --- trainers/Trainer.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/trainers/Trainer.py b/trainers/Trainer.py index ee42103..5c6411c 100644 --- a/trainers/Trainer.py +++ b/trainers/Trainer.py @@ -8,20 +8,13 @@ class Trainer(AbstractTrainer): """ - Trainer class for single img2img convolutional models backpropagating on single loss items + Trainer class for generator while backpropagating on single or multiple loss functions. """ def __init__( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, backprop_loss: Union[torch.nn.Module, List[torch.nn.Module]], - # rest of the arguments are passed to and handled by the parent class - # - dataset - # - batch_size - # - epochs - # - patience - # - callbacks - # - metrics **kwargs ): """ From 63829d7549ea83f2105606ecc69a10eeacafbeb9 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 1 Mar 2025 01:32:19 -0700 Subject: [PATCH 56/89] Changes the default behavior of early termination to be disabled when early termination metric is not supplied --- trainers/AbstractTrainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/trainers/AbstractTrainer.py b/trainers/AbstractTrainer.py index 83a590d..5f152aa 100644 --- a/trainers/AbstractTrainer.py +++ b/trainers/AbstractTrainer.py @@ -62,6 +62,7 @@ def __init__( self._best_model = None self._best_loss = float("inf") + self._early_termination = None # switch for early termination self._early_stop_counter = 0 self._early_termination_metric = early_termination_metric @@ -210,9 +211,10 @@ def train(self): # Update early stopping if self._early_termination_metric is None: - # use the first loss function value as early stopping metric - early_term_metric = next(iter(val_loss.values())) + # Do not perform early stopping when no termination metric is specified + self._early_termination = False else: + self._early_termination = True # First look for the metric in validation loss if self._early_termination_metric in list(val_loss.keys()): early_term_metric = val_loss[self._early_termination_metric] @@ -225,7 +227,7 @@ def train(self): self.update_early_stop(early_term_metric) # Check if early stopping is needed - if self.early_stop_counter >= self.patience: + if self._early_termination and self.early_stop_counter >= self.patience: print(f"Early termination at epoch {epoch + 1} with best validation metric {self._best_loss}") break From 18dec2c21571598016db1b0b636eab600cc02e5b Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 1 Mar 2025 01:52:32 -0700 Subject: [PATCH 57/89] Added parameter to determine if batch normalization will be used in discriminators, defaults to don't batch normalize --- models/discriminator.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/models/discriminator.py b/models/discriminator.py index 6f1b8e7..fa93b22 100644 --- a/models/discriminator.py +++ b/models/discriminator.py @@ -13,7 +13,8 @@ def __init__( n_in_channels: int, n_in_filters: int, _conv_depth: int=4, - _leaky_relu_alpha: float=0.2 + _leaky_relu_alpha: float=0.2, + _batch_norm: bool=False ): """ A patch-based discriminator for pix2pix GANs that outputs a feature map @@ -29,6 +30,8 @@ def __init__( :param _leaky_relu_alpha: (float) alpha value for leaky ReLU activation. Must be between 0 and 1 :type _leaky_relu_alpha: float + :param _batch_norm: (bool) whether to use batch normalization, defaults to False + :type _batch_norm: bool """ super().__init__() @@ -55,7 +58,10 @@ def __init__( conv_layers.append( nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=1, padding=1) ) - conv_layers.append(nn.BatchNorm2d(n_channels * 2)) + + if _batch_norm: + conv_layers.append(nn.BatchNorm2d(n_channels * 2)) + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) n_channels *= 2 self._conv_layers = nn.Sequential(*conv_layers) @@ -80,6 +86,7 @@ def __init__( n_in_filters: int, _conv_depth: int=4, _leaky_relu_alpha: float=0.2, + _batch_norm: bool=False, _pool_before_fc: bool=False ): """ @@ -96,6 +103,8 @@ def __init__( :param _leaky_relu_alpha: (float) alpha value for leaky ReLU activation. Must be between 0 and 1 :type _leaky_relu_alpha: float + :param _batch_norm: (bool) whether to use batch normalization, defaults to False + :type _batch_norm: bool :param _pool_before_fc: (bool) whether to pool before the fully connected network Pooling before the fully connected network can reduce the number of parameters :type _pool_before_fc: bool @@ -116,7 +125,10 @@ def __init__( conv_layers.append( nn.Conv2d(n_channels, n_channels * 2, kernel_size=4, stride=2, padding=1) ) - conv_layers.append(nn.BatchNorm2d(n_channels * 2)) + + if _batch_norm: + conv_layers.append(nn.BatchNorm2d(n_channels * 2)) + conv_layers.append(nn.LeakyReLU(_leaky_relu_alpha, inplace=True)) n_channels *= 2 From c0b00b66e6c1a89c3518238216292ca0807758f3 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 1 Mar 2025 02:08:05 -0700 Subject: [PATCH 58/89] Do not retain graph when computing the gradient penalty loss for potential memory saving --- losses/GradientPenaltyLoss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/losses/GradientPenaltyLoss.py b/losses/GradientPenaltyLoss.py index 7442db9..e804a13 100644 --- a/losses/GradientPenaltyLoss.py +++ b/losses/GradientPenaltyLoss.py @@ -38,7 +38,7 @@ def forward(self, truth, generated): inputs=interpolated, grad_outputs=torch.ones_like(prob_interpolated), create_graph=True, - retain_graph=True, + retain_graph=False, )[0] gradients = gradients.view(batch_size, -1) From 90660ba26a47878855e976f942db23eb529a9de2 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 1 Mar 2025 02:10:21 -0700 Subject: [PATCH 59/89] Removed unecessary if statement --- losses/GeneratorLoss.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/losses/GeneratorLoss.py b/losses/GeneratorLoss.py index 04e2fe9..51695ae 100644 --- a/losses/GeneratorLoss.py +++ b/losses/GeneratorLoss.py @@ -28,29 +28,31 @@ def forward(self, discriminator_probs: torch.tensor, truth: torch.tensor, generated: torch.tensor, - epoch: Optional[int] = None + epoch: int = 0 ): """ Computes the loss for the GaN generator. :param discriminator_probs: The probabilities of the discriminator for the fake images being real. + :type discriminator_probs: torch.tensor :param truth: The tensor containing the ground truth image, should be of shape [batch_size, channel_number, img_height, img_width]. :type truth: torch.Tensor :param generated: The tensor containing model generated image, should be of shape [batch_size, channel_number, img_height, img_width]. + :type generated: torch.Tensor :param epoch: The current epoch number. - Used for a smoothing weight for the adversarial loss component + Used for a smoothing weight for the adversarial loss component + Defaults to 0. :type epoch: int - :type generated: torch.Tensor :return: The computed metric as a float value. :rtype: float """ # Adversarial loss adversarial_loss = -torch.mean(discriminator_probs) - if epoch is not None: - adversarial_loss = 0.01 * adversarial_loss/(epoch + 1) + + adversarial_loss = 0.01 * adversarial_loss/(epoch + 1) image_loss = self._reconstruction_loss(generated, truth) From 80be3ddae134a8152bb86afddf77c49c27c31809 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 1 Mar 2025 02:19:34 -0700 Subject: [PATCH 60/89] Added reconstruction loss weight parameter that defaults to 1 --- losses/GeneratorLoss.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/losses/GeneratorLoss.py b/losses/GeneratorLoss.py index 51695ae..9ec2bd5 100644 --- a/losses/GeneratorLoss.py +++ b/losses/GeneratorLoss.py @@ -11,18 +11,25 @@ class GeneratorLoss(AbstractLoss): Combines an adversarial loss component with an image reconstruction loss. """ def __init__(self, - _metric_name: str, - reconstruction_loss: Optional[torch.tensor] = L1Loss() + _metric_name: str, + reconstruction_loss: Optional[torch.tensor] = L1Loss(), + reconstruction_weight: float = 1.0 ): """ :param reconstruction_loss: The image reconstruction loss, defaults to L1Loss(reduce=False) :type reconstruction_loss: torch.tensor + :param reconstruction_weight: The weight for the image reconstruction loss, defaults to 1.0 + :type reconstruction_weight: float """ super().__init__(_metric_name) self._reconstruction_loss = reconstruction_loss + if isinstance(reconstruction_weight, float): + self._reconstruction_weight = reconstruction_weight + else: + raise ValueError("reconstruction_weight must be a float value") def forward(self, discriminator_probs: torch.tensor, @@ -51,9 +58,8 @@ def forward(self, # Adversarial loss adversarial_loss = -torch.mean(discriminator_probs) - adversarial_loss = 0.01 * adversarial_loss/(epoch + 1) image_loss = self._reconstruction_loss(generated, truth) - return adversarial_loss + image_loss.mean() \ No newline at end of file + return adversarial_loss + self._reconstruction_weight * image_loss.mean() \ No newline at end of file From d7eb27e471ce8ced1be1ba511fd531bc93ea26f9 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 1 Mar 2025 10:35:25 -0700 Subject: [PATCH 61/89] Added raw_input and raw_target properties to PatchDataset to centralize access of raw images --- datasets/PatchDataset.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/datasets/PatchDataset.py b/datasets/PatchDataset.py index 029c5ef..08cecae 100644 --- a/datasets/PatchDataset.py +++ b/datasets/PatchDataset.py @@ -165,6 +165,42 @@ def all_patch_coords(self): def patch_coords(self): return self.__current_patch_coords + @property + def raw_input(self): + """ + Returns a tuple of input, target raw images where the current patch is cropped + from. Relies on the parent class _ImageDataset__input_cache and _ImageDataset__target_cache. + Returns None when the cache is empty. Raises an error if the input channel keys are not set. + + :return: Tuple of input, target raw images + :rtype: Tuple[np.ndarray, np + """ + + if self._input_channel_keys is None: + raise ValueError("Input channel keys not set") + + return np.stack( + [self._ImageDataset__input_cache[key] for key in self._input_channel_keys], + axis=0) if self._input_channel_keys is not None else None + + @property + def raw_target(self): + """ + Returns a tuple of input, target raw images where the current patch is cropped + from. Relies on the parent class _ImageDataset__input_cache and _ImageDataset__target_cache. + Returns None when the cache is empty. Raises an error if the target channel keys are not set. + + :return: Tuple of input, target raw images + :rtype: Tuple[np.ndarray, np.ndarray] + """ + + if self._target_channel_keys is None: + raise ValueError("Target channel keys not set") + + return np.stack( + [self._ImageDataset__target_cache[key] for key in self._target_channel_keys], + axis=0) if self._target_channel_keys is not None else None + """ Internal Helper functions """ From 5470fd44fe92e82609da52b3b91fb1fa04936595 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 1 Mar 2025 23:34:54 -0700 Subject: [PATCH 62/89] Fixed bug of higher order derivative not being able to be computed due to not retaining the graph --- losses/GradientPenaltyLoss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/losses/GradientPenaltyLoss.py b/losses/GradientPenaltyLoss.py index e804a13..7442db9 100644 --- a/losses/GradientPenaltyLoss.py +++ b/losses/GradientPenaltyLoss.py @@ -38,7 +38,7 @@ def forward(self, truth, generated): inputs=interpolated, grad_outputs=torch.ones_like(prob_interpolated), create_graph=True, - retain_graph=False, + retain_graph=True, )[0] gradients = gradients.view(batch_size, -1) From 65556a35fecb9679022d4fb29f333ec1b1ebdfe0 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Sat, 1 Mar 2025 23:36:21 -0700 Subject: [PATCH 63/89] Fixed early temination enable/disable logic to ensure when no early termination metric is supplied the trainer runs for the specified amount of epochs. --- trainers/AbstractTrainer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/trainers/AbstractTrainer.py b/trainers/AbstractTrainer.py index 5f152aa..6173062 100644 --- a/trainers/AbstractTrainer.py +++ b/trainers/AbstractTrainer.py @@ -44,8 +44,8 @@ def __init__( :param device: (optional) The device to be used for training. :type device: torch.device :param early_termination_metric: (optional) The metric to be tracked and used to update early - termination count on the validation dataset. If not configured, will be using the value - computed by the first validation loss function + termination count on the validation dataset. If None, early termination is disabled and the + training will run for the specified number of epochs. :type early_termination_metric: str """ @@ -62,9 +62,9 @@ def __init__( self._best_model = None self._best_loss = float("inf") - self._early_termination = None # switch for early termination self._early_stop_counter = 0 self._early_termination_metric = early_termination_metric + self._early_termination = True if early_termination_metric else False # Customize data splits self._train_ratio = kwargs.get("train", 0.7) @@ -212,9 +212,8 @@ def train(self): # Update early stopping if self._early_termination_metric is None: # Do not perform early stopping when no termination metric is specified - self._early_termination = False + early_term_metric = None else: - self._early_termination = True # First look for the metric in validation loss if self._early_termination_metric in list(val_loss.keys()): early_term_metric = val_loss[self._early_termination_metric] @@ -234,7 +233,7 @@ def train(self): for callback in self.callbacks: callback.on_train_end() - def update_early_stop(self, val_loss: torch.Tensor): + def update_early_stop(self, val_loss: Optional[torch.Tensor]): """ Method to update the early stopping criterion @@ -242,6 +241,10 @@ def update_early_stop(self, val_loss: torch.Tensor): :type val_loss: torch.Tensor """ + # When early termination is disabled, the best model is updated with the current model + if not self._early_termination and val_loss is None: + self.best_model = self.model.state_dict().copy() + if val_loss < self.best_loss: self.best_loss = val_loss self.early_stop_counter = 0 From 975143784efc85751d745a22b2e4af87d4a58d3e Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 00:01:32 -0700 Subject: [PATCH 64/89] Modified predict_image function so it returns the target tensor along with prediction to stream line evaluation. --- evaluation/predict_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/evaluation/predict_utils.py b/evaluation/predict_utils.py index 8e24fa5..eadc15f 100644 --- a/evaluation/predict_utils.py +++ b/evaluation/predict_utils.py @@ -32,8 +32,8 @@ def predict_image( :param indices: Optional list of dataset indices to subset the dataset before inference. :type indices: Optional[List[int]], optional - :return: A stacked tensor of model predictions with shape (N, C, H, W), where N is the dataset size or subset size. - :rtype: torch.Tensor + :return: Tuple of stacked target and prediction tensors. + :rtype: Tuple[torch.Tensor, torch.Tensor] """ # Subset the dataset if indices are provided if indices is not None: @@ -45,20 +45,20 @@ def predict_image( model.to(device) model.eval() - predictions = [] # List to store predictions + predictions, targets = [], [] with torch.no_grad(): - for inputs, _ in dataloader: # Unpacking (input_tensor, target_tensor) + for inputs, target in dataloader: # Unpacking (input_tensor, target_tensor) inputs = inputs.to(device) # Move input data to the specified device # Forward pass outputs = model(inputs) - - # Store predictions + + # output both target and prediction tensors for metric + targets.append(target.cpu()) predictions.append(outputs.cpu()) # Move to CPU for stacking - # Stack all predictions into a single tensor - return torch.cat(predictions, dim=0) if predictions else torch.empty(0) + return torch.cat(targets, dim=0), torch.cat(predictions, dim=0) def process_tensor_image( img_tensor: torch.Tensor, From 1b5d7ec0af9f3bdaf797345f90b8913fa46cdb6d Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 00:02:40 -0700 Subject: [PATCH 65/89] Modified evlauation_per_image_metric function so metrics cna be computed on a subset of the target prediciton pairs. --- evaluation/evaluation_utils.py | 35 +++++++++------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/evaluation/evaluation_utils.py b/evaluation/evaluation_utils.py index accb77e..e9d5c48 100644 --- a/evaluation/evaluation_utils.py +++ b/evaluation/evaluation_utils.py @@ -1,37 +1,15 @@ -from collections import defaultdict -from typing import List, Dict, Callable, Union +from typing import List, Optional import pandas as pd import torch from torch.nn import Module from torch.utils.data import DataLoader -def evaluate_metrics( - _model: torch.nn.Module, - _dataset: torch.utils.data.Dataset, - _metrics: List[Union[Callable, torch.nn.Module]], - _device:str='cpu' -): - metrics = defaultdict(list) - _model.to(_device) - _model.eval() - - data_loader = DataLoader(_dataset, batch_size=1, shuffle=False) - - with torch.no_grad(): - for input, target in data_loader: - input = input.to(_device) - target = target.to(_device) - output = _model(input) - for _metric in _metrics: - metrics[_metric.__class__.__name__].append(_metric(output, target).item()) - - return pd.DataFrame(metrics) - def evaluate_per_image_metric( predictions: torch.Tensor, targets: torch.Tensor, - metrics: List[Module] + metrics: List[Module], + indices: Optional[List[int]] = None ) -> pd.DataFrame: """ Computes a set of metrics on a per-image basis and returns the results as a pandas DataFrame. @@ -42,6 +20,8 @@ def evaluate_per_image_metric( :type targets: torch.Tensor :param metrics: List of metric functions to evaluate. :type metrics: List[torch.nn.Module] + :param indices: Optional list of indices to subset the dataset before inference. If None, all images are evaluated. + :type indices: Optional[List[int]], optional :return: A DataFrame where each row corresponds to an image and each column corresponds to a metric. :rtype: pd.DataFrame @@ -51,7 +31,10 @@ def evaluate_per_image_metric( results = [] - for i in range(predictions.shape[0]): # Iterate over images + if indices is None: + indices = range(predictions.shape[0]) + + for i in indices: # Iterate over images/subset pred, target = predictions[i].unsqueeze(0), targets[i].unsqueeze(0) # Keep batch dimension metric_scores = {metric.__class__.__name__: metric.forward(target, pred).item() for metric in metrics} results.append(metric_scores) From a05211367b930288394a22828600b5d5b20e8ca5 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 00:05:19 -0700 Subject: [PATCH 66/89] Modified plot_patches function for compatibility with updated predict_image --- evaluation/plot_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluation/plot_utils.py b/evaluation/plot_utils.py index a1bc79f..c001078 100644 --- a/evaluation/plot_utils.py +++ b/evaluation/plot_utils.py @@ -202,7 +202,7 @@ def plot_patches( targets = torch.stack(targets) # Run model predictions (if provided) - predictions = predict_image(dataset, model, device=device, indices=patch_index) if model else None + _, predictions = predict_image(dataset, model, device=device, indices=patch_index) if model else None # Convert tensors to NumPy arrays inputs_numpy = process_tensor_image(inputs, invert_function=dataset.input_transform.invert) From 2ef624f797fadd555a447d9c702c3c60cbc6f9da Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 00:05:30 -0700 Subject: [PATCH 67/89] Update return type hint --- evaluation/predict_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/evaluation/predict_utils.py b/evaluation/predict_utils.py index eadc15f..17ad487 100644 --- a/evaluation/predict_utils.py +++ b/evaluation/predict_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Union, Callable +from typing import Optional, List, Tuple, Callable import torch import numpy as np @@ -12,7 +12,7 @@ def predict_image( device: str = "cpu", num_workers: int = 0, indices: Optional[List[int]] = None -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Runs a model on a dataset, performing a forward pass on all (or a subset of) input images in evaluation mode and returning a stacked tensor of predictions. From 2e1532ebfe61dc7821c4674627bed649fcb67bdb Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 00:09:18 -0700 Subject: [PATCH 68/89] Added new functions to visualization_utils.py for visualization of selected input and target patches/images from dataset along sid emodel prediction. The new plot functions include one, plot_predictions_grid_from_eval, that can operate down stream of existing inference/evaluation results to avoid redundant forward passes. A different plot function, plot_predictions_grid_from_model, internally performs inference and evluation and visualizes results from trained model and dataset to enable visualization without the need for beforehand inference and evaluation. --- evaluation/visualization_utils.py | 162 +++++++++++++++++++++++++++++- 1 file changed, 161 insertions(+), 1 deletion(-) diff --git a/evaluation/visualization_utils.py b/evaluation/visualization_utils.py index f98c1b6..091f3b3 100644 --- a/evaluation/visualization_utils.py +++ b/evaluation/visualization_utils.py @@ -3,6 +3,7 @@ import random import numpy as np +import pandas as pd import torch from torch.utils.data import Dataset import matplotlib.pyplot as plt @@ -11,6 +12,10 @@ from albumentations import ImageOnlyTransform from albumentations.core.composition import Compose +from ..datasets.PatchDataset import PatchDataset +from ..evaluation.predict_utils import predict_image, process_tensor_image +from ..evaluation.evaluation_utils import evaluate_per_image_metric + def invert_transforms( numpy_img: np.ndarray, transforms: Union[ImageOnlyTransform, Compose] = None @@ -213,4 +218,159 @@ def plot_patches( if save_path is not None: plt.savefig(save_path) - plt.close() \ No newline at end of file + plt.close() + +def _plot_predictions_grid( + inputs: Union[np.ndarray, torch.Tensor], + targets: Union[np.ndarray, torch.Tensor], + predictions: Union[np.ndarray, torch.Tensor], + raw_images: Optional[Union[np.ndarray, torch.Tensor]] = None, + patch_coords: Optional[List[tuple]] = None, + metrics_df: Optional[pd.DataFrame] = None, + save_path: Optional[str] = None, + show: bool = True, +): + """ + Generalized function to plot a grid of images with predictions and optional raw images. + The Batch dimensions of (raw_image), input, target, and prediction should match and so should the length of metrics_df. + + :param inputs: Input images (N, C, H, W) or (N, H, W). + :param targets: Target images (N, C, H, W) or (N, H, W). + :param predictions: Model predictions (N, C, H, W) or (N, H, W). + :param raw_images: Optional raw images for PatchDataset (N, H, W). + :param patch_coords: Optional list of (x, y) coordinates for patches. + Only used if raw_images is provided. Length match the first dimension of inputs/targets/predictions. + :param metrics_df: Optional DataFrame with per-image metrics. + :param save_path: If provided, saves figure. + :param show: Whether to display the plot. + """ + + num_samples = len(inputs) + is_patch_dataset = raw_images is not None + num_cols = 4 if is_patch_dataset else 3 # (Raw | Input | Target | Prediction) vs (Input | Target | Prediction) + + fig, axes = plt.subplots(num_samples, num_cols, figsize=(5 * num_cols, 5 * num_samples)) + column_titles = ["Raw Image", "Input", "Target", "Prediction"] if is_patch_dataset else ["Input", "Target", "Prediction"] + + for row_idx in range(num_samples): + img_set = [raw_images[row_idx]] if is_patch_dataset else [] + img_set.extend([inputs[row_idx], targets[row_idx], predictions[row_idx]]) + + for col_idx, img in enumerate(img_set): + ax = axes[row_idx, col_idx] + ax.imshow(img.squeeze(), cmap="gray") + ax.set_title(column_titles[col_idx]) + ax.axis("off") + + # Draw rectangle on raw image if PatchDataset + if is_patch_dataset and col_idx == 0 and patch_coords is not None: + patch_x, patch_y = patch_coords[row_idx] # (x, y) coordinates + patch_size = targets.shape[-1] # Assume square patches from target size + rect = Rectangle((patch_x, patch_y), patch_size, patch_size, linewidth=2, edgecolor="r", facecolor="none") + ax.add_patch(rect) + + # Display metrics if provided + if metrics_df is not None: + metric_values = metrics_df.iloc[row_idx] + metric_text = "\n".join([f"{key}: {value:.3f}" for key, value in metric_values.items()]) + axes[row_idx, -1].set_title(metric_text, fontsize=10, pad=10) + + # Save and/or show the plot + if save_path: + plt.savefig(save_path, bbox_inches="tight", dpi=300) + if show: + plt.show() + else: + plt.close() + +def plot_predictions_grid_from_eval( + dataset: torch.utils.data.Dataset, + predictions: Union[torch.Tensor, np.ndarray], + indices: List[int], + metrics_df: Optional[pd.DataFrame] = None, + save_path: Optional[str] = None, + show: bool = True, +): + """ + Wrapper function to extract dataset samples and call `_plot_predictions_grid`. + This function operates on the outputs downstream of `evaluate_per_image_metric` + and `predict_image` to avoid unecessary forward pass. + + :param dataset: Dataset (either normal or PatchDataset). + :param predictions: Subsetted tensor/NumPy array of predictions. + :param indices: Indices corresponding to the subset. + :param metrics_df: DataFrame with per-image metrics for the subset. + :param save_path: If provided, saves figure. + :param show: Whether to display the plot. + """ + + is_patch_dataset = isinstance(dataset, PatchDataset) + + # Extract input, target, and (optional) raw images & patch coordinates + raw_images, inputs, targets, patch_coords = [], [], [], [] + for i in indices: + inputs.append(dataset[i][0]) + targets.append(dataset[i][1]) + if is_patch_dataset: + raw_images.append(dataset.raw_input) + patch_coords.append(dataset.patch_coords) # Get patch location + + inputs_numpy = process_tensor_image(torch.stack(inputs), invert_function=dataset.input_transform.invert) + targets_numpy = process_tensor_image(torch.stack(targets), invert_function=dataset.target_transform.invert) + + # Pass everything to the core grid function + _plot_predictions_grid( + inputs_numpy, targets_numpy, predictions[indices], + raw_images if is_patch_dataset else None, + patch_coords if is_patch_dataset else None, + metrics_df, save_path, show + ) + +def plot_predictions_grid_from_model( + model: torch.nn.Module, + dataset: torch.utils.data.Dataset, + indices: List[int], + metrics: List[torch.nn.Module], + device: str = "cuda", + save_path: Optional[str] = None, + show: bool = True, +): + """ + Wrapper plot function that internally performs inference and evaluation with the following steps: + 1. Perform inference on a subset of the dataset given the model. + 2. Compute per-image metrics on that subset. + 3. Plot the results with core `_plot_predictions_grid` function. + + :param model: PyTorch model for inference. + :param dataset: The dataset to use for evaluation and plotting. + :param indices: List of dataset indices to evaluate and visualize. + :param metrics: List of metric functions to evaluate. + :param device: Device to run inference on ("cpu" or "cuda"). + :param save_path: Optional path to save the plot. + :param show: Whether to display the plot. + """ + # Step 1: Run inference on the selected subset + predictions, targets = predict_image(dataset, model, indices=indices, device=device) + + # Step 2: Compute per-image metrics for the subset + metrics_df = evaluate_per_image_metric(predictions, targets, metrics) + + # Step 3: Extract subset of inputs & targets and plot + is_patch_dataset = isinstance(dataset, PatchDataset) + raw_images, inputs, targets, patch_coords = [], [], [], [] + for i in indices: + inputs.append(dataset[i][0]) + targets.append(dataset[i][1]) + if is_patch_dataset: + raw_images.append(dataset.raw_input) + patch_coords.append(dataset.patch_coords) # Get patch location + + _plot_predictions_grid( + torch.stack(inputs), + torch.stack(targets), + predictions, + raw_images=raw_images if is_patch_dataset else None, + patch_coords=patch_coords if is_patch_dataset else None, + metrics_df=metrics_df, + save_path=save_path, + show=show) \ No newline at end of file From fda7de710ac24c76bd15f46157355ca3a39ccf90 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 00:48:19 -0700 Subject: [PATCH 69/89] Added kwargs support for plot parameters and fixed metrics in title --- evaluation/visualization_utils.py | 35 ++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/evaluation/visualization_utils.py b/evaluation/visualization_utils.py index 091f3b3..6b07a51 100644 --- a/evaluation/visualization_utils.py +++ b/evaluation/visualization_utils.py @@ -228,7 +228,7 @@ def _plot_predictions_grid( patch_coords: Optional[List[tuple]] = None, metrics_df: Optional[pd.DataFrame] = None, save_path: Optional[str] = None, - show: bool = True, + **kwargs ): """ Generalized function to plot a grid of images with predictions and optional raw images. @@ -242,14 +242,20 @@ def _plot_predictions_grid( Only used if raw_images is provided. Length match the first dimension of inputs/targets/predictions. :param metrics_df: Optional DataFrame with per-image metrics. :param save_path: If provided, saves figure. - :param show: Whether to display the plot. + :param kwargs: Additional keyword arguments to pass to plt.subplots. """ + cmap = kwargs.get("cmap", "gray") + panel_width = kwargs.get("panel_width", 5) + show_plot = kwargs.get("show_plot", True) + fig_size = kwargs.get("fig_size", None) + num_samples = len(inputs) is_patch_dataset = raw_images is not None num_cols = 4 if is_patch_dataset else 3 # (Raw | Input | Target | Prediction) vs (Input | Target | Prediction) - fig, axes = plt.subplots(num_samples, num_cols, figsize=(5 * num_cols, 5 * num_samples)) + fig_size = (panel_width * num_cols, panel_width * num_samples) if fig_size is None else fig_size + fig, axes = plt.subplots(num_samples, num_cols, figsize=fig_size) column_titles = ["Raw Image", "Input", "Target", "Prediction"] if is_patch_dataset else ["Input", "Target", "Prediction"] for row_idx in range(num_samples): @@ -258,7 +264,7 @@ def _plot_predictions_grid( for col_idx, img in enumerate(img_set): ax = axes[row_idx, col_idx] - ax.imshow(img.squeeze(), cmap="gray") + ax.imshow(img.squeeze(), cmap=cmap) ax.set_title(column_titles[col_idx]) ax.axis("off") @@ -273,12 +279,13 @@ def _plot_predictions_grid( if metrics_df is not None: metric_values = metrics_df.iloc[row_idx] metric_text = "\n".join([f"{key}: {value:.3f}" for key, value in metric_values.items()]) - axes[row_idx, -1].set_title(metric_text, fontsize=10, pad=10) + axes[row_idx, -1].set_title( + axes[row_idx, -1].get_title() + "\n" + metric_text, fontsize=10, pad=10) # Save and/or show the plot if save_path: plt.savefig(save_path, bbox_inches="tight", dpi=300) - if show: + if show_plot: plt.show() else: plt.close() @@ -289,7 +296,7 @@ def plot_predictions_grid_from_eval( indices: List[int], metrics_df: Optional[pd.DataFrame] = None, save_path: Optional[str] = None, - show: bool = True, + **kwargs ): """ Wrapper function to extract dataset samples and call `_plot_predictions_grid`. @@ -301,7 +308,7 @@ def plot_predictions_grid_from_eval( :param indices: Indices corresponding to the subset. :param metrics_df: DataFrame with per-image metrics for the subset. :param save_path: If provided, saves figure. - :param show: Whether to display the plot. + :param kwargs: Additional keyword arguments to pass to `_plot_predictions_grid`. """ is_patch_dataset = isinstance(dataset, PatchDataset) @@ -323,7 +330,7 @@ def plot_predictions_grid_from_eval( inputs_numpy, targets_numpy, predictions[indices], raw_images if is_patch_dataset else None, patch_coords if is_patch_dataset else None, - metrics_df, save_path, show + metrics_df, save_path, **kwargs ) def plot_predictions_grid_from_model( @@ -333,7 +340,7 @@ def plot_predictions_grid_from_model( metrics: List[torch.nn.Module], device: str = "cuda", save_path: Optional[str] = None, - show: bool = True, + **kwargs ): """ Wrapper plot function that internally performs inference and evaluation with the following steps: @@ -347,10 +354,10 @@ def plot_predictions_grid_from_model( :param metrics: List of metric functions to evaluate. :param device: Device to run inference on ("cpu" or "cuda"). :param save_path: Optional path to save the plot. - :param show: Whether to display the plot. + :param kwargs: Additional keyword arguments to pass to `_plot_predictions_grid`. """ # Step 1: Run inference on the selected subset - predictions, targets = predict_image(dataset, model, indices=indices, device=device) + targets, predictions = predict_image(dataset, model, indices=indices, device=device) # Step 2: Compute per-image metrics for the subset metrics_df = evaluate_per_image_metric(predictions, targets, metrics) @@ -372,5 +379,5 @@ def plot_predictions_grid_from_model( raw_images=raw_images if is_patch_dataset else None, patch_coords=patch_coords if is_patch_dataset else None, metrics_df=metrics_df, - save_path=save_path, - show=show) \ No newline at end of file + save_path=save_path, + **kwargs) \ No newline at end of file From f909e342a805b0b6b793f07c47ea44b99216c62d Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 00:49:09 -0700 Subject: [PATCH 70/89] Updated IntermedaitePlot callback class for compatibility with new plotting function --- callbacks/IntermediatePlot.py | 48 ++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/callbacks/IntermediatePlot.py b/callbacks/IntermediatePlot.py index 6264fae..fe2c102 100644 --- a/callbacks/IntermediatePlot.py +++ b/callbacks/IntermediatePlot.py @@ -1,11 +1,13 @@ from typing import List, Union +import random import torch import torch.nn as nn +from torch.utils.data import Dataset from .AbstractCallback import AbstractCallback from ..datasets.PatchDataset import PatchDataset - +from ..evaluation.visualization_utils import plot_predictions_grid_from_model from ..evaluation.visualization_utils import plot_patches class IntermediatePatchPlot(AbstractCallback): @@ -17,28 +19,38 @@ class IntermediatePatchPlot(AbstractCallback): def __init__(self, name: str, path: str, - dataset: PatchDataset, + dataset: Union[Dataset, PatchDataset], plot_n_patches: int=5, + indices: Union[List[int], None]=None, plot_metrics: List[nn.Module]=None, every_n_epochs: int=5, + random_seed: int=42, **kwargs): """ - Initialize the IntermediatePlot callback. + Initialize the IntermediatePlot callback. + This callback, when passed into the trainer, will plot the model predictions on a subset of the provided dataset at the end of each epoch. :param name: Name of the callback. :type name: str :param path: Path to save the model weights. :type path: str :param dataset: Dataset to be used for plotting intermediate results. - :type dataset: PatchDataset - :param plot_n_patches: Number of patches to plot, defaults to 5. + :type dataset: Union[Dataset, PatchDataset] + :param plot_n_patches: Number of patches to randomly select and plot, defaults to 5. + The exact patches/images being plotted may vary due to a difference in seed or dataset size. + To ensure best reproducibility and consistency, please use a fixed dataset and indices argument instead. :type plot_n_patches: int, optional + :param indices: Optional list of specific indices to subset the dataset before inference. + Overrides the plot_n_patches and random_seed arguments and uses the indices list to subset. + :type indices: Union[List[int], None] :param plot_metrics: List of metrics to compute and display in plot title, defaults to None. :type plot_metrics: List[nn.Module], optional :param kwargs: Additional keyword arguments to be passed to plot_patches. :type kwargs: dict :param every_n_epochs: How frequent should intermediate plots should be plotted, defaults to 5 :type every_n_epochs: int + :param random_seed: Random seed for reproducibility for random patch/image selection, defaults to 42. + :type random_seed: int :raises TypeError: If the dataset is not an instance of PatchDataset. """ super().__init__(name) @@ -48,10 +60,21 @@ def __init__(self, self._dataset = dataset # Additional kwargs passed to plot_patches - self.plot_n_patches = plot_n_patches self.plot_metrics = plot_metrics self.every_n_epochs = every_n_epochs self.plot_kwargs = kwargs + + if indices is not None: + # Check if indices are within bounds + for i in indices: + if i >= len(self._dataset): + raise ValueError(f"Index {i} out of bounds for dataset of size {len(self._dataset)}") + self._dataset_subset_indices = indices + else: + # Generate random indices to subset given seed and plot_n_patches + plot_n_patches = min(plot_n_patches, len(self._dataset)) + random.seed(random_seed) + self._dataset_subset_indices = random.sample(range(len(self._dataset)), plot_n_patches) def on_epoch_end(self): """ @@ -76,12 +99,13 @@ def _plot(self): original_device = next(self.trainer.model.parameters()).device - plot_patches( - _dataset = self._dataset, - _n_patches = self.plot_n_patches, - _model = self.trainer.model, - _metrics = self.plot_metrics, - save_path = f"{self._path}/epoch_{self.trainer.epoch}.png", + plot_predictions_grid_from_model( + model=self.trainer.model, + dataset=self._dataset, + indices=self._dataset_subset_indices, + metrics=self.plot_metrics, + save_path=f"{self._path}/epoch_{self.trainer.epoch}.png", device=original_device, + show=False, **self.plot_kwargs ) \ No newline at end of file From 14132809dc6ef624806135f51bceee7f1f071b7a Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 00:52:16 -0700 Subject: [PATCH 71/89] Renamed callback name and modified type checking to reflect that it supports both PatchDataset and standard ImageDataset --- callbacks/IntermediatePlot.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/callbacks/IntermediatePlot.py b/callbacks/IntermediatePlot.py index fe2c102..22c2260 100644 --- a/callbacks/IntermediatePlot.py +++ b/callbacks/IntermediatePlot.py @@ -8,9 +8,8 @@ from .AbstractCallback import AbstractCallback from ..datasets.PatchDataset import PatchDataset from ..evaluation.visualization_utils import plot_predictions_grid_from_model -from ..evaluation.visualization_utils import plot_patches -class IntermediatePatchPlot(AbstractCallback): +class IntermediatePlot(AbstractCallback): """ Callback to plot model generated outputs, ground truth, and input stained image patches at the end of each epoch. @@ -27,7 +26,9 @@ def __init__(self, random_seed: int=42, **kwargs): """ - Initialize the IntermediatePlot callback. + Initialize the IntermediatePlot callback. + Allows plots of predictions to be generated during training for monitoring of training progress. + Supports both PatchDataset and Dataset classes for plotting. This callback, when passed into the trainer, will plot the model predictions on a subset of the provided dataset at the end of each epoch. :param name: Name of the callback. @@ -55,8 +56,13 @@ def __init__(self, """ super().__init__(name) self._path = path - if not isinstance(dataset, PatchDataset): + if isinstance(dataset, Dataset): + pass + if isinstance(dataset, PatchDataset): + pass + else: raise TypeError(f"Expected PatchDataset, got {type(dataset)}") + self._dataset = dataset # Additional kwargs passed to plot_patches From 0da8875ca5675254ed0653e62c939b71e14cb272 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 00:55:53 -0700 Subject: [PATCH 72/89] Update comment to better reflect what the functions are doing --- evaluation/predict_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/evaluation/predict_utils.py b/evaluation/predict_utils.py index 17ad487..bba623b 100644 --- a/evaluation/predict_utils.py +++ b/evaluation/predict_utils.py @@ -90,10 +90,11 @@ def process_tensor_image( if dtype is not None: output_images = output_images.astype(dtype) - # Apply inverse invert function when supplied or transformation if dataset supplied and target_transform is valid + # Apply invert function when supplied or transformation if invert function is supplied if invert_function is not None and isinstance(invert_function, Callable): output_images = np.array([invert_function(img) for img in output_images]) elif dataset is not None and hasattr(dataset, "target_transform"): + # Apply inverted target transformation if available target_transform = dataset.target_transform if isinstance(target_transform, (ImageOnlyTransform, Compose)): # Apply the transformation on each image From 8d5a616c5f4ea25bf701e0bd2152435f4ee7dd2a Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 00:58:32 -0700 Subject: [PATCH 73/89] Removed uneeded files and functions --- evaluation/plot_utils.py | 291 ------------------------------ evaluation/visualization_utils.py | 215 +--------------------- 2 files changed, 3 insertions(+), 503 deletions(-) delete mode 100644 evaluation/plot_utils.py diff --git a/evaluation/plot_utils.py b/evaluation/plot_utils.py deleted file mode 100644 index c001078..0000000 --- a/evaluation/plot_utils.py +++ /dev/null @@ -1,291 +0,0 @@ -import random - -import torch -import numpy as np -import matplotlib.pyplot as plt -from typing import Union, Optional, List -from PIL import Image -from torch.utils.data import Dataset -from matplotlib.gridspec import GridSpec -from matplotlib.patches import Rectangle - -from ..datasets.PatchDataset import PatchDataset -from ..evaluation.predict_utils import predict_image, process_tensor_image -from ..evaluation.evaluation_utils import evaluate_per_image_metric - -def plot_single_image( - image: Union[np.ndarray, torch.Tensor], - ax: Optional[plt.Axes] = None, - cmap: str = "gray", - vmin: Optional[float] = None, - vmax: Optional[float] = None, - title: Optional[str] = None, - title_fontsize: int = 10 -): - """ - Plots a single image on the given matplotlib axis or creates a new figure if no axis is provided. - - :param image: The image to plot, either as a NumPy array or a PyTorch tensor. - :type image: Union[np.ndarray, torch.Tensor] - :param ax: Optional existing axis to plot on. If None, a new figure is created. - :type ax: Optional[plt.Axes], default is None - :param cmap: Colormap for visualization (default: "gray"). - :type cmap: str, optional - :param vmin: Minimum value for image scaling. If None, defaults to image min. - :type vmin: Optional[float], optional - :param vmax: Maximum value for image scaling. If None, defaults to image max. - :type vmax: Optional[float], optional - :param title: Optional title for the image. - :type title: Optional[str], optional - :param title_fontsize: Font size of the title. - :type title_fontsize: int, optional - """ - - # Convert tensor to NumPy - if isinstance(image, torch.Tensor): - image = image.cpu().numpy() - - # Handle grayscale vs multi-channel - if image.ndim == 3 and image.shape[0] in [1, 3]: # (C, H, W) format - image = np.transpose(image, (1, 2, 0)) # Convert to (H, W, C) - - # Create a new figure if no axis is provided - if ax is None: - fig, ax = plt.subplots(figsize=(5, 5)) - - # Plot the image - im = ax.imshow(image, cmap=cmap, vmin=vmin or image.min(), vmax=vmax or image.max()) - - # Hide axis - ax.axis("off") - - # Add title if provided - if title: - ax.set_title(title, fontsize=title_fontsize, fontweight="bold") - - return im # Return image object for further customization if needed - -def visualize_images_with_stats( - images: Union[torch.Tensor, np.ndarray], - cmap: str = "gray", - figsize: Optional[tuple] = None, - panel_width: int = 3, - show_stats: bool = True, - channel_names: Optional[list] = None, - title_fontsize: int = 10, - axes: Optional[np.ndarray] = None -): - """ - Visualizes images using matplotlib, handling various shapes: - - (H, W) → Single grayscale image. - - (C, H, W) → Multi-channel image. - - (N, C, H, W) → Multiple images, multiple channels. - - Supports external axes input for easier integration. - - :param images: Input images as PyTorch tensor or NumPy array. - :param cmap: Colormap for visualization. - :param figsize: Optional figure size. - :param panel_width: Width of each panel. - :param show_stats: Whether to display statistics (μ, σ, ⊥, ⊤) in titles. - :param channel_names: List of channel names for first row. - :param title_fontsize: Font size for titles. - :param axes: Optional pre-existing matplotlib Axes. - """ - if isinstance(images, torch.Tensor): - images = images.cpu().numpy() - - ndim = images.ndim - if ndim == 2: - images = images[np.newaxis, np.newaxis, ...] # Convert to (1, 1, H, W) - elif ndim == 3: - images = images[np.newaxis, ...] # Convert to (1, C, H, W) - elif ndim != 4: - raise ValueError(f"Unsupported shape {images.shape}. Expected (H, W), (C, H, W), or (N, C, H, W).") - - n_images, n_channels, _, _ = images.shape - - # Create figure and axes if not provided - if axes is None: - figsize = figsize or (n_channels * panel_width, n_images * panel_width) - fig, axes = plt.subplots(n_images, n_channels, figsize=figsize, squeeze=False) - - for i in range(n_images): - for j in range(n_channels): - img = images[i, j] - title = None - - # Compute statistics if needed - if show_stats: - img_mean, img_std, img_min, img_max = np.mean(img), np.std(img), np.min(img), np.max(img) - title = f"μ: {img_mean:.2f} | σ: {img_std:.2f} | ⊥: {img_min:.2f} | ⊤: {img_max:.2f}" - - # Use the helper function for plotting - plot_single_image(img, ax=axes[i, j], cmap=cmap, title=title, title_fontsize=title_fontsize) - - plt.tight_layout() - if axes is None: - plt.show() - -def plot_single_image( - image: np.ndarray, - ax: plt.Axes, - cmap: str = "gray", - vmin: Optional[float] = None, - vmax: Optional[float] = None, - title: Optional[str] = None, - title_fontsize: int = 10 -): - """ - Plots a single image on the given matplotlib axis. - - :param image: The image to plot (NumPy array). - :param ax: The matplotlib axis to plot on. - :param cmap: Colormap for visualization. - :param vmin: Minimum value for scaling. - :param vmax: Maximum value for scaling. - :param title: Optional title for the image. - :param title_fontsize: Font size of the title. - """ - ax.imshow(np.squeeze(image), cmap=cmap, vmin=vmin, vmax=vmax) - ax.axis("off") - if title: - ax.set_title(title, fontsize=title_fontsize, fontweight="bold") - -def plot_patches( - dataset: torch.utils.data.Dataset, - n_patches: int = 5, - model: Optional[torch.nn.Module] = None, - patch_index: Optional[List[int]] = None, - random_seed: int = 42, - metrics: Optional[List[torch.nn.Module]] = None, - device: str = "cpu", - **kwargs -): - """ - Plots dataset patches with optional model predictions and evaluation metrics using GridSpec. - Uses `plot_single_image` to ensure consistency. - - :param dataset: A dataset that returns (input_tensor, target_tensor) tuples. - :param n_patches: Number of patches to visualize (default: 5). - :param model: Optional PyTorch model to run inference on patches. - :param patch_index: List of dataset indices to select specific patches. - :param random_seed: Random seed for reproducibility. - :param metrics: List of metric functions to evaluate model predictions. - :param device: Device to run model inference on, e.g., "cpu" or "cuda". - :param **kwargs: Additional customization options (e.g., `cmap`, `panel_width`, `show_plot`). - """ - - cmap = kwargs.get("cmap", "gray") - panel_width = kwargs.get("panel_width", 5) - show_plot = kwargs.get("show_plot", True) - save_path = kwargs.get("save_path", None) - title_fontsize = kwargs.get("title_fontsize", 12) - - # Select patches - if patch_index is None: - random.seed(random_seed) - patch_index = random.sample(range(len(dataset)), n_patches) - else: - patch_index = [i for i in patch_index if i < len(dataset)] - n_patches = len(patch_index) - - inputs, targets, raw_images, patch_coords = [], [], [], [] - for i in patch_index: - input_tensor, target_tensor = dataset[i] - inputs.append(input_tensor) - targets.append(target_tensor) - patch_coords.append(dataset.patch_coords) # Extract (x, y) coordinates - raw_images.append(np.array(Image.open(dataset.input_names[0]))) - - inputs = torch.stack(inputs) - targets = torch.stack(targets) - - # Run model predictions (if provided) - _, predictions = predict_image(dataset, model, device=device, indices=patch_index) if model else None - - # Convert tensors to NumPy arrays - inputs_numpy = process_tensor_image(inputs, invert_function=dataset.input_transform.invert) - targets_numpy = process_tensor_image(targets, dataset=dataset) - predictions_numpy = process_tensor_image(predictions, dataset=dataset) if predictions is not None else None - - # Compute evaluation metrics (if applicable) - if metrics and predictions is not None: - metric_values = evaluate_per_image_metric( - predictions=predictions, - targets=targets, - metrics=metrics - ) - else: - metric_values = None - - # Determine number of columns (Raw + Input + Target + Optional Predictions) - n_predictions = predictions_numpy.shape[1] if predictions_numpy is not None else 0 - n_columns = 3 + n_predictions # (Raw, Input, Target, Predictions) - - # Compute raw image global vmin/vmax - raw_vmin, raw_vmax = np.min(raw_images), np.max(raw_images) - - # Set up figure and GridSpec layout with an extra row for column titles - figsize = (panel_width * n_columns, panel_width * (n_patches + 1)) # Extra space for headers - fig = plt.figure(figsize=figsize) - gs = GridSpec(n_patches + 1, n_columns, figure=fig, height_ratios=[0.05] + [1] * n_patches, hspace=0.05, wspace=0.05) - - # Column headers (Shared Titles) - column_titles = ["Raw Image", "Input Patch", "Target Patch"] + [f"Predicted {i+1}" for i in range(n_predictions)] - for j, title in enumerate(column_titles): - ax = fig.add_subplot(gs[0, j]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.axis("off") - ax.text(0.5, 0.5, title, ha="center", va="center", fontsize=title_fontsize, fontweight="bold") - - # Iterate through patches and plot each column separately - for i in range(n_patches): - row_offset = i + 1 # Offset by 1 to account for the title row - - # Extract patch coordinates - patch_x, patch_y = patch_coords[i] - patch_size = targets_numpy.shape[-1] # Infer patch size from target shape - - # Compute per-patch vmin/vmax - input_vmin, input_vmax = np.min(inputs_numpy[i]), np.max(inputs_numpy[i]) - target_vmin, target_vmax = np.min(targets_numpy[i]), np.max(targets_numpy[i]) - - # Plot raw image with patch annotation - ax = fig.add_subplot(gs[row_offset, 0]) - plot_single_image(raw_images[i], ax, cmap, raw_vmin, raw_vmax) - rect = Rectangle((patch_x, patch_y), patch_size, patch_size, linewidth=2, edgecolor="r", facecolor="none") - ax.add_patch(rect) - - # Plot input patch - ax = fig.add_subplot(gs[row_offset, 1]) - plot_single_image(inputs_numpy[i], ax, cmap, input_vmin, input_vmax) - - # Plot target patch - ax = fig.add_subplot(gs[row_offset, 2]) - plot_single_image(targets_numpy[i], ax, cmap, target_vmin, target_vmax) - - # Plot prediction patches (if available) with metrics - if predictions_numpy is not None: - for j in range(n_predictions): - ax = fig.add_subplot(gs[row_offset, 3 + j]) - plot_single_image(predictions_numpy[i, j], ax, cmap, target_vmin, target_vmax) - - # Display metric values below prediction - metric_str = "" - if metric_values is not None: - metric_value_row = metric_values.iloc[i, :] - metric_str = "\n".join( - [f"{metric_name}: {metric_val:.2f}" for metric_name, metric_val in metric_value_row.items()] - ) - - ax.set_title(metric_str, fontsize=title_fontsize - 2) - - # Adjust layout and save/show - if save_path: - plt.savefig(save_path) - if show_plot: - plt.show() - else: - plt.close() \ No newline at end of file diff --git a/evaluation/visualization_utils.py b/evaluation/visualization_utils.py index 6b07a51..1b27699 100644 --- a/evaluation/visualization_utils.py +++ b/evaluation/visualization_utils.py @@ -1,6 +1,4 @@ -import pathlib -from typing import Tuple, List, Union, Optional -import random +from typing import List, Union, Optional import numpy as np import pandas as pd @@ -8,218 +6,11 @@ from torch.utils.data import Dataset import matplotlib.pyplot as plt from matplotlib.patches import Rectangle -from PIL import Image -from albumentations import ImageOnlyTransform -from albumentations.core.composition import Compose from ..datasets.PatchDataset import PatchDataset from ..evaluation.predict_utils import predict_image, process_tensor_image from ..evaluation.evaluation_utils import evaluate_per_image_metric -def invert_transforms( - numpy_img: np.ndarray, - transforms: Union[ImageOnlyTransform, Compose] = None - ) -> np.ndarray: - - if isinstance(transforms, ImageOnlyTransform): - return transforms.invert(numpy_img) - elif isinstance(transforms, Compose): - for transform in reversed(transforms.transforms): - numpy_img = transform.invert(numpy_img) - elif transforms is None: - return numpy_img - else: - raise ValueError(f"Invalid transforms type: {type(transforms)}") - - return numpy_img - -def format_img( - _tensor_img: torch.Tensor, - cast_to_type: torch.dtype = None - ) -> np.ndarray: - - if cast_to_type is not None: - _tensor_img = _tensor_img.to(cast_to_type) - - img = torch.squeeze(_tensor_img).cpu().numpy() - - return img - -def evaluate_and_format_imgs( - _input: torch.Tensor, - _target: torch.Tensor, - model=None, - _input_transform: Optional[Union[Compose, ImageOnlyTransform]]=None, - _target_transform: Optional[Union[Compose, ImageOnlyTransform]]=None, - device: str='cpu' - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - input_transform = invert_transforms( - format_img(_input), - _input_transform - ) - target_transform = invert_transforms( - format_img(_target), - _target_transform - ) - - if model is not None: - model.to(device) - model.eval() - with torch.no_grad(): - # Forward Pass - output = model(_input.unsqueeze(1).to(device)) - - output_transform = invert_transforms( - format_img(output), - _target_transform - ) - else: - output_transform = None - - return input_transform, target_transform, output_transform - -def plot_patch( - _raw_img: np.ndarray, - _patch_size: int, - _patch_coords: Tuple[int, int], - _input: torch.Tensor, - _target: torch.Tensor, - _output: torch.Tensor = None, - axes: List = None, - **kwargs -): - ## Plot keyword arguments - cmap = kwargs.get("cmap", "gray") - vmin = kwargs.get("vmin", None) - vmax = kwargs.get("vmax", None) - figsize = kwargs.get("figsize", None) - if figsize is None: - panel_width = kwargs.get("panel_width", 5) - figsize = (panel_width, panel_width * 3 if _output is None else 4) - else: - panel_width = None - - if axes is None: - fig, ax = plt.subplots(1, 3 if _output is None else 4, figsize=figsize) - else: - ax = axes - - # plot image - ax[0].imshow(_raw_img, cmap=cmap) - ax[0].set_title("Raw Image") - ax[0].axis("off") - - rect = Rectangle( - _patch_coords, - _patch_size, - _patch_size, - linewidth=1, - edgecolor="r", - facecolor="none" - ) - - if vmin is None: - vmin = min(_output.min(), _target.min()) - if vmax is None: - vmax = max(_output.max(), _target.max()) - - ax[0].add_patch(rect) - - # plot input - ax[1].imshow(_input, cmap=cmap) - ax[1].set_title("Input") - ax[1].axis("off") - - # plot target - ax[2].imshow(_target, cmap=cmap, vmin=vmin, vmax=vmax) - ax[2].set_title("Target") - ax[2].axis("off") - - if _output is not None: - ax[3].imshow(_output, cmap=cmap, vmin=vmin, vmax=vmax) - ax[3].set_title("Output") - ax[3].axis("off") - -def plot_patches( - _dataset: Dataset, - _n_patches: int=5, - _model: torch.nn.Module=None, - _patch_index: List[int]=None, - _random_seed: int=42, - _metrics: List[torch.nn.Module]=None, - device: str='cpu', - **kwargs -): - ## Plot keyword arguments - cmap = kwargs.get("cmap", "gray") - vmin = kwargs.get("vmin", None) - vmax = kwargs.get("vmax", None) - panel_width = kwargs.get("panel_width", 5) - save_path = kwargs.get("save_path", None) - show_plot = kwargs.get("show_plot", True) - - ## Generate random patch indices to visualize - if _patch_index is None: - random.seed(_random_seed) - _patch_index = random.sample(range(len(_dataset)), _n_patches) - else: - _patch_index = [i for i in _patch_index if i < len(_dataset)] - _n_patches = len(_patch_index) - - figsize = kwargs.get("figsize", None) - if figsize is None: - figsize = (panel_width * _n_patches, panel_width * 3 if _model is None else 4, ) - fig, axes = plt.subplots(_n_patches, 3 if _model is None else 4, figsize=figsize) - - for i, row_axes in zip(_patch_index, axes): - _input, _target = _dataset[i] - _raw_image = np.array(Image.open( - _dataset.input_names[0] - )) - _input, _target, _output = evaluate_and_format_imgs( - _input, - _target, - _model, - device=device - ) - - plot_patch( - _raw_img=_raw_image, - _patch_size=_input.shape[-1], - _patch_coords=_dataset.patch_coords, - _input=_input, - _target=_target, - _output=_output, - axes=row_axes, - cmap=cmap, - vmin=vmin, - vmax=vmax - ) - - ## Compute metrics for single set of (target, output) pairs and add to subplot title - metric_str = "" - if _metrics is not None: - for _metric in _metrics: - metric_val = _metric( - torch.tensor(_output).unsqueeze(0).unsqueeze(0), - torch.tensor(_target).unsqueeze(0).unsqueeze(0) - ).item() - metric_str = f"{metric_str}\n{_metric.__class__.__name__}: {metric_val:.2f}" - row_axes[-1].set_title( - row_axes[-1].get_title() + metric_str - ) - - plt.tight_layout() - - if show_plot: - plt.show() - - if save_path is not None: - plt.savefig(save_path) - - plt.close() - def _plot_predictions_grid( inputs: Union[np.ndarray, torch.Tensor], targets: Union[np.ndarray, torch.Tensor], @@ -291,7 +82,7 @@ def _plot_predictions_grid( plt.close() def plot_predictions_grid_from_eval( - dataset: torch.utils.data.Dataset, + dataset: Dataset, predictions: Union[torch.Tensor, np.ndarray], indices: List[int], metrics_df: Optional[pd.DataFrame] = None, @@ -335,7 +126,7 @@ def plot_predictions_grid_from_eval( def plot_predictions_grid_from_model( model: torch.nn.Module, - dataset: torch.utils.data.Dataset, + dataset: Dataset, indices: List[int], metrics: List[torch.nn.Module], device: str = "cuda", From baeb80bc9cac0cf051652bd69dca613eb6cb4c07 Mon Sep 17 00:00:00 2001 From: Weishan Li Date: Mon, 3 Mar 2025 01:03:09 -0700 Subject: [PATCH 74/89] Update example so it is consistent with the updated evaluation/plotting functions --- examples/minimal_example.ipynb | 357 ++++++++++++++++++++++++--------- 1 file changed, 261 insertions(+), 96 deletions(-) diff --git a/examples/minimal_example.ipynb b/examples/minimal_example.ipynb index f968a70..b0ea058 100644 --- a/examples/minimal_example.ipynb +++ b/examples/minimal_example.ipynb @@ -20,6 +20,14 @@ "text": [ "/home/weishanli/Waylab\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/weishanli/anaconda3/envs/cp_gan_env/lib/python3.9/site-packages/albumentations/__init__.py:28: UserWarning: A new version of Albumentations is available: '2.0.5' (you have '2.0.4'). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n", + " check_for_updates()\n" + ] } ], "source": [ @@ -60,7 +68,7 @@ "\n", "## callback\n", "from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger\n", - "from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePatchPlot\n" + "from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePlot\n" ] }, { @@ -317,27 +325,27 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-02-20 10:22:43,813 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n", - "2025-02-20 10:22:43,813 - DEBUG - Dataframe supplied for sc_feature, using as is\n", - "2025-02-20 10:22:43,813 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n", - "2025-02-20 10:22:43,813 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n", - "2025-02-20 10:22:43,813 - DEBUG - Merge fields inferred: ['Metadata_Site', 'Metadata_Well', 'Metadata_Plate']\n", - "2025-02-20 10:22:43,813 - DEBUG - Dataframe supplied for sc_feature, using as is\n", - "2025-02-20 10:22:43,850 - DEBUG - Inferring channel keys from loaddata csv\n", - "2025-02-20 10:22:43,851 - DEBUG - Channel keys: {'OrigER', 'OrigMito', 'OrigBrightfield', 'OrigRNA', 'OrigDNA', 'OrigAGP'} inferred from loaddata csv\n", - "2025-02-20 10:22:43,851 - DEBUG - Setting input channel(s) ...\n", - "2025-02-20 10:22:43,851 - DEBUG - No channel keys specified, skip\n", - "2025-02-20 10:22:43,851 - DEBUG - Setting target channel(s) ...\n", - "2025-02-20 10:22:43,851 - DEBUG - No channel keys specified, skip\n", - "2025-02-20 10:22:43,851 - DEBUG - Setting input transform ...\n", - "2025-02-20 10:22:43,851 - DEBUG - Setting target transform ...\n", - "2025-02-20 10:22:43,851 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n", - "2025-02-20 10:22:43,875 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n", - "2025-02-20 10:22:43,875 - DEBUG - Generating patches that contain cells\n", - "2025-02-20 10:22:43,899 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n", - "2025-02-20 10:22:44,318 - DEBUG - Generated 461 patches for 93 site/view\n", - "2025-02-20 10:22:44,319 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", - "2025-02-20 10:22:44,319 - DEBUG - Set target channel(s) as ['OrigDNA']\n" + "2025-03-03 00:56:37,394 - DEBUG - Dataframe supplied for loaddata_csv, using as is\n", + "2025-03-03 00:56:37,394 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-03-03 00:56:37,394 - DEBUG - X and Y columns Metadata_Cells_Location_Center_X, Metadata_Cells_Location_Center_Y detected in sc_feature dataframe, using as the coordinates for cell centers\n", + "2025-03-03 00:56:37,395 - DEBUG - Both loaddata_csv and sc_feature supplied, inferring merge fields to associate the two dataframes\n", + "2025-03-03 00:56:37,395 - DEBUG - Merge fields inferred: ['Metadata_Site', 'Metadata_Plate', 'Metadata_Well']\n", + "2025-03-03 00:56:37,395 - DEBUG - Dataframe supplied for sc_feature, using as is\n", + "2025-03-03 00:56:37,419 - DEBUG - Inferring channel keys from loaddata csv\n", + "2025-03-03 00:56:37,420 - DEBUG - Channel keys: {'OrigAGP', 'OrigBrightfield', 'OrigER', 'OrigDNA', 'OrigMito', 'OrigRNA'} inferred from loaddata csv\n", + "2025-03-03 00:56:37,420 - DEBUG - Setting input channel(s) ...\n", + "2025-03-03 00:56:37,420 - DEBUG - No channel keys specified, skip\n", + "2025-03-03 00:56:37,420 - DEBUG - Setting target channel(s) ...\n", + "2025-03-03 00:56:37,420 - DEBUG - No channel keys specified, skip\n", + "2025-03-03 00:56:37,420 - DEBUG - Setting input transform ...\n", + "2025-03-03 00:56:37,420 - DEBUG - Setting target transform ...\n", + "2025-03-03 00:56:37,420 - DEBUG - Extracting image channel paths of site/view and associatedcell coordinates (if applicable) from loaddata csv\n", + "2025-03-03 00:56:37,575 - DEBUG - Extracted images of all input and target channels for 93 unique sites/view and 10090 cells\n", + "2025-03-03 00:56:37,575 - DEBUG - Generating patches that contain cells\n", + "2025-03-03 00:56:37,590 - DEBUG - Image size inferred: 1080 for all images to force redetect image sizes for each view/site set consistent_img_size=False\n", + "2025-03-03 00:56:37,968 - DEBUG - Generated 461 patches for 93 site/view\n", + "2025-03-03 00:56:37,968 - DEBUG - Set input channel(s) as ['OrigBrightfield']\n", + "2025-03-03 00:56:37,969 - DEBUG - Set target channel(s) as ['OrigDNA']\n" ] } ], @@ -399,6 +407,7 @@ " epochs = 10,\n", " patience = 5,\n", " callbacks=None,\n", + " early_termination_metric = 'L1Loss',\n", " metrics={'psnr': PSNR(_metric_name=\"psnr\"), 'ssim': SSIM(_metric_name=\"ssim\")},\n", " device = 'cuda'\n", ")\n", @@ -413,6 +422,158 @@ "outputs": [ { "data": { + "application/vnd.microsoft.datawrangler.viewer.v0+json": { + "columns": [ + { + "name": "index", + "rawType": "int64", + "type": "integer" + }, + { + "name": "epoch", + "rawType": "int64", + "type": "integer" + }, + { + "name": "L1Loss", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_L1Loss", + "rawType": "float64", + "type": "float" + }, + { + "name": "psnr", + "rawType": "float64", + "type": "float" + }, + { + "name": "ssim", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_psnr", + "rawType": "float64", + "type": "float" + }, + { + "name": "val_ssim", + "rawType": "float64", + "type": "float" + } + ], + "conversionMethod": "pd.DataFrame", + "ref": "4b5efabf-7981-406e-9cdd-46083a756f22", + "rows": [ + [ + "0", + "1", + "0.3949891839708601", + "0.35189451575279235", + "8.089171409606934", + "0.050400376319885254", + "9.02688217163086", + "0.03860057145357132" + ], + [ + "1", + "2", + "0.2313196864866075", + "0.1457641005516052", + "12.610918998718262", + "0.07135152816772461", + "16.650663375854492", + "0.07930725812911987" + ], + [ + "2", + "3", + "0.13591892910855158", + "0.07575097531080247", + "16.98736572265625", + "0.11438453942537308", + "21.61182975769043", + "0.14077268540859222" + ], + [ + "3", + "4", + "0.08485395816110429", + "0.06444989740848542", + "20.824344635009766", + "0.1992468386888504", + "23.16204071044922", + "0.2225249856710434" + ], + [ + "4", + "5", + "0.05403244122862816", + "0.04417358413338661", + "24.34528160095215", + "0.3096939027309418", + "25.62473487854004", + "0.334009051322937" + ], + [ + "5", + "6", + "0.03749226778745651", + "0.029600178450345994", + "26.830354690551758", + "0.42897143959999084", + "28.255292892456055", + "0.44296297430992126" + ], + [ + "6", + "7", + "0.028990560787774268", + "0.021629180014133453", + "28.25572395324707", + "0.49695098400115967", + "30.104782104492188", + "0.5570668578147888" + ], + [ + "7", + "8", + "0.023570057448177112", + "0.01590950321406126", + "29.08185386657715", + "0.554701566696167", + "31.01228141784668", + "0.6276947855949402" + ], + [ + "8", + "9", + "0.01982145525869869", + "0.014556870982050895", + "29.655765533447266", + "0.5952008962631226", + "29.830678939819336", + "0.4739395081996918" + ], + [ + "9", + "10", + "0.018070991106686137", + "0.013890763558447362", + "29.52162742614746", + "0.5716341137886047", + "30.02247428894043", + "0.4999874234199524" + ] + ], + "shape": { + "columns": 7, + "rows": 10 + } + }, "text/html": [ "
\n", "