Skip to content

✨ Everything as a second level import #173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions tests/baselines/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,3 @@ def test_standard(self):
separable=True,
).eval()
_ = net(torch.rand(1, 3, 32, 32))

def test_errors(self):
with pytest.raises(ValueError):
DeepLabBaseline(
num_classes=10,
loss=nn.CrossEntropyLoss(),
version="test",
style="v3",
output_stride=16,
arch=50,
separable=True,
)
18 changes: 7 additions & 11 deletions tests/models/test_deeplab.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
import pytest
import torch

from torch_uncertainty.models.segmentation.deeplab import (
_DeepLabV3,
deep_lab_v3_resnet50,
deep_lab_v3_resnet101,
)
from torch_uncertainty.models.segmentation.deeplab import _DeepLabV3, deep_lab_v3_resnet


class TestDeeplab:
"""Testing the Deeplab class."""

@torch.no_grad()
def test_main(self):
model = deep_lab_v3_resnet50(10, "v3", 16, True, False).eval()
model = deep_lab_v3_resnet(10, 50, "v3", 16, True, False).eval()
model(torch.randn(1, 3, 32, 32))
model = deep_lab_v3_resnet50(10, "v3", 16, False, False).eval()
model = deep_lab_v3_resnet101(10, "v3+", 8, True, False).eval()
model = deep_lab_v3_resnet(10, 50, "v3", 16, False, False).eval()
model = deep_lab_v3_resnet(10, 101, "v3+", 8, True, False).eval()
model(torch.randn(1, 3, 32, 32))
model = deep_lab_v3_resnet101(10, "v3+", 8, False, False).eval()
model = deep_lab_v3_resnet(10, 101, "v3+", 8, False, False).eval()

def test_errors(self):
with pytest.raises(ValueError, match="Unknown backbone:"):
_DeepLabV3(10, "other", "v3", 16, True, False)
with pytest.raises(ValueError, match="output_stride: "):
deep_lab_v3_resnet50(10, "v3", 15, True, False)
deep_lab_v3_resnet(10, 50, "v3", 15, True, False)
with pytest.raises(ValueError, match="Unknown style: "):
deep_lab_v3_resnet50(10, "v2", 16, True, False)
deep_lab_v3_resnet(10, 50, "v2", 16, True, False)
16 changes: 4 additions & 12 deletions torch_uncertainty/baselines/depth/bts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,11 @@

from torch import nn

from torch_uncertainty.models.depth.bts import bts_resnet50, bts_resnet101
from torch_uncertainty.models.depth.bts import bts_resnet
from torch_uncertainty.routines import PixelRegressionRoutine


class BTSBaseline(PixelRegressionRoutine):
single = ["std"]
versions = {
"std": [
bts_resnet50,
bts_resnet101,
]
}
archs = [50, 101]

def __init__(
Expand All @@ -23,10 +16,10 @@ def __init__(
arch: int,
max_depth: float,
dist_family: str | None = None,
num_estimators: int = 1,
pretrained_backbone: bool = True,
) -> None:
params = {
"arch": arch,
"dist_family": dist_family,
"max_depth": max_depth,
"pretrained_backbone": pretrained_backbone,
Expand All @@ -37,12 +30,11 @@ def __init__(
if version not in self.versions:
raise ValueError(f"Unknown version {version}")

model = self.versions[version][self.archs.index(arch)](**params)
model = bts_resnet(**params)
super().__init__(
output_dim=1,
model=model,
output_dim=1,
loss=loss,
num_estimators=num_estimators,
format_batch_fn=format_batch_fn,
dist_family=dist_family,
)
Expand Down
18 changes: 3 additions & 15 deletions torch_uncertainty/baselines/segmentation/deeplab.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,11 @@

from torch import nn

from torch_uncertainty.models.segmentation.deeplab import (
deep_lab_v3_resnet50,
deep_lab_v3_resnet101,
)
from torch_uncertainty.models.segmentation import deep_lab_v3_resnet
from torch_uncertainty.routines.segmentation import SegmentationRoutine


class DeepLabBaseline(SegmentationRoutine):
single = ["std"]
versions = {
"std": [
deep_lab_v3_resnet50,
deep_lab_v3_resnet101,
]
}
archs = [50, 101]

def __init__(
Expand All @@ -35,6 +25,7 @@ def __init__(
) -> None:
params = {
"num_classes": num_classes,
"arch": arch,
"style": style,
"output_stride": output_stride,
"separable": separable,
Expand All @@ -43,10 +34,7 @@ def __init__(

format_batch_fn = nn.Identity()

if version not in self.versions:
raise ValueError(f"Unknown version {version}")

model = self.versions[version][self.archs.index(arch)](**params)
model = deep_lab_v3_resnet(**params)
super().__init__(
num_classes=num_classes,
model=model,
Expand Down
3 changes: 2 additions & 1 deletion torch_uncertainty/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
OnlineShoppersDataModule,
SpamBaseDataModule,
TinyImageNetDataModule,
UCIClassificationDataModule,
)
from .segmentation import CamVidDataModule, CityscapesDataModule
from .segmentation import CamVidDataModule, CityscapesDataModule, MUADDataModule
from .uci_regression import UCIRegressionDataModule
1 change: 1 addition & 0 deletions torch_uncertainty/datamodules/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
HTRU2DataModule,
OnlineShoppersDataModule,
SpamBaseDataModule,
UCIClassificationDataModule,
)
24 changes: 24 additions & 0 deletions torch_uncertainty/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
# ruff: noqa: F401
from .aggregated_dataset import AggregatedDataset
from .classification import (
CIFAR10C,
CIFAR10H,
CIFAR10N,
CIFAR100C,
CIFAR100N,
CUB,
HTRU2,
MNISTC,
BankMarketing,
DOTA2Games,
ImageNetA,
ImageNetC,
ImageNetO,
ImageNetR,
NotMNIST,
OnlineShoppers,
OpenImageO,
SpamBase,
TinyImageNet,
TinyImageNetC,
)
from .fractals import Fractals
from .frost import FrostImages
from .kitti import KITTIDepth
from .muad import MUAD
from .nyu import NYUv2
from .regression import UCIRegression
from .segmentation import CamVid, Cityscapes
13 changes: 13 additions & 0 deletions torch_uncertainty/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@
from .batch_ensemble import BatchConv2d, BatchLinear
from .bayesian import BayesConv1d, BayesConv2d, BayesConv3d, BayesLinear
from .channel_layer_norm import ChannelLayerNorm
from .distributions import (
CauchyConvNd,
CauchyLinear,
LaplaceConvNd,
LaplaceLinear,
NormalConvNd,
NormalInverseGammaConvNd,
NormalInverseGammaLinear,
NormalLinear,
StudentTConvNd,
StudentTLinear,
)
from .filter_response_norm import FilterResponseNorm1d, FilterResponseNorm2d, FilterResponseNorm3d
from .masksembles import MaskedConv2d, MaskedLinear
from .modules import Identity
from .packed import (
Expand Down
21 changes: 21 additions & 0 deletions torch_uncertainty/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
# ruff: noqa: F401
from .classification import (
batched_resnet,
batched_wideresnet28x10,
batchensemble_lenet,
bayesian_lenet,
lenet,
lpbnn_resnet,
masked_resnet,
masked_wideresnet28x10,
mimo_resnet,
mimo_wideresnet28x10,
packed_lenet,
packed_resnet,
packed_vgg,
packed_wideresnet28x10,
resnet,
vgg,
wideresnet28x10,
)
from .depth import bts_resnet
from .wrappers import (
EMA,
EPOCH_UPDATE_MODEL,
Expand All @@ -9,6 +29,7 @@
CheckpointCollector,
MCDropout,
StochasticModel,
Zero,
batch_ensemble,
deep_ensembles,
mc_dropout,
Expand Down
16 changes: 11 additions & 5 deletions torch_uncertainty/models/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# ruff: noqa: F401, F403
from .lenet import *
from .resnet import *
from .vgg import *
from .wideresnet import *
# ruff: noqa: F401
from .lenet import batchensemble_lenet, bayesian_lenet, lenet, packed_lenet
from .resnet import batched_resnet, lpbnn_resnet, masked_resnet, mimo_resnet, packed_resnet, resnet
from .vgg import packed_vgg, vgg
from .wideresnet import (
batched_wideresnet28x10,
masked_wideresnet28x10,
mimo_wideresnet28x10,
packed_wideresnet28x10,
wideresnet28x10,
)
2 changes: 1 addition & 1 deletion torch_uncertainty/models/classification/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from torch_uncertainty.layers.bayesian import BayesConv2d, BayesLinear
from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d
from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear
from torch_uncertainty.models import StochasticModel
from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble
from torch_uncertainty.models.wrappers.stochastic import StochasticModel

__all__ = ["batchensemble_lenet", "bayesian_lenet", "lenet", "packed_lenet"]

Expand Down
2 changes: 1 addition & 1 deletion torch_uncertainty/models/classification/resnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# ruff: noqa: F401, F403
# ruff: noqa: F401
from .batched import batched_resnet
from .lpbnn import lpbnn_resnet
from .masked import masked_resnet
Expand Down
2 changes: 1 addition & 1 deletion torch_uncertainty/models/classification/vgg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# ruff: noqa: F401, F403
# ruff: noqa: F401
from .packed import packed_vgg
from .std import vgg
12 changes: 6 additions & 6 deletions torch_uncertainty/models/classification/wideresnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ruff: noqa: F401, F403
from .batched import *
from .masked import *
from .mimo import *
from .packed import *
from .std import *
# ruff: noqa: F401
from .batched import batched_wideresnet28x10
from .masked import masked_wideresnet28x10
from .mimo import mimo_wideresnet28x10
from .packed import packed_wideresnet28x10
from .std import wideresnet28x10
2 changes: 2 additions & 0 deletions torch_uncertainty/models/depth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# ruff: noqa: F401
from .bts import bts_resnet
29 changes: 4 additions & 25 deletions torch_uncertainty/models/depth/bts.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,8 @@ def _bts(
return _BTS(backbone_name, max_depth, bts_size, dist_family, pretrained_backbone)


def bts_resnet50(
def bts_resnet(
arch: int,
max_depth: float,
bts_size: int = 512,
dist_family: str | None = None,
Expand All @@ -585,36 +586,14 @@ def bts_resnet50(
"""BTS model with ResNet-50 backbone.

Args:
arch (int): The number of layers of the underlying ResNet model: 50 or 101.
max_depth (float): Maximum predicted depth.
bts_size (int): BTS feature size. Defaults to 512.
dist_family (str): Distribution family name. Defaults to None.
pretrained_backbone (bool): Use a pretrained backbone. Defaults to True.
"""
return _bts(
"resnet50",
max_depth,
bts_size=bts_size,
dist_family=dist_family,
pretrained_backbone=pretrained_backbone,
)


def bts_resnet101(
max_depth: float,
bts_size: int = 512,
dist_family: str | None = None,
pretrained_backbone: bool = True,
) -> _BTS:
"""BTS model with ResNet-101 backbone.

Args:
max_depth (float): Maximum predicted depth.
bts_size (int): BTS feature size. Defaults to 512.
dist_family (str): Distribution family name. Defaults to None.
pretrained_backbone (bool): Use a pretrained backbone. Defaults to True.
"""
return _bts(
"resnet101",
f"resnet{arch}",
max_depth,
bts_size=bts_size,
dist_family=dist_family,
Expand Down
2 changes: 2 additions & 0 deletions torch_uncertainty/models/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# ruff: noqa: F401
from .deeplab import deep_lab_v3_resnet
Loading