From e876d65431243048e4795afb2a2b46a3e21ff752 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 May 2025 17:20:53 +0200 Subject: [PATCH 1/4] :hammer: DeepLab and BTS as singl e functions --- tests/baselines/test_standard.py | 12 ------- tests/models/test_deeplab.py | 18 ++++------ .../baselines/segmentation/deeplab.py | 18 ++-------- torch_uncertainty/models/depth/__init__.py | 2 ++ torch_uncertainty/models/depth/bts.py | 29 +++------------- .../models/segmentation/deeplab.py | 34 +++---------------- 6 files changed, 21 insertions(+), 92 deletions(-) diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index e59e0a5c..bece8ecd 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -163,15 +163,3 @@ def test_standard(self): separable=True, ).eval() _ = net(torch.rand(1, 3, 32, 32)) - - def test_errors(self): - with pytest.raises(ValueError): - DeepLabBaseline( - num_classes=10, - loss=nn.CrossEntropyLoss(), - version="test", - style="v3", - output_stride=16, - arch=50, - separable=True, - ) diff --git a/tests/models/test_deeplab.py b/tests/models/test_deeplab.py index 54729c27..564c2591 100644 --- a/tests/models/test_deeplab.py +++ b/tests/models/test_deeplab.py @@ -1,11 +1,7 @@ import pytest import torch -from torch_uncertainty.models.segmentation.deeplab import ( - _DeepLabV3, - deep_lab_v3_resnet50, - deep_lab_v3_resnet101, -) +from torch_uncertainty.models.segmentation.deeplab import _DeepLabV3, deep_lab_v3_resnet class TestDeeplab: @@ -13,17 +9,17 @@ class TestDeeplab: @torch.no_grad() def test_main(self): - model = deep_lab_v3_resnet50(10, "v3", 16, True, False).eval() + model = deep_lab_v3_resnet(10, 50, "v3", 16, True, False).eval() model(torch.randn(1, 3, 32, 32)) - model = deep_lab_v3_resnet50(10, "v3", 16, False, False).eval() - model = deep_lab_v3_resnet101(10, "v3+", 8, True, False).eval() + model = deep_lab_v3_resnet(10, 50, "v3", 16, False, False).eval() + model = deep_lab_v3_resnet(10, 101, "v3+", 8, True, False).eval() model(torch.randn(1, 3, 32, 32)) - model = deep_lab_v3_resnet101(10, "v3+", 8, False, False).eval() + model = deep_lab_v3_resnet(10, 101, "v3+", 8, False, False).eval() def test_errors(self): with pytest.raises(ValueError, match="Unknown backbone:"): _DeepLabV3(10, "other", "v3", 16, True, False) with pytest.raises(ValueError, match="output_stride: "): - deep_lab_v3_resnet50(10, "v3", 15, True, False) + deep_lab_v3_resnet(10, 50, "v3", 15, True, False) with pytest.raises(ValueError, match="Unknown style: "): - deep_lab_v3_resnet50(10, "v2", 16, True, False) + deep_lab_v3_resnet(10, 50, "v2", 16, True, False) diff --git a/torch_uncertainty/baselines/segmentation/deeplab.py b/torch_uncertainty/baselines/segmentation/deeplab.py index 65bc4630..156849e3 100644 --- a/torch_uncertainty/baselines/segmentation/deeplab.py +++ b/torch_uncertainty/baselines/segmentation/deeplab.py @@ -2,21 +2,11 @@ from torch import nn -from torch_uncertainty.models.segmentation.deeplab import ( - deep_lab_v3_resnet50, - deep_lab_v3_resnet101, -) +from torch_uncertainty.models.segmentation import deep_lab_v3_resnet from torch_uncertainty.routines.segmentation import SegmentationRoutine class DeepLabBaseline(SegmentationRoutine): - single = ["std"] - versions = { - "std": [ - deep_lab_v3_resnet50, - deep_lab_v3_resnet101, - ] - } archs = [50, 101] def __init__( @@ -35,6 +25,7 @@ def __init__( ) -> None: params = { "num_classes": num_classes, + "arch": arch, "style": style, "output_stride": output_stride, "separable": separable, @@ -43,10 +34,7 @@ def __init__( format_batch_fn = nn.Identity() - if version not in self.versions: - raise ValueError(f"Unknown version {version}") - - model = self.versions[version][self.archs.index(arch)](**params) + model = deep_lab_v3_resnet(**params) super().__init__( num_classes=num_classes, model=model, diff --git a/torch_uncertainty/models/depth/__init__.py b/torch_uncertainty/models/depth/__init__.py index e69de29b..be9b092e 100644 --- a/torch_uncertainty/models/depth/__init__.py +++ b/torch_uncertainty/models/depth/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .bts import bts_resnet diff --git a/torch_uncertainty/models/depth/bts.py b/torch_uncertainty/models/depth/bts.py index 408b2344..99184ee0 100644 --- a/torch_uncertainty/models/depth/bts.py +++ b/torch_uncertainty/models/depth/bts.py @@ -576,7 +576,8 @@ def _bts( return _BTS(backbone_name, max_depth, bts_size, dist_family, pretrained_backbone) -def bts_resnet50( +def bts_resnet( + arch: int, max_depth: float, bts_size: int = 512, dist_family: str | None = None, @@ -585,36 +586,14 @@ def bts_resnet50( """BTS model with ResNet-50 backbone. Args: + arch (int): The number of layers of the underlying ResNet model: 50 or 101. max_depth (float): Maximum predicted depth. bts_size (int): BTS feature size. Defaults to 512. dist_family (str): Distribution family name. Defaults to None. pretrained_backbone (bool): Use a pretrained backbone. Defaults to True. """ return _bts( - "resnet50", - max_depth, - bts_size=bts_size, - dist_family=dist_family, - pretrained_backbone=pretrained_backbone, - ) - - -def bts_resnet101( - max_depth: float, - bts_size: int = 512, - dist_family: str | None = None, - pretrained_backbone: bool = True, -) -> _BTS: - """BTS model with ResNet-101 backbone. - - Args: - max_depth (float): Maximum predicted depth. - bts_size (int): BTS feature size. Defaults to 512. - dist_family (str): Distribution family name. Defaults to None. - pretrained_backbone (bool): Use a pretrained backbone. Defaults to True. - """ - return _bts( - "resnet101", + f"resnet{arch}", max_depth, bts_size=bts_size, dist_family=dist_family, diff --git a/torch_uncertainty/models/segmentation/deeplab.py b/torch_uncertainty/models/segmentation/deeplab.py index b38bb776..1d364b4b 100644 --- a/torch_uncertainty/models/segmentation/deeplab.py +++ b/torch_uncertainty/models/segmentation/deeplab.py @@ -352,17 +352,19 @@ def forward(self, x: Tensor) -> Tensor: ) -def deep_lab_v3_resnet50( +def deep_lab_v3_resnet( num_classes: int, + arch: int, style: Literal["v3", "v3+"], output_stride: int = 16, separable: bool = False, pretrained_backbone: bool = True, ) -> _DeepLabV3: - """DeepLab V3(+) model with ResNet-50 backbone. + """DeepLab V3(+) model with ResNet-50/101 backbone. Args: num_classes (int): Number of classes. + arch (int): Number of layers of the underlying ResNet model: 50 or 101. style (Literal["v3", "v3+"]): Whether to use a DeepLab V3 or V3+ model. output_stride (int, optional): Output stride. Defaults to 16. separable (bool, optional): Use separable convolutions. Defaults to @@ -372,33 +374,7 @@ def deep_lab_v3_resnet50( """ return _DeepLabV3( num_classes, - "resnet50", - style, - output_stride=output_stride, - separable=separable, - pretrained_backbone=pretrained_backbone, - ) - - -def deep_lab_v3_resnet101( - num_classes: int, - style: Literal["v3", "v3+"], - output_stride: int = 16, - separable: bool = False, - pretrained_backbone: bool = True, -) -> _DeepLabV3: - """DeepLab V3(+) model with ResNet-50 backbone. - - Args: - num_classes (int): Number of classes. - style (Literal["v3", "v3+"]): Whether to use a DeepLab V3 or V3+ model. - output_stride (int, optional): Output stride. Defaults to 16. - separable (bool, optional): Use separable convolutions. Defaults to False. - pretrained_backbone (bool, optional): Use pretrained backbone. Defaults to True. - """ - return _DeepLabV3( - num_classes, - "resnet101", + f"resnet{arch}", style, output_stride=output_stride, separable=separable, From d969550d72d96e349d450cee7ae8c011c7519531 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 May 2025 17:21:10 +0200 Subject: [PATCH 2/4] :hammer: Put plotting function in .plotting --- torch_uncertainty/utils/misc.py | 44 ----------------------------- torch_uncertainty/utils/plotting.py | 42 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 44 deletions(-) diff --git a/torch_uncertainty/utils/misc.py b/torch_uncertainty/utils/misc.py index 328626a6..4dc53d80 100644 --- a/torch_uncertainty/utils/misc.py +++ b/torch_uncertainty/utils/misc.py @@ -1,11 +1,6 @@ import csv from pathlib import Path -import matplotlib.pyplot as plt -import torch -from matplotlib.axes import Axes -from matplotlib.figure import Figure - def csv_writer(path: Path, dic: dict) -> None: """Write a dictionary to a csv file. @@ -27,42 +22,3 @@ def csv_writer(path: Path, dic: dict) -> None: if append_mode is False: writer.writerow(dic.keys()) writer.writerow([f"{elem:.4f}" for elem in dic.values()]) - - -def plot_hist( - conf: list[torch.Tensor], - bins: int = 20, - title: str = "Histogram with 'auto' bins", - dpi: int = 60, -) -> tuple[Figure, Axes]: - """Plot a confidence histogram. - - Args: - conf (Any): The confidence values. - bins (int, optional): The number of bins. Defaults to 20. - title (str, optional): The title of the plot. Defaults to "Histogram - with 'auto' bins". - dpi (int, optional): The dpi of the plot. Defaults to 60. - - Returns: - Tuple[Figure, Axes]: The figure and axes of the plot. - """ - plt.rc("axes", axisbelow=True) - fig, ax = plt.subplots(1, figsize=(7, 5), dpi=dpi) - for i in [1, 0]: - ax.hist( - conf[i], - bins=bins, - density=True, - label=["In-distribution", "Out-of-Distribution"][i], - alpha=0.4, - linewidth=1, - edgecolor=["#0d559f", "#d45f00"][i], - color=["#1f77b4", "#ff7f0e"][i], - ) - - ax.set_title(title) - plt.grid(True, linestyle="--", alpha=0.7, zorder=0) - plt.legend() - fig.tight_layout() - return fig, ax diff --git a/torch_uncertainty/utils/plotting.py b/torch_uncertainty/utils/plotting.py index 0e5fa482..5fe4ceeb 100644 --- a/torch_uncertainty/utils/plotting.py +++ b/torch_uncertainty/utils/plotting.py @@ -1,6 +1,9 @@ import matplotlib.pyplot as plt import numpy as np +import torch import torchvision.transforms.functional as F +from matplotlib.axes import Axes +from matplotlib.figure import Figure from torch import Tensor @@ -17,3 +20,42 @@ def show(prediction: Tensor, target: Tensor): axs[1].set(title="Ground Truth") return fig + + +def plot_hist( + conf: list[torch.Tensor], + bins: int = 20, + title: str = "Histogram with 'auto' bins", + dpi: int = 60, +) -> tuple[Figure, Axes]: + """Plot a confidence histogram. + + Args: + conf (Any): The confidence values. + bins (int, optional): The number of bins. Defaults to 20. + title (str, optional): The title of the plot. Defaults to "Histogram + with 'auto' bins". + dpi (int, optional): The dpi of the plot. Defaults to 60. + + Returns: + Tuple[Figure, Axes]: The figure and axes of the plot. + """ + plt.rc("axes", axisbelow=True) + fig, ax = plt.subplots(1, figsize=(7, 5), dpi=dpi) + for i in [1, 0]: + ax.hist( + conf[i], + bins=bins, + density=True, + label=["In-distribution", "Out-of-Distribution"][i], + alpha=0.4, + linewidth=1, + edgecolor=["#0d559f", "#d45f00"][i], + color=["#1f77b4", "#ff7f0e"][i], + ) + + ax.set_title(title) + plt.grid(True, linestyle="--", alpha=0.7, zorder=0) + plt.legend() + fig.tight_layout() + return fig, ax From fb3e1df5b6e8e7c03e5efdc4b4d9581542494f7e Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 May 2025 17:31:59 +0200 Subject: [PATCH 3/4] :hammer: Everything as a second order import --- torch_uncertainty/datamodules/__init__.py | 3 ++- .../datamodules/classification/__init__.py | 1 + torch_uncertainty/datasets/__init__.py | 24 +++++++++++++++++++ torch_uncertainty/layers/__init__.py | 13 ++++++++++ torch_uncertainty/models/__init__.py | 21 ++++++++++++++++ .../models/classification/__init__.py | 16 +++++++++---- .../models/classification/lenet.py | 2 +- .../models/classification/resnet/__init__.py | 2 +- .../models/classification/vgg/__init__.py | 2 +- .../classification/wideresnet/__init__.py | 12 +++++----- .../models/segmentation/__init__.py | 2 ++ torch_uncertainty/transforms/__init__.py | 18 ++++++++++++++ torch_uncertainty/utils/__init__.py | 5 +++- 13 files changed, 105 insertions(+), 16 deletions(-) diff --git a/torch_uncertainty/datamodules/__init__.py b/torch_uncertainty/datamodules/__init__.py index 02a28e9f..0bfae4a1 100644 --- a/torch_uncertainty/datamodules/__init__.py +++ b/torch_uncertainty/datamodules/__init__.py @@ -11,6 +11,7 @@ OnlineShoppersDataModule, SpamBaseDataModule, TinyImageNetDataModule, + UCIClassificationDataModule, ) -from .segmentation import CamVidDataModule, CityscapesDataModule +from .segmentation import CamVidDataModule, CityscapesDataModule, MUADDataModule from .uci_regression import UCIRegressionDataModule diff --git a/torch_uncertainty/datamodules/classification/__init__.py b/torch_uncertainty/datamodules/classification/__init__.py index 20f19650..838097f8 100644 --- a/torch_uncertainty/datamodules/classification/__init__.py +++ b/torch_uncertainty/datamodules/classification/__init__.py @@ -10,4 +10,5 @@ HTRU2DataModule, OnlineShoppersDataModule, SpamBaseDataModule, + UCIClassificationDataModule, ) diff --git a/torch_uncertainty/datasets/__init__.py b/torch_uncertainty/datasets/__init__.py index 5acc7735..43320fdc 100644 --- a/torch_uncertainty/datasets/__init__.py +++ b/torch_uncertainty/datasets/__init__.py @@ -1,7 +1,31 @@ # ruff: noqa: F401 from .aggregated_dataset import AggregatedDataset +from .classification import ( + CIFAR10C, + CIFAR10H, + CIFAR10N, + CIFAR100C, + CIFAR100N, + CUB, + HTRU2, + MNISTC, + BankMarketing, + DOTA2Games, + ImageNetA, + ImageNetC, + ImageNetO, + ImageNetR, + NotMNIST, + OnlineShoppers, + OpenImageO, + SpamBase, + TinyImageNet, + TinyImageNetC, +) from .fractals import Fractals from .frost import FrostImages from .kitti import KITTIDepth from .muad import MUAD from .nyu import NYUv2 +from .regression import UCIRegression +from .segmentation import CamVid, Cityscapes diff --git a/torch_uncertainty/layers/__init__.py b/torch_uncertainty/layers/__init__.py index 689943ff..c5627f27 100644 --- a/torch_uncertainty/layers/__init__.py +++ b/torch_uncertainty/layers/__init__.py @@ -2,6 +2,19 @@ from .batch_ensemble import BatchConv2d, BatchLinear from .bayesian import BayesConv1d, BayesConv2d, BayesConv3d, BayesLinear from .channel_layer_norm import ChannelLayerNorm +from .distributions import ( + CauchyConvNd, + CauchyLinear, + LaplaceConvNd, + LaplaceLinear, + NormalConvNd, + NormalInverseGammaConvNd, + NormalInverseGammaLinear, + NormalLinear, + StudentTConvNd, + StudentTLinear, +) +from .filter_response_norm import FilterResponseNorm1d, FilterResponseNorm2d, FilterResponseNorm3d from .masksembles import MaskedConv2d, MaskedLinear from .modules import Identity from .packed import ( diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index d293d910..688fe11d 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -1,4 +1,24 @@ # ruff: noqa: F401 +from .classification import ( + batched_resnet, + batched_wideresnet28x10, + batchensemble_lenet, + bayesian_lenet, + lenet, + lpbnn_resnet, + masked_resnet, + masked_wideresnet28x10, + mimo_resnet, + mimo_wideresnet28x10, + packed_lenet, + packed_resnet, + packed_vgg, + packed_wideresnet28x10, + resnet, + vgg, + wideresnet28x10, +) +from .depth import bts_resnet from .wrappers import ( EMA, EPOCH_UPDATE_MODEL, @@ -9,6 +29,7 @@ CheckpointCollector, MCDropout, StochasticModel, + Zero, batch_ensemble, deep_ensembles, mc_dropout, diff --git a/torch_uncertainty/models/classification/__init__.py b/torch_uncertainty/models/classification/__init__.py index 8780ab72..9be30914 100644 --- a/torch_uncertainty/models/classification/__init__.py +++ b/torch_uncertainty/models/classification/__init__.py @@ -1,5 +1,11 @@ -# ruff: noqa: F401, F403 -from .lenet import * -from .resnet import * -from .vgg import * -from .wideresnet import * +# ruff: noqa: F401 +from .lenet import batchensemble_lenet, bayesian_lenet, lenet, packed_lenet +from .resnet import batched_resnet, lpbnn_resnet, masked_resnet, mimo_resnet, packed_resnet, resnet +from .vgg import packed_vgg, vgg +from .wideresnet import ( + batched_wideresnet28x10, + masked_wideresnet28x10, + mimo_wideresnet28x10, + packed_wideresnet28x10, + wideresnet28x10, +) diff --git a/torch_uncertainty/models/classification/lenet.py b/torch_uncertainty/models/classification/lenet.py index d6edbe28..3bdb0c70 100644 --- a/torch_uncertainty/models/classification/lenet.py +++ b/torch_uncertainty/models/classification/lenet.py @@ -8,8 +8,8 @@ from torch_uncertainty.layers.bayesian import BayesConv2d, BayesLinear from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear -from torch_uncertainty.models import StochasticModel from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble +from torch_uncertainty.models.wrappers.stochastic import StochasticModel __all__ = ["batchensemble_lenet", "bayesian_lenet", "lenet", "packed_lenet"] diff --git a/torch_uncertainty/models/classification/resnet/__init__.py b/torch_uncertainty/models/classification/resnet/__init__.py index cdff770e..3a6e2893 100644 --- a/torch_uncertainty/models/classification/resnet/__init__.py +++ b/torch_uncertainty/models/classification/resnet/__init__.py @@ -1,4 +1,4 @@ -# ruff: noqa: F401, F403 +# ruff: noqa: F401 from .batched import batched_resnet from .lpbnn import lpbnn_resnet from .masked import masked_resnet diff --git a/torch_uncertainty/models/classification/vgg/__init__.py b/torch_uncertainty/models/classification/vgg/__init__.py index 05beb228..707f5e4d 100644 --- a/torch_uncertainty/models/classification/vgg/__init__.py +++ b/torch_uncertainty/models/classification/vgg/__init__.py @@ -1,3 +1,3 @@ -# ruff: noqa: F401, F403 +# ruff: noqa: F401 from .packed import packed_vgg from .std import vgg diff --git a/torch_uncertainty/models/classification/wideresnet/__init__.py b/torch_uncertainty/models/classification/wideresnet/__init__.py index 883b98db..109957ab 100644 --- a/torch_uncertainty/models/classification/wideresnet/__init__.py +++ b/torch_uncertainty/models/classification/wideresnet/__init__.py @@ -1,6 +1,6 @@ -# ruff: noqa: F401, F403 -from .batched import * -from .masked import * -from .mimo import * -from .packed import * -from .std import * +# ruff: noqa: F401 +from .batched import batched_wideresnet28x10 +from .masked import masked_wideresnet28x10 +from .mimo import mimo_wideresnet28x10 +from .packed import packed_wideresnet28x10 +from .std import wideresnet28x10 diff --git a/torch_uncertainty/models/segmentation/__init__.py b/torch_uncertainty/models/segmentation/__init__.py index e69de29b..a31c8d85 100644 --- a/torch_uncertainty/models/segmentation/__init__.py +++ b/torch_uncertainty/models/segmentation/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .deeplab import deep_lab_v3_resnet diff --git a/torch_uncertainty/transforms/__init__.py b/torch_uncertainty/transforms/__init__.py index d3aae6ec..be65afce 100644 --- a/torch_uncertainty/transforms/__init__.py +++ b/torch_uncertainty/transforms/__init__.py @@ -1,5 +1,23 @@ # ruff: noqa: F401 from .batch import MIMOBatchFormat, RepeatTarget +from .corruption import Brightness as BrightnessCorruption +from .corruption import Contrast as ContrastCorruption +from .corruption import ( + DefocusBlur, + Elastic, + Fog, + Frost, + GaussianNoise, + GlassBlur, + ImpulseNoise, + JPEGCompression, + MotionBlur, + Pixelate, + ShotNoise, + Snow, + ZoomBlur, + corruption_transforms, +) from .cutout import Cutout from .image import ( AutoContrast, diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index d7d32e94..bdc5bb10 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -1,7 +1,10 @@ # ruff: noqa: F401 from .checkpoints import get_version from .cli import TULightningCLI +from .distributions import NormalInverseGamma, get_dist_class, get_dist_estimate +from .evaluation_loop import TUEvaluationLoop from .hub import load_hf -from .misc import csv_writer, plot_hist +from .misc import csv_writer +from .plotting import plot_hist, show from .trainer import TUTrainer from .transforms import interpolation_modes_from_str From 99574e1070e970439f39c7eee11e5100f3ec2ff5 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 14 May 2025 17:37:31 +0200 Subject: [PATCH 4/4] :bug: Fix BTS Baseline --- torch_uncertainty/baselines/depth/bts.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/torch_uncertainty/baselines/depth/bts.py b/torch_uncertainty/baselines/depth/bts.py index 31b54051..c8bca7ea 100644 --- a/torch_uncertainty/baselines/depth/bts.py +++ b/torch_uncertainty/baselines/depth/bts.py @@ -2,18 +2,11 @@ from torch import nn -from torch_uncertainty.models.depth.bts import bts_resnet50, bts_resnet101 +from torch_uncertainty.models.depth.bts import bts_resnet from torch_uncertainty.routines import PixelRegressionRoutine class BTSBaseline(PixelRegressionRoutine): - single = ["std"] - versions = { - "std": [ - bts_resnet50, - bts_resnet101, - ] - } archs = [50, 101] def __init__( @@ -23,10 +16,10 @@ def __init__( arch: int, max_depth: float, dist_family: str | None = None, - num_estimators: int = 1, pretrained_backbone: bool = True, ) -> None: params = { + "arch": arch, "dist_family": dist_family, "max_depth": max_depth, "pretrained_backbone": pretrained_backbone, @@ -37,12 +30,11 @@ def __init__( if version not in self.versions: raise ValueError(f"Unknown version {version}") - model = self.versions[version][self.archs.index(arch)](**params) + model = bts_resnet(**params) super().__init__( - output_dim=1, model=model, + output_dim=1, loss=loss, - num_estimators=num_estimators, format_batch_fn=format_batch_fn, dist_family=dist_family, )