Skip to content

🪜 Small fixes and improvements #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 21, 2025
7 changes: 3 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ lightning_logs/
auto_tutorials_source/*.png
docs/*/generated/
docs/*/auto_tutorials/
*.pth
*.ckpt
*.out
docs/source/sg_execution_times.rst
test
**/*.pth
**/*.ckpt
**/*.out
**/*.csv
pyrightconfig.json

Expand Down
4 changes: 2 additions & 2 deletions experiments/depth/kitti/bts.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import torch
from lightning.pytorch.cli import LightningArgumentParser
from torch.optim.lr_scheduler import PolynomialLR

from torch_uncertainty import TULightningCLI
from torch_uncertainty.baselines.depth import BTSBaseline
from torch_uncertainty.datamodules.depth import KITTIDataModule
from torch_uncertainty.utils.learning_rate import PolyLR


class BTSCLI(TULightningCLI):
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.add_optimizer_args(torch.optim.AdamW)
parser.add_lr_scheduler_args(PolyLR)
parser.add_lr_scheduler_args(PolynomialLR)


def cli_main() -> BTSCLI:
Expand Down
4 changes: 2 additions & 2 deletions experiments/depth/nyu/bts.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import torch
from lightning.pytorch.cli import LightningArgumentParser
from torch.optim.lr_scheduler import PolynomialLR

from torch_uncertainty import TULightningCLI
from torch_uncertainty.baselines.depth import BTSBaseline
from torch_uncertainty.datamodules.depth import NYUv2DataModule
from torch_uncertainty.utils.learning_rate import PolyLR


class BTSCLI(TULightningCLI):
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.add_optimizer_args(torch.optim.AdamW)
parser.add_lr_scheduler_args(PolyLR)
parser.add_lr_scheduler_args(PolynomialLR)


def cli_main() -> BTSCLI:
Expand Down
4 changes: 2 additions & 2 deletions experiments/segmentation/camvid/deeplab.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import torch
from lightning.pytorch.cli import LightningArgumentParser
from torch.optim.lr_scheduler import PolynomialLR

from torch_uncertainty import TULightningCLI
from torch_uncertainty.baselines.segmentation import DeepLabBaseline
from torch_uncertainty.datamodules.segmentation import CamVidDataModule
from torch_uncertainty.utils.learning_rate import PolyLR


class DeepLabV3CLI(TULightningCLI):
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.add_optimizer_args(torch.optim.SGD)
parser.add_lr_scheduler_args(PolyLR)
parser.add_lr_scheduler_args(PolynomialLR)


def cli_main() -> DeepLabV3CLI:
Expand Down
4 changes: 2 additions & 2 deletions experiments/segmentation/cityscapes/deeplab.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import torch
from lightning.pytorch.cli import LightningArgumentParser
from torch.optim.lr_scheduler import PolynomialLR

from torch_uncertainty import TULightningCLI
from torch_uncertainty.baselines.segmentation import DeepLabBaseline
from torch_uncertainty.datamodules.segmentation import CityscapesDataModule
from torch_uncertainty.utils.learning_rate import PolyLR


class DeepLabV3CLI(TULightningCLI):
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.add_optimizer_args(torch.optim.SGD)
parser.add_lr_scheduler_args(PolyLR)
parser.add_lr_scheduler_args(PolynomialLR)


def cli_main() -> DeepLabV3CLI:
Expand Down
4 changes: 2 additions & 2 deletions tests/losses/test_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_prob_regression_training_step(self):
)
inputs = torch.randn(1, 10)
targets = torch.randn(1, 4)
routine.training_step((inputs, targets), 0)
routine.training_step((inputs, targets))

def test_training_step(self):
model = BayesLinear(10, 4)
Expand All @@ -62,7 +62,7 @@ def test_training_step(self):

inputs = torch.randn(1, 10)
targets = torch.randn(1, 4)
routine.training_step((inputs, targets), 0)
routine.training_step((inputs, targets))

def test_failures(self):
model = BayesLinear(1, 1)
Expand Down
5 changes: 4 additions & 1 deletion torch_uncertainty/baselines/classification/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
last_layer_dropout: bool = False,
width_multiplier: float = 1.0,
groups: int = 1,
conv_bias: bool = False,
scale: float | None = None,
alpha: int | None = None,
gamma: int = 1,
Expand Down Expand Up @@ -128,6 +129,8 @@ def __init__(
last_layer_dropout (bool): whether to apply dropout to the last layer only.
groups (int, optional): Number of groups in convolutions. Defaults to
``1``.
conv_bias (bool, optional): Whether to include bias in the convolutional
layers. Defaults to ``False``.
scale (float, optional): Expansion factor affecting the width of the
estimators. Only used if :attr:`version` is ``"masked"``. Defaults
to ``None``.
Expand Down Expand Up @@ -174,7 +177,7 @@ def __init__(
"""
params = {
"arch": arch,
"conv_bias": False,
"conv_bias": conv_bias,
"dropout_rate": dropout_rate,
"groups": groups,
"width_multiplier": width_multiplier,
Expand Down
76 changes: 71 additions & 5 deletions torch_uncertainty/datasets/classification/cub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path

import torch
from einops import rearrange
from torch import Tensor
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
Expand All @@ -20,6 +21,7 @@ def __init__(
train: bool = True,
transform: Callable | None = None,
target_transform: Callable | None = None,
return_attributes: bool = False,
download: bool = False,
):
"""The Caltech-UCSD Birds-200-2011 dataset.
Expand All @@ -32,9 +34,12 @@ def __init__(
returns a transformed version. E.g, transforms.RandomCrop. Defaults to None.
target_transform (callable, optional): A function/transform that takes in the target
and transforms it. Defaults to None.
return_attributes (bool, optional): If True, returns the attributes instead of the images.
Defaults to False.
download (bool, optional): If True, downloads the dataset from the internet and puts it
in root directory. If dataset is already downloaded, it is not downloaded again.
Defaults to
Defaults to False.

Reference:
Wah, C. and Branson, S. and Welinder, P. and Perona, P. and Belongie, S. Caltech-UCSD
Birds 200.
Expand All @@ -52,27 +57,88 @@ def __init__(
super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform)

training_idx = self._load_train_idx()
self.attributes, self.uncertainties = self._load_attributes()
self.attribute_names = self._load_attribute_names()
self.classnames = self._load_classnames()

self.samples = [sample for i, sample in enumerate(self.samples) if training_idx[i] == train]
self._labels = [label for i, label in enumerate(self.targets) if training_idx[i] == train]
self.attributes = rearrange(
torch.masked_select(self.attributes, training_idx.unsqueeze(-1) == train),
"(n c) -> n c",
c=312,
)
self.uncertainties = rearrange(
torch.masked_select(self.uncertainties, training_idx.unsqueeze(-1) == train),
"(n c) -> n c",
c=312,
)

if return_attributes:
self.samples = zip(self.attributes, [sam[1] for sam in self.samples], strict=False)
self.loader = torch.nn.Identity()

def _load_classnames(self) -> list[str]:
"""Load the classnames of the dataset.

Returns:
list[str]: the list containing the names of the 200 classes.
"""
with Path(self.folder_root / "CUB_200_2011" / "classes.txt").open("r") as f:
self.class_names = [
return [
line.split(" ")[1].split(".")[1].replace("\n", "").replace("_", " ") for line in f
]

def _load_train_idx(self) -> Tensor:
is_training_img = []
"""Load the index of the training data to make the split.

Returns:
Tensor: whether the images belong to the training or test split.
"""
with (self.folder_root / "CUB_200_2011" / "train_test_split.txt").open("r") as f:
is_training_img = [int(line.split(" ")[1]) for line in f]
return torch.as_tensor(is_training_img)
return torch.as_tensor([int(line.split(" ")[1]) for line in f])

def _load_attributes(self) -> tuple[Tensor, Tensor]:
"""Load the attributes associated to each image.

Returns:
tuple[Tensor, Tensor]: The presence of the 312 attributes along with their uncertainty.
The uncertainty is 0 for certain samples and 1 for non-visible attributes.
"""
attributes, uncertainty = [], []
with (self.folder_root / "CUB_200_2011" / "attributes" / "image_attribute_labels.txt").open(
"r"
) as f:
for line in f:
attributes.append(int(line.split(" ")[2]))
uncertainty.append(1 - (int(line.split(" ")[3]) - 1) / 3)
return rearrange(torch.as_tensor(attributes), "(n c) -> n c", c=312), rearrange(
torch.as_tensor(uncertainty), "(n c) -> n c", c=312
)

def _load_attribute_names(self) -> list[str]:
"""Load the names of the attributes.

Returns:
list[str]: The list of the names of the 312 attributes.
"""
with (self.folder_root / "attributes.txt").open("r") as f:
return [line.split(" ")[1].replace("\n", "").replace("_", " ") for line in f]

def _check_integrity(self) -> bool:
"""Check the integrity of the dataset.

Returns:
bool: True when the md5 of the archive corresponds.
"""
fpath = self.folder_root / self.filename
return check_integrity(
fpath,
self.tgz_md5,
)

def _download(self):
"""Download the dataset from caltec.edu."""
if self._check_integrity():
logging.info("Files already downloaded and verified")
return
Expand Down
14 changes: 7 additions & 7 deletions torch_uncertainty/datasets/regression/uci_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,17 @@ class UCIRegression(Dataset):
"4e6727f462779e2d396e8f7d2ddb79a3",
]
urls = [
"https://archive.ics.uci.edu/ml/machine-learning-databases/housing/" "housing.data",
"https://archive.ics.uci.edu/static/public/165/concrete+compressive+" "strength.zip",
"https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data",
"https://archive.ics.uci.edu/static/public/165/concrete+compressive+strength.zip",
"https://archive.ics.uci.edu/static/public/242/energy+efficiency.zip",
"https://archive.ics.uci.edu/static/public/374/appliances+energy+" "prediction.zip",
"https://www.openml.org/data/get_csv/3626/dataset_2175_kin8nm.arff",
"https://raw.githubusercontent.com/luishpinto/cm-naval-propulsion-" "plant/master/data.csv",
"https://archive.ics.uci.edu/static/public/294/combined+cycle+power+" "plant.zip",
"https://archive.ics.uci.edu/static/public/374/appliances+energy+prediction.zip",
"https://zenodo.org/records/14645866/files/kin8nm.csv",
"https://raw.githubusercontent.com/luishpinto/cm-naval-propulsion-plant/master/data.csv",
"https://archive.ics.uci.edu/static/public/294/combined+cycle+power+plant.zip",
"https://archive.ics.uci.edu/static/public/265/physicochemical+"
"properties+of+protein+tertiary+structure.zip",
"https://archive.ics.uci.edu/static/public/186/wine+quality.zip",
"https://archive.ics.uci.edu/static/public/243/yacht+" "hydrodynamics.zip",
"https://archive.ics.uci.edu/static/public/243/yacht+hydrodynamics.zip",
]

def __init__(
Expand Down
Loading
Loading