Skip to content

✨ Add UCI classification datasets, improve binary classification, add LS support for BCEWithLogitsLoss, remove dependencies #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
482db82
:shirt: Fix title in corruption tutorial
o-laurent Oct 22, 2024
760dd23
:bug: Code now runs without glest
o-laurent Nov 1, 2024
86adf87
:shirt: Improve the look of ECE's diagrams
o-laurent Nov 9, 2024
f5ed453
:white_check_mark: Update tests
o-laurent Nov 9, 2024
fdfcaf2
:shirt: Reformat TUTrainer
o-laurent Nov 13, 2024
1e23ebe
:bug: Use only lightning.pytorch
o-laurent Nov 13, 2024
a1d1fcb
:sparkles: Add binary classification metrics when needed
o-laurent Nov 13, 2024
a61b050
:bug: Fix Brier Score for Binary cls.
o-laurent Nov 13, 2024
a71f4cf
:sparkles: Add BCE with LS
o-laurent Nov 13, 2024
cde2997
:hammer: Rename UCI regression dm
o-laurent Nov 17, 2024
60e7644
:fire: Remove useless code line
o-laurent Nov 17, 2024
3da7ce3
:sparkles: Add UCI cls datasets
o-laurent Nov 17, 2024
0ef69ce
:sparkles: Add UCI cls dm & forgotten init
o-laurent Nov 17, 2024
24aef4d
:fire: Remove torchinfo dep.
o-laurent Nov 17, 2024
0862136
:heavy_minus_sign: Remove tensorboard dep.
o-laurent Nov 17, 2024
1c7b8d6
:heavy_minus_sign: Remove hf hub from necessary dep.
o-laurent Nov 17, 2024
7cd9ed9
:heavy_minus_sign: Remove sklearn hub from necessary dep.
o-laurent Nov 17, 2024
5702fae
:bug: Fix small errors & make tests pass
o-laurent Nov 17, 2024
d9c039b
:white_check_mark: Add some tests and improve coverage
o-laurent Nov 17, 2024
5e27706
:white_check_mark: Add ds test and fix loss test
o-laurent Nov 17, 2024
43d099c
:fire: Delete useless dataset tests
o-laurent Nov 17, 2024
239b96a
:bug: Fix BCEWithLogitsLSLoss so target and pred types are the same
alafage Nov 17, 2024
241fc5f
:hammer: Fix DOTA case & add all to API
o-laurent Nov 17, 2024
c71db08
Merge branch 'dev' of github.com:ENSTA-U2IS-AI/torch-uncertainty into…
o-laurent Nov 17, 2024
2244fb3
:white_check_mark: Improve tests
o-laurent Nov 17, 2024
41bc492
:bug: Fix val split & loss args
o-laurent Nov 18, 2024
779e426
:fire: Remove segformer dead code
o-laurent Nov 18, 2024
7dbecb8
:white_check_mark: Continue improving cov.
o-laurent Nov 18, 2024
8b80532
:zap: Update version for release
o-laurent Nov 18, 2024
4022f8a
:books: Fix typo in Quickstart
alafage Nov 18, 2024
cce5e1d
:zap: Update ruff and rm np from explicit dep.
o-laurent Nov 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions auto_tutorials_source/tutorial_corruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def show_images(transforms):

# %%
# 4. Other Corruptions
# ~~~~~~~~~~~~~~~~~~~~

from torch_uncertainty.transforms.corruption import (
Brightness, Contrast, Elastic, JPEGCompression, Pixelate)
Expand Down
36 changes: 35 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ Losses
ConfidencePenaltyLoss
KLDiv
ELBOLoss
BCEWithLogitsLSLoss

Post-Processing Methods
-----------------------
Expand Down Expand Up @@ -379,14 +380,28 @@ Classification
TinyImageNetDataModule
ImageNetDataModule

UCI Tabular Classification
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

BankMarketingDataModule
DOTA2GamesDataModule
HTRU2DataModule
OnlineShoppersDataModule
SpamBaseDataModule

Regression
^^^^^^^^^^
.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

UCIDataModule
UCIRegressionDataModule

.. currentmodule:: torch_uncertainty.datamodules.segmentation

Expand Down Expand Up @@ -432,6 +447,25 @@ Classification
TinyImageNetC
OpenImageO


UCI Tabular Classification
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. currentmodule:: torch_uncertainty.datasets.classification.uci


.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

BankMarketing
DOTA2Games
HTRU2
OnlineShoppers
SpamBase


Regression
^^^^^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent"
)
author = "Adrien Lafage and Olivier Laurent"
release = "0.3.0"
release = "0.3.1"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ CIFAR10 datamodule.
from lightning.pytorch import TUTrainer

dm = CIFAR10DataModule(root="data", batch_size=32)
trainer = TUTTrainer(gpus=1, max_epochs=100)
trainer = TUTrainer(gpus=1, max_epochs=100)
trainer.fit(routine, dm)
trainer.test(routine, dm)

Expand Down
11 changes: 11 additions & 0 deletions docs/source/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,17 @@ For the conflictual loss, consider citing:
* Authors: *Mohammed Fellaji, Frédéric Pennerath, Brieuc Conan-Guez, and Miguel Couceiro*
* Paper: `ArXiv 2024 <https://arxiv.org/pdf/2407.12211>`__.

Binary Cross-Entropy with Logits Loss with Label Smoothing
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For the binary cross-entropy with logits loss with label smoothing, consider citing:

**Rethinking the Inception Architecture for Computer Vision**

* Authors: *Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna*
* Paper: `CVPR 2016 <https://arxiv.org/pdf/1512.00567.pdf>`__.


Metrics
-------

Expand Down
6 changes: 3 additions & 3 deletions experiments/regression/uci_datasets/deep_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from torch_uncertainty import cli_main, init_args
from torch_uncertainty.baselines import DeepEnsemblesBaseline
from torch_uncertainty.datamodules import UCIDataModule
from torch_uncertainty.datamodules import UCIRegressionDataModule

if __name__ == "__main__":
args = init_args(DeepEnsemblesBaseline, UCIDataModule)
args = init_args(DeepEnsemblesBaseline, UCIRegressionDataModule)
if args.root == "./data/":
root = Path(__file__).parent.absolute().parents[2]
else:
Expand All @@ -15,7 +15,7 @@

# datamodule
args.root = str(root / "data")
dm = UCIDataModule(dataset_name="kin8nm", **vars(args))
dm = UCIRegressionDataModule(dataset_name="kin8nm", **vars(args))

# model
args.task = "regression"
Expand Down
4 changes: 2 additions & 2 deletions experiments/regression/uci_datasets/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch_uncertainty import TULightningCLI
from torch_uncertainty.baselines.regression import MLPBaseline
from torch_uncertainty.datamodules import UCIDataModule
from torch_uncertainty.datamodules import UCIRegressionDataModule


class MLPCLI(TULightningCLI):
Expand All @@ -12,7 +12,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:


def cli_main() -> MLPCLI:
return MLPCLI(MLPBaseline, UCIDataModule)
return MLPCLI(MLPBaseline, UCIRegressionDataModule)


if __name__ == "__main__":
Expand Down
17 changes: 7 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "torch_uncertainty"
version = "0.3.0"
version = "0.3.1"
authors = [
{ name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" },
{ name = "Adrien Lafage", email = "adrienlafage@outlook.com" },
Expand All @@ -28,20 +28,14 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3 :: Only",
]
dependencies = [
"timm",
"lightning[pytorch-extra]>=2.0",
"torchvision>=0.16",
"tensorboard",
"einops",
"torchinfo",
"huggingface-hub",
"scikit-learn",
"matplotlib",
"numpy",
"rich>=10.2.2",
"seaborn",
]
Expand All @@ -55,8 +49,10 @@ image = [
]
tabular = ["pandas"]
dev = [
"scikit-learn",
"huggingface-hub",
"torch_uncertainty[image]",
"ruff==0.6.9",
"ruff==0.7.4",
"pytest-cov",
"pre-commit",
"pre-commit-hooks",
Expand All @@ -74,6 +70,7 @@ all = [
"laplace-torch",
"glest==0.0.1a1",
"scipy",
"tensorboard",
]

[project.urls]
Expand All @@ -89,7 +86,7 @@ line-length = 80
target-version = "py310"
lint.extend-select = [
"A",
"ARG",
"ARG",
"B",
"C4",
"D",
Expand Down Expand Up @@ -170,5 +167,5 @@ include = ["*/torch-uncertainty/*"]
omit = ["*/tests/*", "*/datasets/*"]

[tool.coverage.report]
exclude_lines = ["coverage: ignore", "raise NotImplementedError"]
exclude_lines = ["coverage: ignore", "raise NotImplementedError", "raise ImportError"]
ignore_errors = true
4 changes: 0 additions & 4 deletions tests/baselines/test_batched.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import nn
from torchinfo import summary

from torch_uncertainty.baselines.classification import (
ResNetBaseline,
Expand All @@ -23,7 +22,6 @@ def test_batched_18(self):
groups=1,
)

summary(net)
_ = net(torch.rand(1, 3, 32, 32))

def test_batched_50(self):
Expand All @@ -38,7 +36,6 @@ def test_batched_50(self):
groups=1,
)

summary(net)
_ = net(torch.rand(1, 3, 40, 40))


Expand All @@ -56,5 +53,4 @@ def test_batched(self):
groups=1,
)

summary(net)
_ = net(torch.rand(1, 3, 32, 32))
4 changes: 0 additions & 4 deletions tests/baselines/test_masked.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import torch
from torch import nn
from torchinfo import summary

from torch_uncertainty.baselines.classification import (
ResNetBaseline,
Expand All @@ -25,7 +24,6 @@ def test_masked_18(self):
groups=1,
)

summary(net)
_ = net(torch.rand(1, 3, 32, 32))

def test_masked_50(self):
Expand All @@ -41,7 +39,6 @@ def test_masked_50(self):
groups=1,
)

summary(net)
_ = net(torch.rand(1, 3, 40, 40))

def test_masked_errors(self):
Expand Down Expand Up @@ -87,5 +84,4 @@ def test_masked(self):
groups=1,
)

summary(net)
_ = net(torch.rand(1, 3, 32, 32))
4 changes: 0 additions & 4 deletions tests/baselines/test_mc_dropout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import nn
from torchinfo import summary

from torch_uncertainty.baselines.classification import (
ResNetBaseline,
Expand All @@ -24,7 +23,6 @@ def test_standard(self):
style="cifar",
groups=1,
)
summary(net)
net(torch.rand(1, 3, 32, 32))


Expand All @@ -42,7 +40,6 @@ def test_standard(self):
style="cifar",
groups=1,
)
summary(net)
net(torch.rand(1, 3, 32, 32))


Expand All @@ -61,7 +58,6 @@ def test_standard(self):
groups=1,
last_layer_dropout=True,
)
summary(net)
net(torch.rand(1, 3, 32, 32))

net = VGGBaseline(
Expand Down
4 changes: 0 additions & 4 deletions tests/baselines/test_mimo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import nn
from torchinfo import summary

from torch_uncertainty.baselines.classification import (
ResNetBaseline,
Expand All @@ -25,7 +24,6 @@ def test_mimo_50(self):
groups=1,
).eval()

summary(net)
_ = net(torch.rand(1, 3, 32, 32))

def test_mimo_18(self):
Expand All @@ -42,7 +40,6 @@ def test_mimo_18(self):
groups=2,
).eval()

summary(net)
_ = net(torch.rand(1, 3, 40, 40))


Expand All @@ -62,5 +59,4 @@ def test_mimo(self):
groups=1,
).eval()

summary(net)
_ = net(torch.rand(1, 3, 32, 32))
8 changes: 0 additions & 8 deletions tests/baselines/test_packed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import torch
from torch import nn
from torchinfo import summary

from torch_uncertainty.baselines.classification import (
ResNetBaseline,
Expand All @@ -28,8 +27,6 @@ def test_packed_50(self):
groups=1,
)

summary(net)

_ = net(torch.rand(1, 3, 32, 32))

def test_packed_18(self):
Expand All @@ -46,7 +43,6 @@ def test_packed_18(self):
groups=2,
)

summary(net)
_ = net(torch.rand(1, 3, 40, 40))

def test_packed_exception(self):
Expand Down Expand Up @@ -95,7 +91,6 @@ def test_packed(self):
groups=1,
)

summary(net)
_ = net(torch.rand(1, 3, 32, 32))


Expand All @@ -114,8 +109,6 @@ def test_packed(self):
gamma=1,
groups=1,
)

summary(net)
_ = net(torch.rand(2, 3, 32, 32))


Expand All @@ -133,5 +126,4 @@ def test_packed(self):
alpha=2,
gamma=1,
)
summary(net)
_ = net(torch.rand(1, 3))
Loading