diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml
new file mode 100644
index 0000000..f9b2976
--- /dev/null
+++ b/.github/workflows/linters.yml
@@ -0,0 +1,26 @@
+name: pylint
+
+on:
+  pull_request:
+  workflow_dispatch:
+
+jobs:
+  checks:
+    runs-on: ubuntu-20.04
+    strategy:
+      max-parallel: 4
+      matrix:
+        python-version: [3.9, "3.10", "3.11", "3.12"]
+
+    steps:
+    - uses: actions/checkout@v1
+    - name: Set up Python ${{ matrix.python-version }}
+      uses: actions/setup-python@v2
+      with:
+        python-version: ${{ matrix.python-version }}
+    - name: Install dependencies
+      run: |
+        python -m pip install --upgrade pip
+        pip install tox
+    - name: Check lint
+      run: tox -e py$(echo ${{ matrix.python-version }} | tr -d .)-lint
\ No newline at end of file
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index ae821e9..ae943de 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -2,7 +2,7 @@ name: tests
 
 on:
   push:
-    branches: ["main", "develop"]
+    branches: ["main"]
   pull_request:
     branches: ["main", "develop"]
 
@@ -44,7 +44,7 @@ jobs:
           set -xe
           python -VV
           python -c "import torch; print(torch.__version__)"
-          pytest --cov=orthogonium.layers.conv.AOC --cov-report lcov tests
+          pytest --cov=orthogonium.layers --cov-report lcov tests
         shell: bash
 
       - name: Upload Coverage to Coveralls
diff --git a/.gitignore b/.gitignore
index 69b0de5..b5f1975 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,8 +7,7 @@ __pycache__
 # IDE / Tools files
 .vscode
 .tox
-env_flashlipschitz
-env_d16_flashlipschitz
+env_orthogonium
 *wandb
 
 # Files generated:
@@ -19,14 +18,11 @@ docs/build
 # mkdocs documentation
 /site
 
-env_opacus/
-
 # all pth files
 *.pth
 
 lightning_logs/
 scripts/data/
-scripts/test_gan/data/
 output/
 *.tar
 *_args.txt
@@ -37,6 +33,5 @@ wandb/
 *.ckpt
 *.pth
 scripts/data/
-scripts/test_gan/data/
 scripts/ortho_gan/data/
 *.csv
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 456c9d3..c8d9024 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -4,17 +4,17 @@ Thanks for taking the time to contribute!
 
 From opening a bug report to creating a pull request: every contribution is
 appreciated and welcome. If you're planning to implement a new feature or change
-the api please create an [issue first](https://https://github.com/deel-ai/dp-lipschitz/issues/new). This way we can ensure that your precious
+the api please create an [issue first](https://https://github.com/thib-s/orthogonium/issues/new). This way we can ensure that your precious
 work is not in vain.
 
 
 ## Setup with make
 
-- Clone the repo `git clone https://github.com/deel-ai/dp-lipschitz.git`.
-- Go to your freshly downloaded repo `cd lipdp`
+- Clone the repo `git clone https://github.com/thib-s/orthogonium.git`.
+- Go to your freshly downloaded repo `cd orthogonium`
 - Create a virtual environment and install the necessary dependencies for development:
 
-  `make prepare-dev && source lipdp_dev_env/bin/activate`.
+  `make prepare-dev && source orthogonium_dev_env/bin/activate`.
 
 Welcome to the team !
 
@@ -42,7 +42,7 @@ Basically, it will check that your code follow a certain number of convention. A
 
 After getting some feedback, push to your fork and submit a pull request. We
 may suggest some changes or improvements or alternatives, but for small changes
-your pull request should be accepted quickly (see [Governance policy](https://github.com/deel-ai/lipdp/blob/master/GOVERNANCE.md)).
+your pull request should be accepted quickly (see [Governance policy](https://github.com/thib-s/orthogonium/blob/master/GOVERNANCE.md)).
 
 Something that will increase the chance that your pull request is accepted:
 
diff --git a/GOVERNANCE.md b/GOVERNANCE.md
index 5aacd5f..b290f54 100644
--- a/GOVERNANCE.md
+++ b/GOVERNANCE.md
@@ -1,6 +1,6 @@
 # GOVERNANCE
 
-`lipdp` is developped as part of the Artificial and Natural Intelligence Toulouse Institute (DEEL/ANITI) program.
+`orthogonium` is developped as part of the Artificial and Natural Intelligence Toulouse Institute (DEEL/ANITI) program.
 
 DEEL/ANITI is this repository owner and the write rights manager.
 
@@ -10,7 +10,7 @@ These management rules are intended to be collaborative and all those involved i
 
 ###	Governance committee
 
-The governance committee is initially composed of DEEL members who contributed to the first version of `lipdp` and are the only contributors to the master branch.
+The governance committee is initially composed of DEEL members who contributed to the first version of `orthogonium` and are the only contributors to the master branch.
 
 The governance committee is responsible for the master branch that contains the code of the version of the library that is officially recognised.
 
diff --git a/Makefile b/Makefile
index ceb1920..07664fd 100644
--- a/Makefile
+++ b/Makefile
@@ -19,28 +19,28 @@ help:
 
 prepare-dev:
 	python3 -m pip install virtualenv
-	python3 -m venv env_flashlipschitz
-	. env_flashlipschitz/bin/activate && pip install --upgrade pip
-	. env_flashlipschitz/bin/activate && pip install -e .[dev]
-	. env_flashlipschitz/bin/activate && pre-commit install
-	. env_flashlipschitz/bin/activate && pre-commit install-hooks
-	. env_flashlipschitz/bin/activate && pre-commit install --hook-type commit-msg
+	python3 -m venv env_orthogonium
+	. env_orthogonium/bin/activate && pip install --upgrade pip
+	. env_orthogonium/bin/activate && pip install -e .[dev]
+	. env_orthogonium/bin/activate && pre-commit install
+	. env_orthogonium/bin/activate && pre-commit install-hooks
+	. env_orthogonium/bin/activate && pre-commit install --hook-type commit-msg
 
 test:
-	. env_flashlipschitz/bin/activate && tox
+	. env_orthogonium/bin/activate && tox
 
 check_all:
-	. env_flashlipschitz/bin/activate && pre-commit run --all-files
+	. env_orthogonium/bin/activate && pre-commit run --all-files
 
 updatetools:
-	. env_flashlipschitz/bin/activate && pre-commit autoupdate
+	. env_orthogonium/bin/activate && pre-commit autoupdate
 
 test-disable-gpu:
-	. env_flashlipschitz/bin/activate && CUDA_VISIBLE_DEVICES=-1 tox
+	. env_orthogonium/bin/activate && CUDA_VISIBLE_DEVICES=-1 tox
 
 doc:
-	. env_flashlipschitz/bin/activate && mkdocs build
-	. env_flashlipschitz/bin/activate && mkdocs gh-deploy
+	. env_orthogonium/bin/activate && mkdocs build
+	. env_orthogonium/bin/activate && mkdocs gh-deploy
 
 serve-doc:
-	. env_flashlipschitz/bin/activate && CUDA_VISIBLE_DEVICES=-1 mkdocs serve
+	. env_orthogonium/bin/activate && CUDA_VISIBLE_DEVICES=-1 mkdocs serve
diff --git a/README.md b/README.md
index c3058e5..363ea59 100644
--- a/README.md
+++ b/README.md
@@ -6,19 +6,22 @@
 
 
     
-         +
+         -
     
     
-         +
+         
     
-    
-         +    
+
+    
+         
     
-    
-         +    
+
+    
+         
     
-    
-         +    
+
+    
+         +
+    
+    
+         
     
     
          @@ -115,20 +118,20 @@ pip install -e .
 
 ```python
 from orthogonium.layers.conv.AOC import AdaptiveOrthoConv2d
-from orthogonium.layers.linear.reparametrizers import DEFAULT_ORTHO_PARAMS
+from orthogonium.reparametrizers import DEFAULT_ORTHO_PARAMS
 
 # use OrthoConv2d with the same params as torch.nn.Conv2d
 
 conv = AdaptiveOrthoConv2d(
-  kernel_size=kernel_size,
-  in_channels=256,
-  out_channels=256,
-  stride=2,
-  groups=16,
-  bias=True,
-  padding=(kernel_size // 2, kernel_size // 2),
-  padding_mode="circular",
-  ortho_params=DEFAULT_ORTHO_PARAMS
+    kernel_size=kernel_size,
+    in_channels=256,
+    out_channels=256,
+    stride=2,
+    groups=16,
+    bias=True,
+    padding=(kernel_size // 2, kernel_size // 2),
+    padding_mode="circular",
+    ortho_params=DEFAULT_ORTHO_PARAMS
 )
 ```
 
diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md
new file mode 100644
index 0000000..e6a426d
--- /dev/null
+++ b/docs/CONTRIBUTING.md
@@ -0,0 +1,52 @@
+# Contributing
+
+Thanks for taking the time to contribute!
+
+From opening a bug report to creating a pull request: every contribution is
+appreciated and welcome. If you're planning to implement a new feature or change
+the api please create an [issue first](https://github.com/thib-s/orthogonium/issues). This way we can ensure that your precious
+work is not in vain.
+
+
+## Setup with make
+
+- Clone the repo `git clone git@github.com:thib-s/orthogonium.git`.
+- Go to your freshly downloaded repo `cd orthogonium`
+- Create a virtual environment and install the necessary dependencies for development:
+
+  `make prepare-dev && source orthogonium_dev_env/bin/activate`.
+
+Welcome to the team !
+
+
+## Tests
+
+To run test `make test`
+This command activate your virtual environment and launch the `tox` command.
+
+
+`tox` on the otherhand will do the following:
+- run pytest on the tests folder
+- run pylint on the deel-datasets main files
+> Note: It is possible that pylint throw false-positive errors. If the linting test failed please check first pylint output to point out the reasons.
+
+Please, make sure you run all the tests at least once before opening a pull request.
+
+A word toward [Pylint](https://pypi.org/project/pylint/) for those that don't know it:
+> Pylint is a Python static code analysis tool which looks for programming errors, helps enforcing a coding standard, sniffs for code smells and offers simple refactoring suggestions.
+
+Basically, it will check that your code follow a certain number of convention. Any Pull Request will go through a Github workflow ensuring that your code respect the Pylint conventions (most of them at least).
+
+## Submitting Changes
+
+After getting some feedback, push to your fork and submit a pull request. We
+may suggest some changes or improvements or alternatives, but for small changes
+your pull request should be accepted quickly (see [Governance policy](https://github.com/thib-s/orthogonium/blob/release-no-advertising/GOVERNANCE.md)).
+
+Something that will increase the chance that your pull request is accepted:
+
+- Write tests and ensure that the existing ones pass.
+- If `make test` is succesful, you have fair chances to pass the CI workflows (linting and test)
+- Follow the existing coding style and run `make check_all` to check all files format.
+- Write a [good commit message](https://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html) (we follow a lowercase convention).
+- For a major fix/feature make sure your PR has an issue and if it doesn't, please create one. This would help discussion with the community, and polishing ideas in case of a new feature.
diff --git a/docs/api/aoc.md b/docs/api/aoc.md
new file mode 100644
index 0000000..54e9838
--- /dev/null
+++ b/docs/api/aoc.md
@@ -0,0 +1,28 @@
+The most scalable method to build orthogonal convolution. Allows control of kernel size, 
+stride, groups dilation and transposed convolutions.
+
+The classes `AdaptiveOrthoConv2d` and `AdaptiveOrthoConv2d` are not classes,
+ but factory function to selecte bewteen 3 different parametrizations, as depicted
+in the following figure:
+
+
@@ -115,20 +118,20 @@ pip install -e .
 
 ```python
 from orthogonium.layers.conv.AOC import AdaptiveOrthoConv2d
-from orthogonium.layers.linear.reparametrizers import DEFAULT_ORTHO_PARAMS
+from orthogonium.reparametrizers import DEFAULT_ORTHO_PARAMS
 
 # use OrthoConv2d with the same params as torch.nn.Conv2d
 
 conv = AdaptiveOrthoConv2d(
-  kernel_size=kernel_size,
-  in_channels=256,
-  out_channels=256,
-  stride=2,
-  groups=16,
-  bias=True,
-  padding=(kernel_size // 2, kernel_size // 2),
-  padding_mode="circular",
-  ortho_params=DEFAULT_ORTHO_PARAMS
+    kernel_size=kernel_size,
+    in_channels=256,
+    out_channels=256,
+    stride=2,
+    groups=16,
+    bias=True,
+    padding=(kernel_size // 2, kernel_size // 2),
+    padding_mode="circular",
+    ortho_params=DEFAULT_ORTHO_PARAMS
 )
 ```
 
diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md
new file mode 100644
index 0000000..e6a426d
--- /dev/null
+++ b/docs/CONTRIBUTING.md
@@ -0,0 +1,52 @@
+# Contributing
+
+Thanks for taking the time to contribute!
+
+From opening a bug report to creating a pull request: every contribution is
+appreciated and welcome. If you're planning to implement a new feature or change
+the api please create an [issue first](https://github.com/thib-s/orthogonium/issues). This way we can ensure that your precious
+work is not in vain.
+
+
+## Setup with make
+
+- Clone the repo `git clone git@github.com:thib-s/orthogonium.git`.
+- Go to your freshly downloaded repo `cd orthogonium`
+- Create a virtual environment and install the necessary dependencies for development:
+
+  `make prepare-dev && source orthogonium_dev_env/bin/activate`.
+
+Welcome to the team !
+
+
+## Tests
+
+To run test `make test`
+This command activate your virtual environment and launch the `tox` command.
+
+
+`tox` on the otherhand will do the following:
+- run pytest on the tests folder
+- run pylint on the deel-datasets main files
+> Note: It is possible that pylint throw false-positive errors. If the linting test failed please check first pylint output to point out the reasons.
+
+Please, make sure you run all the tests at least once before opening a pull request.
+
+A word toward [Pylint](https://pypi.org/project/pylint/) for those that don't know it:
+> Pylint is a Python static code analysis tool which looks for programming errors, helps enforcing a coding standard, sniffs for code smells and offers simple refactoring suggestions.
+
+Basically, it will check that your code follow a certain number of convention. Any Pull Request will go through a Github workflow ensuring that your code respect the Pylint conventions (most of them at least).
+
+## Submitting Changes
+
+After getting some feedback, push to your fork and submit a pull request. We
+may suggest some changes or improvements or alternatives, but for small changes
+your pull request should be accepted quickly (see [Governance policy](https://github.com/thib-s/orthogonium/blob/release-no-advertising/GOVERNANCE.md)).
+
+Something that will increase the chance that your pull request is accepted:
+
+- Write tests and ensure that the existing ones pass.
+- If `make test` is succesful, you have fair chances to pass the CI workflows (linting and test)
+- Follow the existing coding style and run `make check_all` to check all files format.
+- Write a [good commit message](https://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html) (we follow a lowercase convention).
+- For a major fix/feature make sure your PR has an issue and if it doesn't, please create one. This would help discussion with the community, and polishing ideas in case of a new feature.
diff --git a/docs/api/aoc.md b/docs/api/aoc.md
new file mode 100644
index 0000000..54e9838
--- /dev/null
+++ b/docs/api/aoc.md
@@ -0,0 +1,28 @@
+The most scalable method to build orthogonal convolution. Allows control of kernel size, 
+stride, groups dilation and transposed convolutions.
+
+The classes `AdaptiveOrthoConv2d` and `AdaptiveOrthoConv2d` are not classes,
+ but factory function to selecte bewteen 3 different parametrizations, as depicted
+in the following figure:
+
+ +
+::: orthogonium.layers.conv.AOC.ortho_conv
+    rendering:
+        show_root_toc_entry: True
+    selection:
+        inherited_members: True
+
+
+::: orthogonium.layers.conv.AOC.fast_block_ortho_conv
+    rendering:
+        show_root_toc_entry: true
+    selection:
+        inherited_members: true
+
+::: orthogonium.layers.conv.AOC.rko_conv
+    rendering:
+        show_root_toc_entry: true
+    selection:
+        inherited_members: true
+
diff --git a/docs/api/conv.md b/docs/api/conv.md
new file mode 100644
index 0000000..f317d46
--- /dev/null
+++ b/docs/api/conv.md
@@ -0,0 +1,13 @@
+::: orthogonium.layers.conv.AOC.ortho_conv
+    rendering:
+        show_root_toc_entry: True
+    selection:
+        inherited_members: True
+
+
+
+::: orthogonium.layers.conv.SLL.sll_layer
+    rendering:
+        show_root_toc_entry: True
+    selection:
+        inherited_members: True
diff --git a/docs/api/linear.md b/docs/api/linear.md
new file mode 100644
index 0000000..1eeaa62
--- /dev/null
+++ b/docs/api/linear.md
@@ -0,0 +1,5 @@
+::: orthogonium.layers.linear.ortho_linear
+    rendering:
+        show_root_toc_entry: True
+    selection:
+        inherited_members: True
diff --git a/docs/api/reparametrizers.md b/docs/api/reparametrizers.md
new file mode 100644
index 0000000..c36df77
--- /dev/null
+++ b/docs/api/reparametrizers.md
@@ -0,0 +1,5 @@
+::: orthogonium.reparametrizers
+    rendering:
+        show_root_toc_entry: True
+    selection:
+        inherited_members: True
diff --git a/docs/assets/banner.png b/docs/assets/banner.png
new file mode 100644
index 0000000..4501fd5
Binary files /dev/null and b/docs/assets/banner.png differ
diff --git a/docs/assets/flowchart_v4.png b/docs/assets/flowchart_v4.png
new file mode 100644
index 0000000..df509ec
Binary files /dev/null and b/docs/assets/flowchart_v4.png differ
diff --git a/docs/css/custom.css b/docs/css/custom.css
new file mode 100644
index 0000000..9ddf790
--- /dev/null
+++ b/docs/css/custom.css
@@ -0,0 +1,194 @@
+:root {
+    --primary: #3f51b5;
+    --dark: #313131;
+
+    --synthax-blue: #3b78e7;
+    --synthax-red: #ec407a;
+
+}
+
+[data-md-color-scheme="slate"] {
+    /* dark theme */
+    --primary: #ecb22e;
+    --dark: #313131;
+    --synthax-blue: #ecb22e;
+    /* --synthax-red: #03DAC6; */
+    --md-typeset-a-color: var(--synthax-blue);
+    --md-primary-fg-color: var(--dark);
+    --md-primary-fg-color--light: #5d6cc0;
+    --md-primary-fg-color--dark: var(--primary);
+    --md-accent-fg-color: var(--synthax-blue);
+}
+
+
+html {
+    font-size: 115%;
+}
+
+.md-header {
+    background-color: var(--md-footer-bg-color);
+    color: #fff;
+}
+
+.md-typeset ul li p {
+    margin: 0;
+}
+
+.md-typeset h4,
+.md-typeset h5,
+.md-typeset h6 {
+    line-height: 2.0;
+}
+
+li {
+    margin-bottom: 0 !important;
+}
+
+.autodoc-signature>em:first-of-type {
+    font-style: normal;
+    font-weight: bold;
+    color: var(--primary);
+}
+
+h2 code {
+    color: var(--synthax-red) !important;
+}
+
+h4 code {
+    color: var(--synthax-blue) !important;
+    font-weight: bold !important;
+    background-color: rgba(31, 31, 31, 0.05) !important;
+    padding: 5px !important;
+    border-radius: 2px !important;
+    cursor: pointer !important;
+    font-size: 95% !important;
+}
+
+p {
+    text-align: justify
+}
+
+[data-md-color-scheme="slate"] h4 code {
+    background-color: rgba(240, 240, 240, 0.05) !important;
+}
+
+.admonition {
+    font-size: 95% !important;
+}
+
+img[alt*="Colab"] {
+    transform: translateY(3px);
+    padding-left: 5px;
+    width: 150px;
+}
+
+.md-typeset img,
+.md-typeset svg {
+    max-width: none;
+}
+
+.md-typeset .task-list-control [type=checkbox]:checked+.task-list-indicator::before {
+    background: var(--primary) !important;
+}
+
+.md-typeset ul li,
+.md-typeset ol li {
+    margin-bottom: .4em !important;
+}
+
+h2.numkdoc~p:not(.footnote),
+h2.numkdoc~ul:not(.footnote),
+h2.numkdoc~ul li ul:not(.footnote),
+h2.numkdoc~ul li:not(.footnote) {
+    --margin: 0.06em;
+    margin-bottom: var(--margin) !important;
+    margin-top: var(--margin) !important;
+}
+
+table {
+    border: solid 2px rgba(255, 255, 255, 0.1);
+}
+
+.md-typeset table:not([class]) th {
+    background-color: var(--dark);
+}
+
+[data-md-color-scheme="slate"] .md-typeset table:not([class]) th {
+    color: var(--dark);
+    background-color: white;
+}
+
+.md-typeset__table {
+    width: 100%;
+}
+
+.md-typeset table:not([class]) {
+    display: table;
+}
+
+.md-typeset thead:not([class]) {
+    color: white;
+}
+
+table th a {
+    color: var(--synthax-red) !important;
+    word-break: break-word !important;
+}
+
+span.parameter-name {
+    color: var(--primary);
+}
+
+li .parameter-type {
+    font-weight: bold;
+}
+
+[data-md-color-scheme="slate"] .parameter-type {
+    color: var(--synthax-red);
+}
+
+.parameter-self {
+    color: var(--synthax-red);
+}
+
+.md-typeset table:not([class]) {
+    font-size: 0.8rem;
+}
+
+.md-header__button.md-logo img,
+.md-header__button.md-logo svg {
+    width: 2.0rem !important;
+    height: auto !important;
+}
+
+.md-clipboard {
+    transition: all 0.25s ease;
+    color: var(--synthax-red) !important;
+    opacity: 0.25;
+}
+
+.md-clipboard:focus,
+.md-clipboard:hover {
+    opacity: 1.0;
+}
+
+.authors-container {
+    display: flex;
+    flex-direction: row;
+    align-items: flex-start;
+    justify-content: center;
+}
+
+.author-block {
+    display: flex;
+    flex-direction: column;
+    align-items: center;
+    justify-content: center;
+    margin: 0 1em;
+    text-align: center;
+    width: 100%;
+}
+
+.author-block img {
+    width: 100px;
+}
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 0000000..0cc5a1f
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,200 @@
+
+
+::: orthogonium.layers.conv.AOC.ortho_conv
+    rendering:
+        show_root_toc_entry: True
+    selection:
+        inherited_members: True
+
+
+::: orthogonium.layers.conv.AOC.fast_block_ortho_conv
+    rendering:
+        show_root_toc_entry: true
+    selection:
+        inherited_members: true
+
+::: orthogonium.layers.conv.AOC.rko_conv
+    rendering:
+        show_root_toc_entry: true
+    selection:
+        inherited_members: true
+
diff --git a/docs/api/conv.md b/docs/api/conv.md
new file mode 100644
index 0000000..f317d46
--- /dev/null
+++ b/docs/api/conv.md
@@ -0,0 +1,13 @@
+::: orthogonium.layers.conv.AOC.ortho_conv
+    rendering:
+        show_root_toc_entry: True
+    selection:
+        inherited_members: True
+
+
+
+::: orthogonium.layers.conv.SLL.sll_layer
+    rendering:
+        show_root_toc_entry: True
+    selection:
+        inherited_members: True
diff --git a/docs/api/linear.md b/docs/api/linear.md
new file mode 100644
index 0000000..1eeaa62
--- /dev/null
+++ b/docs/api/linear.md
@@ -0,0 +1,5 @@
+::: orthogonium.layers.linear.ortho_linear
+    rendering:
+        show_root_toc_entry: True
+    selection:
+        inherited_members: True
diff --git a/docs/api/reparametrizers.md b/docs/api/reparametrizers.md
new file mode 100644
index 0000000..c36df77
--- /dev/null
+++ b/docs/api/reparametrizers.md
@@ -0,0 +1,5 @@
+::: orthogonium.reparametrizers
+    rendering:
+        show_root_toc_entry: True
+    selection:
+        inherited_members: True
diff --git a/docs/assets/banner.png b/docs/assets/banner.png
new file mode 100644
index 0000000..4501fd5
Binary files /dev/null and b/docs/assets/banner.png differ
diff --git a/docs/assets/flowchart_v4.png b/docs/assets/flowchart_v4.png
new file mode 100644
index 0000000..df509ec
Binary files /dev/null and b/docs/assets/flowchart_v4.png differ
diff --git a/docs/css/custom.css b/docs/css/custom.css
new file mode 100644
index 0000000..9ddf790
--- /dev/null
+++ b/docs/css/custom.css
@@ -0,0 +1,194 @@
+:root {
+    --primary: #3f51b5;
+    --dark: #313131;
+
+    --synthax-blue: #3b78e7;
+    --synthax-red: #ec407a;
+
+}
+
+[data-md-color-scheme="slate"] {
+    /* dark theme */
+    --primary: #ecb22e;
+    --dark: #313131;
+    --synthax-blue: #ecb22e;
+    /* --synthax-red: #03DAC6; */
+    --md-typeset-a-color: var(--synthax-blue);
+    --md-primary-fg-color: var(--dark);
+    --md-primary-fg-color--light: #5d6cc0;
+    --md-primary-fg-color--dark: var(--primary);
+    --md-accent-fg-color: var(--synthax-blue);
+}
+
+
+html {
+    font-size: 115%;
+}
+
+.md-header {
+    background-color: var(--md-footer-bg-color);
+    color: #fff;
+}
+
+.md-typeset ul li p {
+    margin: 0;
+}
+
+.md-typeset h4,
+.md-typeset h5,
+.md-typeset h6 {
+    line-height: 2.0;
+}
+
+li {
+    margin-bottom: 0 !important;
+}
+
+.autodoc-signature>em:first-of-type {
+    font-style: normal;
+    font-weight: bold;
+    color: var(--primary);
+}
+
+h2 code {
+    color: var(--synthax-red) !important;
+}
+
+h4 code {
+    color: var(--synthax-blue) !important;
+    font-weight: bold !important;
+    background-color: rgba(31, 31, 31, 0.05) !important;
+    padding: 5px !important;
+    border-radius: 2px !important;
+    cursor: pointer !important;
+    font-size: 95% !important;
+}
+
+p {
+    text-align: justify
+}
+
+[data-md-color-scheme="slate"] h4 code {
+    background-color: rgba(240, 240, 240, 0.05) !important;
+}
+
+.admonition {
+    font-size: 95% !important;
+}
+
+img[alt*="Colab"] {
+    transform: translateY(3px);
+    padding-left: 5px;
+    width: 150px;
+}
+
+.md-typeset img,
+.md-typeset svg {
+    max-width: none;
+}
+
+.md-typeset .task-list-control [type=checkbox]:checked+.task-list-indicator::before {
+    background: var(--primary) !important;
+}
+
+.md-typeset ul li,
+.md-typeset ol li {
+    margin-bottom: .4em !important;
+}
+
+h2.numkdoc~p:not(.footnote),
+h2.numkdoc~ul:not(.footnote),
+h2.numkdoc~ul li ul:not(.footnote),
+h2.numkdoc~ul li:not(.footnote) {
+    --margin: 0.06em;
+    margin-bottom: var(--margin) !important;
+    margin-top: var(--margin) !important;
+}
+
+table {
+    border: solid 2px rgba(255, 255, 255, 0.1);
+}
+
+.md-typeset table:not([class]) th {
+    background-color: var(--dark);
+}
+
+[data-md-color-scheme="slate"] .md-typeset table:not([class]) th {
+    color: var(--dark);
+    background-color: white;
+}
+
+.md-typeset__table {
+    width: 100%;
+}
+
+.md-typeset table:not([class]) {
+    display: table;
+}
+
+.md-typeset thead:not([class]) {
+    color: white;
+}
+
+table th a {
+    color: var(--synthax-red) !important;
+    word-break: break-word !important;
+}
+
+span.parameter-name {
+    color: var(--primary);
+}
+
+li .parameter-type {
+    font-weight: bold;
+}
+
+[data-md-color-scheme="slate"] .parameter-type {
+    color: var(--synthax-red);
+}
+
+.parameter-self {
+    color: var(--synthax-red);
+}
+
+.md-typeset table:not([class]) {
+    font-size: 0.8rem;
+}
+
+.md-header__button.md-logo img,
+.md-header__button.md-logo svg {
+    width: 2.0rem !important;
+    height: auto !important;
+}
+
+.md-clipboard {
+    transition: all 0.25s ease;
+    color: var(--synthax-red) !important;
+    opacity: 0.25;
+}
+
+.md-clipboard:focus,
+.md-clipboard:hover {
+    opacity: 1.0;
+}
+
+.authors-container {
+    display: flex;
+    flex-direction: row;
+    align-items: flex-start;
+    justify-content: center;
+}
+
+.author-block {
+    display: flex;
+    flex-direction: column;
+    align-items: center;
+    justify-content: center;
+    margin: 0 1em;
+    text-align: center;
+    width: 100%;
+}
+
+.author-block img {
+    width: 100px;
+}
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 0000000..0cc5a1f
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,200 @@
+
+    

+
+
+
+
+
+
+# β¨ Orthogonium: Improved implementations of orthogonal layers
+
+This library aims to centralize, standardize and improve methods to 
+build orthogonal layers, with a focus on convolutional layers . We noticed that a layer's implementation play a
+significant role in the final performance : a more efficient implementation 
+allows larger networks and more training steps within the same compute 
+budget. So our implementation differs from original papers in order to 
+be faster, to consume less memory or be more flexible.
+
+# π What is included in this library ?
+
+| Layer name          | Description                                                                                                                        | Orthogonal ? | Usage                                                                                                                              | Status         |
+|---------------------|------------------------------------------------------------------------------------------------------------------------------------|--------------|------------------------------------------------------------------------------------------------------------------------------------|----------------|
+| AOC (Adaptive-BCOP) | The most scalable method to build orthogonal convolution. Allows control of kernel size, stride, groups dilation and convtranspose | Orthogonal   | A flexible method for complex architectures. Preserve orthogonality and works on large scale images.                               | done           |
+| Adaptive-SC-Fac     | Same as previous layer but based on SC-Fac instead of BCOP, which claims a complete parametrization of separable convolutions      | Orthogonal   | Same as above                                                                                                                      | pending        |
+| Adaptive-SOC        | SOC modified to be: i) faster and memory efficient ii) handle stride, groups, dilation & convtranspose                             | Orthogonal   | Good for depthwise convolutions and cases where control over the kernel size is not required                                       | in progress    |
+| SLL                 | The original SLL layer, which is already quite efficient.                                                                          | 1-Lipschitz  | Well suited for residual blocks, it also contains ReLU activations.                                                                | done           |
+| SLL-AOC             | SLL-AOC is to the downsampling block what SLL is to the residual block (see ResNet paper)                                          | 1-Lipschitz  | Allows to construct a "strided" residual block than can change the number of channels. It adds a convolution in the residual path. | done           |
+| Sandwish-AOC        | Sandwish convolutions that uses AOC to replace the FFT. Allowing it to scale to large images.                                      | 1-Lipschitz  |                                                                                                                                    | pending        |
+| Adaptive-ECO        | ECO modified to i) handle stride, groups & convtranspose                                                                           | Orthogonal   |                                                                                                                                    | (low priority) |
+
+## directory structure
+
+```
+orthogonium
+βββ layers
+β   βββ conv
+β   β   βββ AOC
+β   β   β   βββ ortho_conv.py # contains AdaptiveOrthoConv2d layer
+β   β   βββ AdaptiveSOC
+β   β   β   βββ ortho_conv.py # contains AdaptiveSOCConv2d layer (untested)
+β   β   βββ SLL
+β   β   β   βββ sll_layer.py # contains SDPBasedLipschitzConv, SDPBasedLipschitzDense, SLLxAOCLipschitzResBlock
+β   βββ legacy
+β   β   βββ original code of BCOP, SOC, Cayley etc.
+β   βββ linear
+β   β   βββ ortho_linear.py # contains OrthoLinear layer (can be used with BB, QR and Exp parametrization)
+β   βββ normalization.py # contains Batch centering and Layer centering
+β   βββ custom_activations.py # contains custom activations for 1 lipschitz networks
+β   βββ channel_shuffle.py # contains channel shuffle layer  
+βββ model_factory.py # factory function to construct various models for the zoo
+βββ losses # loss functions, VRA estimation
+```
+
+## AOC:
+
+AOC is a method that allows to build orthogonal convolutions with 
+an explicit kernel, that support all features like stride, conv transposed,
+grouped convolutions and dilation (and all compositions of these parameters). This approach is highly scalable, and can
+be applied to problems like Imagenet-1K.
+
+## Adaptive-SC-FAC:
+
+As AOC is built on top of BCOP method, we can construct an equivalent method constructed on top of SC-Fac instead.
+This will allow to compare performance of the two methods given that they have very similar parametrization. (See our 
+paper for discussions about the similarities and differences between the two methods).
+
+## Adaptive-SOC:
+
+Adaptive-SOC blend the approach of AOC and SOC. It differs from SOC in the way that it is more memory efficient and 
+sometimes faster. It also allows to handle stride, groups, dilation and transposed convolutions. However, it does not allow to 
+control the kernel size explicitly as the resulting kernel size is larger than the requested kernel size. 
+It is due to the computation to the exponential of a kernel that increases the kernel size at each iteration.
+
+Its development is still in progress, so extra testing is still require to ensure exact orthogonality.
+
+## SLL:
+
+SLL is a method that allows to construct small residual blocks with ReLU activations. We kept most to the original 
+implementation, and added `SLLxAOCLipschitzResBlock` that construct a down-sampling residual block by fusing SLL with 
+$AOC.
+
+## more layers are coming soon !
+
+# π  Install the library:
+
+The library will soon be available on pip, in the meanwhile, you can clone the repository and run the following command 
+to install it locally:
+```
+pip install -e .
+```
+
+## Use the layer:
+
+```python
+from orthogonium.layers.conv.AOC import AdaptiveOrthoConv2d
+from orthogonium.reparametrizers import DEFAULT_ORTHO_PARAMS
+
+# use OrthoConv2d with the same params as torch.nn.Conv2d
+
+conv = AdaptiveOrthoConv2d(
+  kernel_size=kernel_size,
+  in_channels=256,
+  out_channels=256,
+  stride=2,
+  groups=16,
+  bias=True,
+  padding=(kernel_size // 2, kernel_size // 2),
+  padding_mode="circular",
+  ortho_params=DEFAULT_ORTHO_PARAMS
+)
+```
+
+# π― Model Zoo
+
+Stay tuned, a model zoo will be available soon !
+
+
+
+# π₯Disclaimer
+
+Given the great quality of the original implementations, orthogonium do not focus on reproducing exactly the results of
+the original papers, but rather on providing a more efficient implementation. Some degradations in the final provable 
+accuracy may be observed when reproducing the results of the original papers, we consider this acceptable is the gain 
+in terms of scalability is worth it. This library aims to provide more scalable and versatile implementations for people who seek to use orthogonal layers 
+in a larger scale setting.
+
+# π Ressources
+
+## 1 Lipschitz CNNs and orthogonal CNNs
+
+- 1-Lipschitz Layers Compared: [github](https://github.com/berndprach/1LipschitzLayersCompared) and [paper](https://berndprach.github.io/publication/1LipschitzLayersCompared)
+- BCOP: [github](https://github.com/ColinQiyangLi/LConvNet) and [paper](https://arxiv.org/abs/1911.00937)
+- SC-Fac: [paper](https://arxiv.org/abs/2106.09121)
+- ECO: [paper](https://openreview.net/forum?id=Zr5W2LSRhD)
+- Cayley: [github](https://github.com/locuslab/orthogonal-convolutions) and [paper](https://arxiv.org/abs/2104.07167)
+- LOT: [github](https://github.com/AI-secure/Layerwise-Orthogonal-Training) and [paper](https://arxiv.org/abs/2210.11620)
+- ProjUNN-T: [github](https://github.com/facebookresearch/projUNN) and [paper](https://arxiv.org/abs/2203.05483)
+- SLL: [github](https://github.com/araujoalexandre/Lipschitz-SLL-Networks) and [paper](https://arxiv.org/abs/2303.03169)
+- Sandwish: [github](https://github.com/acfr/LBDN) and [paper](https://arxiv.org/abs/2301.11526)
+- AOL: [github](https://github.com/berndprach/AOL) and [paper](https://arxiv.org/abs/2208.03160)
+- SOC: [github](https://github.com/singlasahil14/SOC) and [paper 1](https://arxiv.org/abs/2105.11417), [paper 2](https://arxiv.org/abs/2211.08453)
+
+## Lipschitz constant evaluation
+
+- [Spectral Norm of Convolutional Layers with Circular and Zero Paddings](https://arxiv.org/abs/2402.00240) 
+- [Efficient Bound of Lipschitz Constant for Convolutional Layers by Gram Iteration](https://arxiv.org/abs/2305.16173)
+- [github of the two papers](https://github.com/blaisedelattre/lip4conv/tree/main)
+
+# π» Contributing
+
+This library is still in a very early stage, so expect some bugs and missing features. Also, before the version 1.0.0,
+the API may change and no backward compatibility will be ensured, this will allow a rapid integration of new features.
+In order to prioritize the development, we will focus on the most used layers and models. If you have a specific need,
+please open an issue, and we will try to address it as soon as possible.
+
+Also, if you have a model that you would like to share, please open a PR with the model and the training script. We will
+be happy to include it in the zoo.
+
+If you want to contribute, please open a PR with the new feature or bug fix. We will review it as soon as possible.
+
+## Ongoing developments
+
+Layers:
+- SOC:
+  - remove channels padding to handle ci != co efficiently
+  - enable groups
+  - enable support for native stride, transposition and dilation
+- AOL:
+  - torch implementation of AOL
+- Sandwish:
+  - import code
+  - plug AOC into Sandwish conv
+
+ZOO:
+- models from the paper
diff --git a/docs/js/custom.js b/docs/js/custom.js
new file mode 100644
index 0000000..7fc3c5b
--- /dev/null
+++ b/docs/js/custom.js
@@ -0,0 +1,12 @@
+window.MathJax = {
+  tex: {
+    inlineMath: [["\\(", "\\)"]],
+    displayMath: [["\\[", "\\]"]],
+    processEscapes: true,
+    processEnvironments: true,
+  },
+  options: {
+    ignoreHtmlClass: ".*|",
+    processHtmlClass: "arithmatex"
+  },
+};
\ No newline at end of file
diff --git a/mkdocs.yml b/mkdocs.yml
index 7df1aa3..4545550 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -1,18 +1,23 @@
-site_name: lipdp
+site_name: orthogonium
 
 # Set navigation here
 nav:
   - Home: index.md
   - API Reference:
-      - deel.lipdp.layers module: api/layers.md
+      - convolutions: api/conv.md
+      - linear layers: api/linear.md
+      - reparametrizers: api/reparametrizers.md
+#      - layers.conv.AOC module: api/aoc.md
+#      - layers.conv.adaptiveSOC module: api/adaptiveSOC.md
+#      - layers.conv.SLL module: api/sll.md
 #  - Tutorials:
 #    - "Demo 0: How to use notebook in documentation": notebooks/demo_fake.ipynb
   - Contributing: CONTRIBUTING.md
 
 theme:
   name: "material"
-  logo: assets/logo.png
-  favicon: assets/logo.png
+  logo: assets/banner.png
+  favicon: assets/banner.png
   palette:
     - scheme: default
       primary: dark
@@ -50,8 +55,8 @@ markdown_extensions:
       custom_checkbox: true
       clickable_checkbox: true
   - pymdownx.emoji:
-      emoji_index: !!python/name:materialx.emoji.twemoji
-      emoji_generator: !!python/name:materialx.emoji.to_svg
+      emoji_index: !!python/name:material.extensions.emoji.twemoji
+      emoji_generator: !!python/name:material.extensions.emoji.to_svg
 
 extra_css:
   - css/custom.css
@@ -61,5 +66,5 @@ extra_javascript:
   - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
   - js/custom.js
 
-repo_name: "deel-ai/"
-repo_url: "https://github.com/deel-ai/"
+repo_name: "thib-s/orthogonium"
+repo_url: "https://github.com/thib-s/orthogonium"
diff --git a/orthogonium/__init__.py b/orthogonium/__init__.py
index e69de29..228fa2a 100644
--- a/orthogonium/__init__.py
+++ b/orthogonium/__init__.py
@@ -0,0 +1 @@
+from .model_factory import ClassParam
diff --git a/orthogonium/layers/__init__.py b/orthogonium/layers/__init__.py
index 88cab1c..4236109 100644
--- a/orthogonium/layers/__init__.py
+++ b/orthogonium/layers/__init__.py
@@ -9,9 +9,3 @@
 from .channel_shuffle import ChannelShuffle
 from orthogonium.layers.conv.AOC.ortho_conv import AdaptiveOrthoConv2d
 from orthogonium.layers.conv.AOC.ortho_conv import AdaptiveOrthoConvTranspose2d
-from orthogonium.layers.linear.reparametrizers import OrthoParams
-from orthogonium.layers.linear.reparametrizers import (
-    DEFAULT_ORTHO_PARAMS,
-    EXP_ORTHO_PARAMS,
-    CHOLESKY_ORTHO_PARAMS,
-)
diff --git a/orthogonium/layers/channel_shuffle.py b/orthogonium/layers/channel_shuffle.py
index 4fad75c..c4f623d 100644
--- a/orthogonium/layers/channel_shuffle.py
+++ b/orthogonium/layers/channel_shuffle.py
@@ -24,20 +24,3 @@ def forward(self, x):
 
     def extra_repr(self):
         return f"group_in={self.group_in}, group_out={self.group_out}"
-
-
-if __name__ == "__main__":
-    x = torch.randn(2, 6)
-    # takes two groups of size 3, and return 3 groups of size 2
-    gm = ChannelShuffle(2, 3)
-    print(f"in: {x}")
-    y = gm(x)
-    print(f"out: {y}")
-    x2 = torch.randn(2, 6, 32, 32)
-    gm = ChannelShuffle(2, 3)
-    y2 = gm(x2)
-    print(f"in shape: {x2.shape}, out shape: {y2.shape}")
-    gp = ChannelShuffle(3, 2)
-    x2b = gp(y2)
-    assert torch.allclose(x2, x2b), "ChannelShuffle is not invertible"
-    print("ChannelShuffle is invertible")
diff --git a/orthogonium/layers/conv/AOC/__init__.py b/orthogonium/layers/conv/AOC/__init__.py
index 2bb128e..4e3bfbd 100644
--- a/orthogonium/layers/conv/AOC/__init__.py
+++ b/orthogonium/layers/conv/AOC/__init__.py
@@ -1,2 +1,8 @@
 from .ortho_conv import AdaptiveOrthoConv2d
 from .ortho_conv import AdaptiveOrthoConvTranspose2d
+from .fast_block_ortho_conv import FastBlockConv2d
+from .fast_block_ortho_conv import FastBlockConvTranspose2D
+from .bcop_x_rko_conv import BcopRkoConv2d
+from .bcop_x_rko_conv import BcopRkoConvTranspose2d
+from .rko_conv import RKOConv2d
+from .rko_conv import RkoConvTranspose2d
diff --git a/orthogonium/layers/conv/AOC/bcop_x_rko_conv.py b/orthogonium/layers/conv/AOC/bcop_x_rko_conv.py
index 8ebacfe..8cb4942 100644
--- a/orthogonium/layers/conv/AOC/bcop_x_rko_conv.py
+++ b/orthogonium/layers/conv/AOC/bcop_x_rko_conv.py
@@ -2,19 +2,17 @@
 from typing import Union
 
 import numpy as np
-import torch
 from torch import nn as nn
 from torch.nn.common_types import _size_2_t
 from torch.nn.utils import parametrize as parametrize
 
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import (
     attach_bcop_weight,
-    transpose_kernel,
 )
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import conv_singular_values_numpy
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import fast_matrix_conv
 from orthogonium.layers.conv.AOC.rko_conv import attach_rko_weight
-from orthogonium.layers.linear.reparametrizers import OrthoParams
+from orthogonium.reparametrizers import OrthoParams
 
 
 class BcopRkoConv2d(nn.Conv2d):
@@ -53,27 +51,25 @@ def __init__(
             bias,
             padding_mode,
         )
-        if (self.dilation[0] != 1 or self.dilation[1] != 1) and (
-            self.stride[0] != 1 or self.stride[1] != 1
-        ):
-            raise RuntimeError(
-                "dilation must be 1 when stride is not 1. The set of orthonal convolutions is empty in this setting."
-            )
         # raise runtime error if kernel size >= stride
         if self.kernel_size[0] < self.stride[0] or self.kernel_size[1] < self.stride[1]:
-            raise RuntimeError(
+            raise ValueError(
                 "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
             )
-        if (in_channels % groups != 0) and (out_channels % groups != 0):
-            raise RuntimeError(
-                "in_channels and out_channels must be divisible by groups"
-            )
         if (
             ((max(in_channels, out_channels) // groups) < 2)
             and (self.kernel_size[0] != self.stride[0])
             and (self.kernel_size[1] != self.stride[1])
         ):
-            raise RuntimeError("inner conv must have at least 2 channels")
+            raise ValueError("inner conv must have at least 2 channels")
+        if (
+            (self.out_channels >= self.in_channels)
+            and (((self.dilation[0] % self.stride[0]) == 0) and (self.stride[0] > 1))
+            and (((self.dilation[1] % self.stride[1]) == 0) and (self.stride[1] > 1))
+        ):
+            raise ValueError(
+                "dilation must be 1 when stride is not 1. The set of orthonal convolutions is empty in this setting."
+            )
         self.intermediate_channels = max(
             in_channels, out_channels // (self.stride[0] * self.stride[1])
         )
@@ -161,7 +157,9 @@ def singular_values(self):
         )
         sv_min = sv_min * svs_2.min()
         sv_max = sv_max * svs_2.max()
-        stable_rank = 0.5 * stable_rank + 0.5 * (np.mean(svs_2) / (svs_2.max() ** 2))
+        stable_rank = 0.5 * stable_rank + 0.5 * (
+            np.mean(svs_2) ** 2 / (svs_2.max() ** 2)
+        )
         return sv_min, sv_max, stable_rank
 
     def forward(self, X):
@@ -197,35 +195,36 @@ def __init__(
             out_channels,
             kernel_size,
             stride,
-            padding,
+            padding if padding_mode == "zeros" else 0,
             output_padding,
             groups,
             bias,
             dilation,
-            padding_mode,
+            "zeros",
         )
+        self.real_padding_mode = padding_mode
+        if padding == "same":
+            padding = self._calculate_same_padding()
+        self.real_padding = self._standardize_padding(padding)
 
-        # raise runtime error if kernel size >= stride
-        if (self.dilation[0] != 1 or self.dilation[1] != 1) and (
-            self.stride[0] != 1 or self.stride[1] != 1
+        if (
+            (self.out_channels <= self.in_channels)
+            and (((self.dilation[0] % self.stride[0]) == 0) and (self.stride[0] > 1))
+            and (((self.dilation[1] % self.stride[1]) == 0) and (self.stride[1] > 1))
         ):
-            raise RuntimeError(
+            raise ValueError(
                 "dilation must be 1 when stride is not 1. The set of orthonal convolutions is empty in this setting."
             )
         if self.kernel_size[0] < self.stride[0] or self.kernel_size[1] < self.stride[1]:
-            raise RuntimeError(
+            raise ValueError(
                 "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
             )
-        if (in_channels % groups != 0) and (out_channels % groups != 0):
-            raise RuntimeError(
-                "in_channels and out_channels must be divisible by groups"
-            )
         if (
             ((max(in_channels, out_channels) // groups) < 2)
             and (self.kernel_size[0] != self.stride[0])
             and (self.kernel_size[1] != self.stride[1])
         ):
-            raise RuntimeError("inner conv must have at least 2 channels")
+            raise ValueError("inner conv must have at least 2 channels")
         if out_channels * (self.stride[0] * self.stride[1]) >= in_channels:
             self.intermediate_channels = max(
                 in_channels // (self.stride[0] * self.stride[1]), out_channels
@@ -234,12 +233,12 @@ def __init__(
             self.intermediate_channels = out_channels
             # raise warning because this configuration don't yield orthogonal
             # convolutions
-            warnings.warn(
-                "This configuration does not yield orthogonal convolutions due to "
-                "padding issues: pytorch does not implement circular padding for "
-                "transposed convolutions",
-                RuntimeWarning,
-            )
+            # warnings.warn(
+            #     "This configuration does not yield orthogonal convolutions due to "
+            #     "padding issues: pytorch does not implement circular padding for "
+            #     "transposed convolutions",
+            #     RuntimeWarning,
+            # )
         del self.weight
         attach_bcop_weight(
             self,
@@ -267,8 +266,47 @@ def __init__(
             ortho_params=ortho_params,
         )
 
+    def _calculate_same_padding(self) -> tuple:
+        """Calculate padding for 'same' mode."""
+        return (
+            int(
+                np.ceil(
+                    (self.dilation[0] * (self.kernel_size[0] - 1) + 1 - self.stride[0])
+                    / 2
+                )
+            ),
+            int(
+                np.floor(
+                    (self.dilation[0] * (self.kernel_size[0] - 1) + 1 - self.stride[0])
+                    / 2
+                )
+            ),
+            int(
+                np.ceil(
+                    (self.dilation[1] * (self.kernel_size[1] - 1) + 1 - self.stride[1])
+                    / 2
+                )
+            ),
+            int(
+                np.floor(
+                    (self.dilation[1] * (self.kernel_size[1] - 1) + 1 - self.stride[1])
+                    / 2
+                )
+            ),
+        )
+
+    def _standardize_padding(self, padding: _size_2_t) -> tuple:
+        """Ensure padding is always a tuple."""
+        if isinstance(padding, int):
+            padding = (padding, padding)
+        if isinstance(padding, tuple):
+            if len(padding) == 2:
+                padding = (padding[0], padding[0], padding[1], padding[1])
+            return padding
+        raise ValueError(f"padding must be int or tuple, got {type(padding)} instead")
+
     def singular_values(self):
-        if self.padding_mode != "circular":
+        if self.real_padding_mode != "circular":
             print(
                 f"padding {self.padding} not supported, return min and max"
                 f"singular values as if it was 'circular' padding "
@@ -301,7 +339,9 @@ def singular_values(self):
         )
         sv_min = sv_min * svs_2.min()
         sv_max = sv_max * svs_2.max()
-        stable_rank = 0.5 * stable_rank + 0.5 * (np.mean(svs_2) / (svs_2.max() ** 2))
+        stable_rank = 0.5 * stable_rank + 0.5 * (
+            np.mean(svs_2) ** 2 / (svs_2.max() ** 2)
+        )
         return sv_min, sv_max, stable_rank
 
     @property
@@ -315,4 +355,29 @@ def weight(self):
 
     def forward(self, X):
         self._input_shape = X.shape[2:]
-        return super(BcopRkoConvTranspose2d, self).forward(X)
+        if self.real_padding_mode != "zeros":
+            X = nn.functional.pad(X, self.real_padding, self.real_padding_mode)
+            y = nn.functional.conv_transpose2d(
+                X,
+                self.weight,
+                self.bias,
+                self.stride,
+                (
+                    (
+                        -self.stride[0]
+                        + self.dilation[0] * (self.kernel_size[0] - 1)
+                        + 1
+                    ),
+                    (
+                        -self.stride[1]
+                        + self.dilation[1] * (self.kernel_size[1] - 1)
+                        + 1
+                    ),
+                ),
+                self.output_padding,
+                self.groups,
+                dilation=self.dilation,
+            )
+            return y
+        else:
+            return super(BcopRkoConvTranspose2d, self).forward(X)
diff --git a/orthogonium/layers/conv/AOC/fast_block_ortho_conv.py b/orthogonium/layers/conv/AOC/fast_block_ortho_conv.py
index befd4a9..277ba5b 100644
--- a/orthogonium/layers/conv/AOC/fast_block_ortho_conv.py
+++ b/orthogonium/layers/conv/AOC/fast_block_ortho_conv.py
@@ -7,8 +7,8 @@
 import torch.nn.utils.parametrize as parametrize
 from torch.nn.common_types import _size_2_t
 
-from orthogonium.layers.linear.reparametrizers import L2Normalize
-from orthogonium.layers.linear.reparametrizers import OrthoParams
+from orthogonium.reparametrizers import L2Normalize
+from orthogonium.reparametrizers import OrthoParams
 
 
 def conv_singular_values_numpy(kernel, input_shape):
@@ -22,7 +22,7 @@ def conv_singular_values_numpy(kernel, input_shape):
         svs = np.linalg.svd(
             transforms, compute_uv=False, full_matrices=False
         )  # g, k1, k2, min(ci, co)
-        stable_rank = np.mean(svs) / svs.max()
+        stable_rank = (np.mean(svs) ** 2) / svs.max()
         return svs.min(), svs.max(), stable_rank
     except np.linalg.LinAlgError:
         print("numerical error in svd, returning only largest singular value")
@@ -350,7 +350,7 @@ def attach_bcop_weight(
     return weight
 
 
-class FlashBCOP(nn.Conv2d):
+class FastBlockConv2d(nn.Conv2d):
     def __init__(
         self,
         in_channels: int,
@@ -377,7 +377,7 @@ def __init__(
         Striding is not supported when out_channels > in_channels. Real striding is supported in BcopRkoConv2d. The use of
         OrthogonalConv2d is recommended.
         """
-        super(FlashBCOP, self).__init__(
+        super(FastBlockConv2d, self).__init__(
             in_channels,
             out_channels,
             kernel_size,
@@ -390,22 +390,26 @@ def __init__(
         )
 
         # raise runtime error if kernel size >= stride
+        if self.kernel_size[0] < self.stride[0] or self.kernel_size[1] < self.stride[1]:
+            raise ValueError(
+                "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
+            )
         if (
             (self.stride[0] > 1 or self.stride[1] > 1) and (out_channels > in_channels)
         ) or (
             self.stride[0] > self.kernel_size[0] or self.stride[1] > self.kernel_size[1]
         ):
-            raise RuntimeError(
+            raise ValueError(
                 "stride > 1 is not supported when out_channels > in_channels, "
                 "use TODO layer instead"
             )
-        if self.kernel_size[0] < self.stride[0] or self.kernel_size[1] < self.stride[1]:
-            raise RuntimeError(
-                "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
-            )
-        if (in_channels % groups != 0) and (out_channels % groups != 0):
-            raise RuntimeError(
-                "in_channels and out_channels must be divisible by groups"
+        if (
+            (self.out_channels >= self.in_channels)
+            and (((self.dilation[0] % self.stride[0]) == 0) and (self.stride[0] > 1))
+            and (((self.dilation[1] % self.stride[1]) == 0) and (self.stride[1] > 1))
+        ):
+            raise ValueError(
+                "dilation must be 1 when stride is not 1. The set of orthonal convolutions is empty in this setting."
             )
         del self.weight
         attach_bcop_weight(
@@ -455,10 +459,10 @@ def singular_values(self):
 
     def forward(self, X):
         self._input_shape = X.shape[2:]
-        return super(FlashBCOP, self).forward(X)
+        return super(FastBlockConv2d, self).forward(X)
 
 
-class BCOPTranspose(nn.ConvTranspose2d):
+class FastBlockConvTranspose2D(nn.ConvTranspose2d):
     def __init__(
         self,
         in_channels: int,
@@ -478,42 +482,41 @@ def __init__(
         uses the same algorithm as the FlashBCOP layer, but the layer acts as a transposed
         convolutional layer.
         """
-        super(BCOPTranspose, self).__init__(
+        super(FastBlockConvTranspose2D, self).__init__(
             in_channels,
             out_channels,
             kernel_size,
             stride,
-            padding,
+            padding if padding_mode == "zeros" else 0,
             output_padding,
             groups,
             bias,
             dilation,
-            padding_mode,
+            "zeros",
         )
-        # raise runtime error if kernel size >= stride
+        self.real_padding_mode = padding_mode
+        if padding == "same":
+            padding = self._calculate_same_padding()
+        self.real_padding = self._standardize_padding(padding)
+
+        if (
+            (self.out_channels <= self.in_channels)
+            and (((self.dilation[0] % self.stride[0]) == 0) and (self.stride[0] > 1))
+            and (((self.dilation[1] % self.stride[1]) == 0) and (self.stride[1] > 1))
+        ):
+            raise ValueError(
+                "dilation must be 1 when stride is not 1. The set of orthonal convolutions is empty in this setting."
+            )
         if self.kernel_size[0] < self.stride[0] or self.kernel_size[1] < self.stride[1]:
-            raise RuntimeError(
+            raise ValueError(
                 "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
             )
-        if (in_channels % groups != 0) and (out_channels % groups != 0):
-            raise RuntimeError(
-                "in_channels and out_channels must be divisible by groups"
-            )
         if (
             ((max(in_channels, out_channels) // groups) < 2)
             and (self.kernel_size[0] != self.stride[0])
             and (self.kernel_size[1] != self.stride[1])
         ):
-            raise RuntimeError("inner conv must have at least 2 channels")
-        if out_channels * (self.stride[0] * self.stride[1]) < in_channels:
-            # raise warning because this configuration don't yield orthogonal
-            # convolutions
-            warnings.warn(
-                "This configuration does not yield orthogonal convolutions due to "
-                "padding issues: pytorch does not implement circular padding for "
-                "transposed convolutions",
-                RuntimeWarning,
-            )
+            raise ValueError("inner conv must have at least 2 channels")
         del self.weight
         attach_bcop_weight(
             self,
@@ -528,6 +531,45 @@ def __init__(
             ortho_params=ortho_params,
         )
 
+    def _calculate_same_padding(self) -> tuple:
+        """Calculate padding for 'same' mode."""
+        return (
+            int(
+                np.ceil(
+                    (self.dilation[0] * (self.kernel_size[0] - 1) + 1 - self.stride[0])
+                    / 2
+                )
+            ),
+            int(
+                np.floor(
+                    (self.dilation[0] * (self.kernel_size[0] - 1) + 1 - self.stride[0])
+                    / 2
+                )
+            ),
+            int(
+                np.ceil(
+                    (self.dilation[1] * (self.kernel_size[1] - 1) + 1 - self.stride[1])
+                    / 2
+                )
+            ),
+            int(
+                np.floor(
+                    (self.dilation[1] * (self.kernel_size[1] - 1) + 1 - self.stride[1])
+                    / 2
+                )
+            ),
+        )
+
+    def _standardize_padding(self, padding: _size_2_t) -> tuple:
+        """Ensure padding is always a tuple."""
+        if isinstance(padding, int):
+            padding = (padding, padding)
+        if isinstance(padding, tuple):
+            if len(padding) == 2:
+                padding = (padding[0], padding[0], padding[1], padding[1])
+            return padding
+        raise ValueError(f"padding must be int or tuple, got {type(padding)} instead")
+
     def singular_values(self):
         if self.padding_mode != "circular":
             print(
@@ -552,4 +594,29 @@ def singular_values(self):
 
     def forward(self, X):
         self._input_shape = X.shape[2:]
-        return super(BCOPTranspose, self).forward(X)
+        if self.real_padding_mode != "zeros":
+            X = nn.functional.pad(X, self.real_padding, self.real_padding_mode)
+            y = nn.functional.conv_transpose2d(
+                X,
+                self.weight,
+                self.bias,
+                self.stride,
+                (
+                    (
+                        -self.stride[0]
+                        + self.dilation[0] * (self.kernel_size[0] - 1)
+                        + 1
+                    ),
+                    (
+                        -self.stride[1]
+                        + self.dilation[1] * (self.kernel_size[1] - 1)
+                        + 1
+                    ),
+                ),
+                self.output_padding,
+                self.groups,
+                dilation=self.dilation,
+            )
+            return y
+        else:
+            return super(FastBlockConvTranspose2D, self).forward(X)
diff --git a/orthogonium/layers/conv/AOC/ortho_conv.py b/orthogonium/layers/conv/AOC/ortho_conv.py
index 98e220c..5b4eb21 100644
--- a/orthogonium/layers/conv/AOC/ortho_conv.py
+++ b/orthogonium/layers/conv/AOC/ortho_conv.py
@@ -5,11 +5,11 @@
 
 from orthogonium.layers.conv.AOC.bcop_x_rko_conv import BcopRkoConv2d
 from orthogonium.layers.conv.AOC.bcop_x_rko_conv import BcopRkoConvTranspose2d
-from orthogonium.layers.conv.AOC.fast_block_ortho_conv import BCOPTranspose
-from orthogonium.layers.conv.AOC.fast_block_ortho_conv import FlashBCOP
+from orthogonium.layers.conv.AOC.fast_block_ortho_conv import FastBlockConvTranspose2D
+from orthogonium.layers.conv.AOC.fast_block_ortho_conv import FastBlockConv2d
 from orthogonium.layers.conv.AOC.rko_conv import RKOConv2d
 from orthogonium.layers.conv.AOC.rko_conv import RkoConvTranspose2d
-from orthogonium.layers.linear.reparametrizers import OrthoParams
+from orthogonium.reparametrizers import OrthoParams
 
 
 def AdaptiveOrthoConv2d(
@@ -25,33 +25,56 @@ def AdaptiveOrthoConv2d(
     ortho_params: OrthoParams = OrthoParams(),
 ) -> nn.Conv2d:
     """
-    factory function to create an Orthogonal Convolutional layer
-    choosing the appropriate class depending on the kernel size and stride.
+    Factory function to create an orthogonal convolutional layer, selecting the appropriate class based on kernel size and stride.
 
-    When kernel_size == stride, the layer is a RKOConv2d.
-    When stride == 1, the layer is a FlashBCOP.
-    Otherwise, the layer is a BcopRkoConv2d.
+    **Key Features:**
+    - Enforces orthogonality, preserving gradient norms.
+    - Supports native striding, dilation, grouped convolutions, and flexible padding.
+
+    **Behavior:**
+    - When kernel_size == stride, the layer is an `RKOConv2d`.
+    - When stride == 1, the layer is a `FastBlockConv2d`.
+    - Otherwise, the layer is a `BcopRkoConv2d`.
+
+    **Arguments:**
+    - `in_channels` (int): Number of input channels.
+    - `out_channels` (int): Number of output channels.
+    - `kernel_size` (_size_2_t): Size of the convolution kernel.
+    - `stride` (_size_2_t, optional): Stride of the convolution. Default is 1.
+    - `padding` (str or _size_2_t, optional): Padding mode or size. Default is "same".
+    - `dilation` (_size_2_t, optional): Dilation rate. Default is 1.
+    - `groups` (int, optional): Number of blocked connections from input to output channels. Default is 1.
+    - `bias` (bool, optional): Whether to include a learnable bias. Default is True.
+    - `padding_mode` (str, optional): Padding mode. Default is "circular".
+    - `ortho_params` (OrthoParams, optional): Parameters to control orthogonality. Default is `OrthoParams()`.
+
+    **Returns:**
+    - A configured instance of `nn.Conv2d` (one of `RKOConv2d`, `FastBlockConv2d`, or `BcopRkoConv2d`).
+
+    **Raises:**
+    - `ValueError`: If kernel_size < stride, as orthogonality cannot be enforced.
     """
+
     if kernel_size < stride:
-        raise RuntimeError(
+        raise ValueError(
             "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
         )
     if kernel_size == stride:
         convclass = RKOConv2d
     elif (stride == 1) or (in_channels >= out_channels):
-        convclass = FlashBCOP
+        convclass = FastBlockConv2d
     else:
         convclass = BcopRkoConv2d
     return convclass(
-        in_channels,
-        out_channels,
-        kernel_size,
-        stride,
-        padding,
-        dilation,
-        groups,
-        bias,
-        padding_mode,
+        in_channels=in_channels,
+        out_channels=out_channels,
+        kernel_size=kernel_size,
+        stride=stride,
+        padding=padding,
+        dilation=dilation,
+        groups=groups,
+        bias=bias,
+        padding_mode=padding_mode,
         ortho_params=ortho_params,
     )
 
@@ -70,37 +93,57 @@ def AdaptiveOrthoConvTranspose2d(
     ortho_params: OrthoParams = OrthoParams(),
 ) -> nn.ConvTranspose2d:
     """
-    factory function to create an Orthogonal Convolutional Transpose layer
-    choosing the appropriate class depending on the kernel size and stride.
+    Factory function to create an orthogonal convolutional transpose layer, adapting based on kernel size and stride.
 
-    As we handle native striding with explicit kernel. It unlocks
-    the possibility to use the same parametrization for transposed convolutions.
-    This class uses the same interface as the ConvTranspose2d class.
+    **Key Features:**
+    - Ensures orthogonality in transpose convolutions for stable gradient propagation.
+    - Supports dilation, grouped operations, and efficient kernel construction.
 
-    Unfortunately, circular padding is not supported for the transposed convolution.
-    But unit testing have shown that the convolution is still orthogonal when
-        `out_channels * (stride**2) > in_channels`.
+    **Behavior:**
+    - When kernel_size == stride, the layer is an `RkoConvTranspose2d`.
+    - When stride == 1, the layer is a `FastBlockConvTranspose2D`.
+    - Otherwise, the layer is a `BcopRkoConvTranspose2d`.
+
+    **Arguments:**
+    - `in_channels` (int): Number of input channels.
+    - `out_channels` (int): Number of output channels.
+    - `kernel_size` (_size_2_t): Size of the convolution kernel.
+    - `stride` (_size_2_t, optional): Stride of the transpose convolution. Default is 1.
+    - `padding` (_size_2_t, optional): Padding size. Default is 0.
+    - `output_padding` (_size_2_t, optional): Additional size for output. Default is 0.
+    - `groups` (int, optional): Number of groups. Default is 1.
+    - `bias` (bool, optional): Whether to include a learnable bias. Default is True.
+    - `dilation` (_size_2_t, optional): Dilation rate. Default is 1.
+    - `padding_mode` (str, optional): Padding mode. Default is "zeros".
+    - `ortho_params` (OrthoParams, optional): Parameters to control orthogonality. Default is `OrthoParams()`.
+
+    **Returns:**
+    - A configured instance of `nn.ConvTranspose2d` (one of `RkoConvTranspose2d`, `FastBlockConvTranspose2D`, or `BcopRkoConvTranspose2d`).
+
+    **Raises:**
+    - `ValueError`: If kernel_size < stride, as orthogonality cannot be enforced.
     """
+
     if kernel_size < stride:
-        raise RuntimeError(
+        raise ValueError(
             "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
         )
     if kernel_size == stride:
         convclass = RkoConvTranspose2d
     elif stride == 1:
-        convclass = BCOPTranspose
+        convclass = FastBlockConvTranspose2D
     else:
         convclass = BcopRkoConvTranspose2d
     return convclass(
-        in_channels,
-        out_channels,
-        kernel_size,
-        stride,
-        padding,
-        output_padding,
-        groups,
-        bias,
-        dilation,
-        padding_mode,
+        in_channels=in_channels,
+        out_channels=out_channels,
+        kernel_size=kernel_size,
+        stride=stride,
+        padding=padding,
+        output_padding=output_padding,
+        groups=groups,
+        bias=bias,
+        dilation=dilation,
+        padding_mode=padding_mode,
         ortho_params=ortho_params,
     )
diff --git a/orthogonium/layers/conv/AOC/rko_conv.py b/orthogonium/layers/conv/AOC/rko_conv.py
index f4e2a38..8fdc214 100644
--- a/orthogonium/layers/conv/AOC/rko_conv.py
+++ b/orthogonium/layers/conv/AOC/rko_conv.py
@@ -8,7 +8,7 @@
 from torch.nn.common_types import _size_2_t
 
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import conv_singular_values_numpy
-from orthogonium.layers.linear.reparametrizers import OrthoParams
+from orthogonium.reparametrizers import OrthoParams
 
 
 class RKOParametrizer(nn.Module):
@@ -54,7 +54,9 @@ def forward(self, X):
         )
         X = self.pi(X)
         X = self.bjorck(X)
-        X = X.view(self.out_channels, self.in_channels // self.groups, self.k1, self.k2)
+        X = X.reshape(
+            self.out_channels, self.in_channels // self.groups, self.k1, self.k2
+        )
         return X * self.scale
 
     def right_inverse(self, X):
@@ -120,15 +122,18 @@ def __init__(
             bias,
             padding_mode,
         )
-        if self.dilation[0] > 1 or self.dilation[1] > 1:
-            raise RuntimeWarning(
-                "Dilation must be 1 in the RKO convolution."
-                "Use RkoConvTranspose2d instead."
+        if (
+            True  # (self.out_channels >= self.in_channels) # investigate why it don't work
+            and (((self.dilation[0] % self.stride[0]) == 0) and (self.stride[0] > 1))
+            and (((self.dilation[1] % self.stride[1]) == 0) and (self.stride[1] > 1))
+        ):
+            raise ValueError(
+                "dilation must be 1 when stride is not 1. The set of orthogonal convolutions is empty in this setting."
             )
-        # torch.nn.init.orthogonal_(self.weight)
+        torch.nn.init.orthogonal_(self.weight)
         self.scale = 1 / math.sqrt(
-            math.ceil(self.dilation[0] * self.kernel_size[0] / self.stride[0])
-            * math.ceil(self.dilation[1] * self.kernel_size[1] / self.stride[1])
+            math.ceil(self.kernel_size[0] / self.stride[0])
+            * math.ceil(self.kernel_size[1] / self.stride[1])
         )
         parametrize.register_parametrization(
             self,
@@ -168,11 +173,11 @@ def singular_values(self):
             )
             sv_min = svs.min()
             sv_max = svs.max()
-            stable_rank = np.mean(svs) / (svs.max() ** 2)
+            stable_rank = (np.mean(svs) ** 2) / (svs.max() ** 2)
             return sv_min, sv_max, stable_rank
         elif self.stride[0] > 1 or self.stride[1] > 1:
             raise RuntimeError(
-                "Not able to compute singular values for this " "configuration"
+                "Not able to compute singular values for this configuration"
             )
         # Implements interface required by LipschitzModuleL2
         sv_min, sv_max, stable_rank = conv_singular_values_numpy(
@@ -211,24 +216,29 @@ def __init__(
             out_channels,
             kernel_size,
             stride,
-            padding,
+            padding if padding_mode == "zeros" else 0,
             output_padding,
             groups,
             bias,
             dilation,
-            padding_mode,
+            "zeros",
         )
-
-        # raise runtime error if kernel size >= stride
-        if self.kernel_size[0] > self.stride[0] or self.kernel_size[1] > self.stride[1]:
-            raise RuntimeError(
-                "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
+        self.real_padding_mode = padding_mode
+        if padding == "same":
+            padding = self._calculate_same_padding()
+        self.real_padding = self._standardize_padding(padding)
+        if self.kernel_size[0] < self.stride[0] or self.kernel_size[1] < self.stride[1]:
+            raise ValueError(
+                "kernel size must be smaller than stride. The set of orthogonal convolutions is empty in this setting."
             )
-        if (in_channels % groups != 0) and (out_channels % groups != 0):
-            raise RuntimeError(
-                "in_channels and out_channels must be divisible by groups"
+        if (
+            (self.out_channels <= self.in_channels)
+            and (((self.dilation[0] % self.stride[0]) == 0) and (self.stride[0] > 1))
+            and (((self.dilation[1] % self.stride[1]) == 0) and (self.stride[1] > 1))
+        ):
+            raise ValueError(
+                "dilation must be 1 when stride is not 1. The set of orthonal convolutions is empty in this setting."
             )
-
         if (
             self.stride[0] != self.kernel_size[0]
             or self.stride[1] != self.kernel_size[1]
@@ -249,6 +259,45 @@ def __init__(
             ortho_params=ortho_params,
         )
 
+    def _calculate_same_padding(self) -> tuple:
+        """Calculate padding for 'same' mode."""
+        return (
+            int(
+                np.ceil(
+                    (self.dilation[0] * (self.kernel_size[0] - 1) + 1 - self.stride[0])
+                    / 2
+                )
+            ),
+            int(
+                np.floor(
+                    (self.dilation[0] * (self.kernel_size[0] - 1) + 1 - self.stride[0])
+                    / 2
+                )
+            ),
+            int(
+                np.ceil(
+                    (self.dilation[1] * (self.kernel_size[1] - 1) + 1 - self.stride[1])
+                    / 2
+                )
+            ),
+            int(
+                np.floor(
+                    (self.dilation[1] * (self.kernel_size[1] - 1) + 1 - self.stride[1])
+                    / 2
+                )
+            ),
+        )
+
+    def _standardize_padding(self, padding: _size_2_t) -> tuple:
+        """Ensure padding is always a tuple."""
+        if isinstance(padding, int):
+            padding = (padding, padding)
+        if isinstance(padding, tuple):
+            if len(padding) == 2:
+                padding = (padding[0], padding[0], padding[1], padding[1])
+            return padding
+        raise ValueError(f"padding must be int or tuple, got {type(padding)} instead")
+
     def singular_values(self):
         if (self.stride[0] == self.kernel_size[0]) and (
             self.stride[1] == self.kernel_size[1]
@@ -267,11 +316,11 @@ def singular_values(self):
             )
             sv_min = svs.min()
             sv_max = svs.max()
-            stable_rank = np.mean(svs) / (svs.max() ** 2)
+            stable_rank = (np.mean(svs) ** 2) / (svs.max() ** 2)
             return sv_min, sv_max, stable_rank
         elif self.stride[0] > 1 or self.stride[1] > 1:
             raise RuntimeError(
-                "Not able to compute singular values for this " "configuration"
+                "Not able to compute singular values for this configuration"
             )
         # Implements interface required by LipschitzModuleL2
         sv_min, sv_max, stable_rank = conv_singular_values_numpy(
@@ -291,4 +340,29 @@ def singular_values(self):
 
     def forward(self, X):
         self._input_shape = X.shape[2:]
-        return super(RkoConvTranspose2d, self).forward(X)
+        if self.real_padding_mode != "zeros":
+            X = nn.functional.pad(X, self.real_padding, self.real_padding_mode)
+            y = nn.functional.conv_transpose2d(
+                X,
+                self.weight,
+                self.bias,
+                self.stride,
+                (
+                    (
+                        -self.stride[0]
+                        + self.dilation[0] * (self.kernel_size[0] - 1)
+                        + 1
+                    ),
+                    (
+                        -self.stride[1]
+                        + self.dilation[1] * (self.kernel_size[1] - 1)
+                        + 1
+                    ),
+                ),
+                self.output_padding,
+                self.groups,
+                dilation=self.dilation,
+            )
+            return y
+        else:
+            return super(RkoConvTranspose2d, self).forward(X)
diff --git a/orthogonium/layers/conv/SLL/sll_layer.py b/orthogonium/layers/conv/SLL/sll_layer.py
index c553342..e8716ca 100644
--- a/orthogonium/layers/conv/SLL/sll_layer.py
+++ b/orthogonium/layers/conv/SLL/sll_layer.py
@@ -1,3 +1,52 @@
+"""
+# SSL derived 1-Lipschitz Layers
+
+This module implements several 1-Lipschitz residual blocks, inspired by and extending
+the SDP-based Lipschitz Layers (SLL) from [1]. Specifically:
+
+- **`SDPBasedLipschitzResBlock`**  
+  The original version of the 1-Lipschitz convolutional residual block. It enforces Lipschitz
+  constraints by rescaling activation outputs according to an estimate of the operator norm.
+
+- **`SLLxAOCLipschitzResBlock`**  
+  An extended version of the SLL approach described in [1], combined with additional orthogonal
+  convolutions to handle stride, kernel-size, or channel-dimension changes. It fuses multiple
+  convolutions via the block convolution, thereby preserving the 1-Lipschitz property while enabling
+  strided downsampling or modifying input/output channels.
+
+- **`AOCLipschitzResBlock`**  
+  A variant of the original Lipschitz block where the core convolution is replaced by an
+  `AdaptiveOrthoConv2d`. It maintains the 1-Lipschitz property with orthogonal weight
+  parameterization while providing efficient convolution implementations.
+
+## References
+
+[1] Alexandre Araujo, Aaron J Havens, Blaise Delattre, Alexandre Allauzen, and Bin Hu. A unified alge-
+braic perspective on lipschitz neural networks. In The Eleventh International Conference on Learning
+Representations, 2023
+[2] Thibaut Boissin, Franck Mamalet, Thomas Fel, Agustin Martin Picard, Thomas Massena, Mathieu Serrurier,
+An Adaptive Orthogonal Convolution Scheme for Efficient and Flexible CNN Architectures
+
+## Notes on the SLL approach
+
+In [1], the SLL layer for convolutions is a 1-Lipschitz residual operation defined approximately as:
+
+$$
+y = x - \mathbf{K}^T \\star (\sigma(\\mathbf{K} \\star x + b)),
+$$
+
+where $\mathbf{K}$ represents a toeplitz (convolution) matrix with suitable norm constraints.
+
+By default, the SLL formulation does **not** allow strides or changes in the number of channels.  
+To address these issues, `SLLxAOCLipschitzResBlock` adds extra orthogonal convolutions before and/or
+after the main SLL operation. These additional convolutions can be merged via block convolution
+(Proposition 1 in [2]) to maintain 1-Lipschitz behavior while enabling stride and/or channel changes.
+
+When $\mathbf{K}$, $\mathbf{K}_{pre}$, and $\mathbf{K}_{post}$ each correspond to 2Γ2 convolutions,
+the resulting block effectively contains two 3Γ3 convolutions in one branch and a single 4Γ4 stride-2
+convolution in the skip branchβquite similar to typical ResNet blocks.
+"""
+
 import numpy as np
 import torch
 import torch.nn as nn
@@ -7,7 +56,7 @@
 from orthogonium.layers import AdaptiveOrthoConv2d
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import fast_matrix_conv
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import transpose_kernel
-from orthogonium.layers.linear.reparametrizers import OrthoParams
+from orthogonium.reparametrizers import OrthoParams
 
 
 def safe_inv(x):
@@ -17,8 +66,32 @@ def safe_inv(x):
     return x_inv
 
 
-class SDPBasedLipschitzResBlock(nn.Module):
+class SLLxAOCLipschitzResBlock(nn.Module):
     def __init__(self, cin, cout, inner_dim_factor, kernel_size=3, stride=2, **kwargs):
+        """
+        Extended SLL-based convolutional residual block. Supports arbitrary kernel sizes,
+        strides, and changes in the number of channels by integrating additional
+        orthogonal convolutions *and* fusing them via `\mathbconv` [1].
+
+        The forward pass follows:
+
+        $$
+        \displaystyle
+        \text{out} \;=\; x \;-\; 2 \,\convtrcode{K}{}{1}\Bigl(
+          \; t \,\sigma\bigl(\convcode{K}{}{1}(x) \;+\; b \bigr)
+        \Bigr),
+        $$
+
+        where the kernel `\kernel{K}` may effectively be expanded by pre/post AOC layers to
+        handle stride and channel changes. This approach is described in "Improving
+        SDP-based Lipschitz Layers" of [1].
+
+        **Args**:
+          - `cin` (int): Number of input channels.
+          - `inner_dim_factor` (float): Multiplier for the internal channel dimension.
+          - `kernel_size` (int, optional): Base kernel size for the SLL portion. Default is 3.
+          - `**kwargs`: Additional options (unused).
+        """
         super().__init__()
         inner_kernel_size = kernel_size - (stride - 1)
         self.skip_kernel_size = stride + (stride // 2)
@@ -80,8 +153,29 @@ def forward(self, x):
         return out
 
 
-class SLLxAOCLipschitzResBlock(nn.Module):
+class SDPBasedLipschitzResBlock(nn.Module):
     def __init__(self, cin, inner_dim_factor, kernel_size=3, **kwargs):
+        """
+         Original 1-Lipschitz convolutional residual block, based on the SDP-based Lipschitz
+        layer (SLL) approach [1]. It has a structure akin to:
+
+        out = x - 2 * ConvTranspose( t * ReLU(Conv(x) + bias) )
+
+        where `t` is a channel-wise scaling factor ensuring a Lipschitz constant β€ 1.
+
+        !!! note
+            By default, `SDPBasedLipschitzResBlock` assumes `cin == cout` and does **not** handle
+            stride changes outside the skip connection (i.e., typically used when stride=1 or 2
+            for downsampling in a standard residual architecture).
+
+        **Args**:
+          - `cin` (int): Number of input channels.
+          - `cout` (int): Number of output channels.
+          - `inner_dim_factor` (float): Multiplier for the intermediate dimensionality.
+          - `kernel_size` (int, optional): Size of the convolution kernel. Default is 3.
+          - `stride` (int, optional): Stride for the skip connection. Default is 2.
+          - `**kwargs`: Additional keyword arguments (unused).
+        """
         super().__init__()
 
         inner_dim = int(cin * inner_dim_factor)
@@ -122,6 +216,22 @@ def forward(self, x):
 
 class SDPBasedLipschitzDense(nn.Module):
     def __init__(self, in_features, out_features, inner_dim, **kwargs):
+        """
+        A 1-Lipschitz fully-connected layer (dense version). Similar to the convolutional
+        SLL approach, but operates on vectors:
+
+        $$
+        \displaystyle
+        \text{out} \;=\; x - 2\, W^\top \Bigl(
+          t \,\sigma\bigl(W\,x + b\bigr)
+        \Bigr).
+        $$
+
+        **Args**:
+          - `in_features` (int): Input size.
+          - `out_features` (int): Output size (must match `in_features` to remain 1-Lipschitz).
+          - `inner_dim` (int): The internal dimension used for the transform.
+        """
         super().__init__()
 
         inner_dim = inner_dim if inner_dim != -1 else in_features
@@ -167,6 +277,28 @@ def __init__(
         padding_mode: str = "circular",
         ortho_params: OrthoParams = OrthoParams(),
     ):
+        """
+        A Lipschitz residual block in which the main convolution is replaced by
+        `AdaptiveOrthoConv2d` (AOC). This preserves 1-Lipschitz (or lower) behavior through
+        an orthogonal parameterization, without explicitly computing a scaling factor `t`.
+
+        $$
+        \displaystyle
+        \text{out} = x \;-\; 2\,\convtrcode{K}{}{1}\Bigl(
+          \;\sigma\bigl(\convcode{K}{}{1}(x)\bigr)
+        \Bigr).
+        $$
+
+        **Args**:
+          - `in_channels` (int): Number of input channels.
+          - `inner_dim_factor` (int): Multiplier for internal representation size.
+          - `kernel_size` (_size_2_t): Convolution kernel size.
+          - `dilation` (_size_2_t, optional): Default is 1.
+          - `groups` (int, optional): Default is 1.
+          - `bias` (bool, optional): If True, adds a learnable bias. Default is True.
+          - `padding_mode` (str, optional): `'circular'` or `'zeros'`. Default is `'circular'`.
+          - `ortho_params` (OrthoParams, optional): Orthogonal parameterization settings. Default is `OrthoParams()`.
+        """
         super().__init__()
 
         inner_dim = int(in_channels * inner_dim_factor)
@@ -200,19 +332,20 @@ def __init__(
     def forward(self, x):
         kernel = self.in_conv.weight
         # conv
+        res = x
         if self.padding_mode == "circular":
-            x = F.pad(
-                x,
-                (
-                    self.kernel_size // 2,
-                    self.kernel_size // 2,
-                    self.kernel_size // 2,
-                    self.kernel_size // 2,
-                ),
+            res = F.pad(
+                res,
+                (self.padding,) * 4,
                 mode="circular",
+                value=0,
             )
         res = F.conv2d(
-            x, kernel, bias=self.in_conv.bias, padding=self.padding, groups=self.groups
+            res,
+            kernel,
+            bias=self.in_conv.bias,
+            padding=0,
+            groups=self.groups,
         )
         # activation
         res = self.activation(res)
@@ -220,17 +353,11 @@ def forward(self, x):
         if self.padding_mode == "circular":
             res = F.pad(
                 res,
-                (
-                    self.kernel_size // 2,
-                    self.kernel_size // 2,
-                    self.kernel_size // 2,
-                    self.kernel_size // 2,
-                ),
+                (self.padding,) * 4,
                 mode="circular",
+                value=0,
             )
-        res = 2 * F.conv_transpose2d(
-            res, kernel, padding=self.padding, groups=self.groups
-        )
+        res = 2 * F.conv_transpose2d(res, kernel, padding=0, groups=self.groups)
         # residual
         out = x - res
         return out
diff --git a/orthogonium/layers/conv/adaptiveSOC/fast_skew_ortho_conv.py b/orthogonium/layers/conv/adaptiveSOC/fast_skew_ortho_conv.py
index 078d42e..8ba1041 100644
--- a/orthogonium/layers/conv/adaptiveSOC/fast_skew_ortho_conv.py
+++ b/orthogonium/layers/conv/adaptiveSOC/fast_skew_ortho_conv.py
@@ -230,17 +230,15 @@ def __init__(
 
         # raise runtime error if kernel size >= stride
         if ((stride > 1) and (out_channels > in_channels)) or (stride > kernel_size):
-            raise RuntimeError(
+            raise ValueError(
                 "stride > 1 is not supported when out_channels > in_channels, "
                 "use TODO layer instead"
             )
         if kernel_size < stride:
-            raise RuntimeError(
+            raise ValueError(
                 "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
             )
         if (in_channels % groups != 0) and (out_channels % groups != 0):
-            raise RuntimeError(
-                "in_channels and out_channels must be divisible by groups"
             )
         self.padding = padding
         self.stride = stride
@@ -340,15 +338,13 @@ def __init__(
 
         # raise runtime error if kernel size >= stride
         if kernel_size < stride:
-            raise RuntimeError(
+            raise ValueError(
                 "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
             )
         if (in_channels % groups != 0) and (out_channels % groups != 0):
-            raise RuntimeError(
-                "in_channels and out_channels must be divisible by groups"
             )
         if ((self.max_channels // groups) < 2) and (kernel_size != stride):
-            raise RuntimeError("inner conv must have at least 2 channels")
+            raise ValueError("inner conv must have at least 2 channels")
         if out_channels * (stride**2) < in_channels:
             # raise warning because this configuration don't yield orthogonal
             # convolutions
diff --git a/orthogonium/layers/conv/adaptiveSOC/ortho_conv.py b/orthogonium/layers/conv/adaptiveSOC/ortho_conv.py
index 3d49897..92e471a 100644
--- a/orthogonium/layers/conv/adaptiveSOC/ortho_conv.py
+++ b/orthogonium/layers/conv/adaptiveSOC/ortho_conv.py
@@ -11,7 +11,7 @@
     SOCRkoConv2d,
     SOCRkoConvTranspose2d,
 )
-from orthogonium.layers.linear.reparametrizers import OrthoParams
+from orthogonium.reparametrizers import OrthoParams
 
 
 def AdaptiveSOCConv2d(
@@ -35,7 +35,7 @@ def AdaptiveSOCConv2d(
     Otherwise, the layer is a BcopRkoConv2d.
     """
     if kernel_size < stride:
-        raise RuntimeError(
+        raise ValueError(
             "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
         )
     if kernel_size == stride:
@@ -84,7 +84,7 @@ def AdaptiveSOCConvTranspose2d(
         `out_channels * (stride**2) > in_channels`.
     """
     if kernel_size < stride:
-        raise RuntimeError(
+        raise ValueError(
             "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
         )
     if kernel_size == stride:
diff --git a/orthogonium/layers/conv/adaptiveSOC/soc_x_rko_conv.py b/orthogonium/layers/conv/adaptiveSOC/soc_x_rko_conv.py
index 9d52ce6..d2cdb34 100644
--- a/orthogonium/layers/conv/adaptiveSOC/soc_x_rko_conv.py
+++ b/orthogonium/layers/conv/adaptiveSOC/soc_x_rko_conv.py
@@ -13,7 +13,7 @@
     attach_soc_weight,
     ExpParams,
 )
-from orthogonium.layers.linear.reparametrizers import OrthoParams
+from orthogonium.reparametrizers import OrthoParams
 
 
 class SOCRkoConv2d(nn.Conv2d):
@@ -64,15 +64,13 @@ def __init__(
 
         # raise runtime error if kernel size >= stride
         if kernel_size < stride:
-            raise RuntimeError(
+            raise ValueError(
                 "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
             )
         if (in_channels % groups != 0) and (out_channels % groups != 0):
-            raise RuntimeError(
-                "in_channels and out_channels must be divisible by groups"
             )
         if ((self.max_channels // groups) < 2) and (kernel_size != stride):
-            raise RuntimeError("inner conv must have at least 2 channels")
+            raise ValueError("inner conv must have at least 2 channels")
         self.padding = padding
         self.stride = stride
         self.kernel_size = kernel_size
@@ -162,7 +160,7 @@ def singular_values(self):
         )
         sv_min = sv_min * svs_2.min()
         sv_max = sv_max * svs_2.max()
-        stable_rank = 0.5 * stable_rank + 0.5 * (np.mean(svs_2) / (svs_2.max() ** 2))
+        stable_rank = 0.5 * stable_rank + 0.5 * ((np.mean(svs_2) ** 2) / (svs_2.max() ** 2))
         return sv_min, sv_max, stable_rank
 
     def forward(self, X):
@@ -194,8 +192,6 @@ def __init__(
         But unit testing have shown that the convolution is still orthogonal when
          `out_channels * (stride**2) > in_channels`.
         """
-        if dilation != 1:
-            raise RuntimeError("dilation not supported")
         super(SOCRkoConvTranspose2d, self).__init__(
             in_channels,
             out_channels,
@@ -219,15 +215,13 @@ def __init__(
 
         # raise runtime error if kernel size >= stride
         if kernel_size < stride:
-            raise RuntimeError(
+            raise ValueError(
                 "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
             )
         if (in_channels % groups != 0) and (out_channels % groups != 0):
-            raise RuntimeError(
-                "in_channels and out_channels must be divisible by groups"
             )
         if ((self.max_channels // groups) < 2) and (kernel_size != stride):
-            raise RuntimeError("inner conv must have at least 2 channels")
+            raise ValueError("inner conv must have at least 2 channels")
         self.padding = padding
         self.stride = stride
         self.kernel_size = kernel_size
@@ -238,12 +232,12 @@ def __init__(
             self.intermediate_channels = out_channels
             # raise warning because this configuration don't yield orthogonal
             # convolutions
-            warnings.warn(
-                "This configuration does not yield orthogonal convolutions due to "
-                "padding issues: pytorch does not implement circular padding for "
-                "transposed convolutions",
-                RuntimeWarning,
-            )
+            # warnings.warn(
+            #     "This configuration does not yield orthogonal convolutions due to "
+            #     "padding issues: pytorch does not implement circular padding for "
+            #     "transposed convolutions",
+            #     RuntimeWarning,
+            # )
         del self.weight
         attach_soc_weight(
             self,
@@ -305,7 +299,7 @@ def singular_values(self):
         )
         sv_min = sv_min * svs_2.min()
         sv_max = sv_max * svs_2.max()
-        stable_rank = 0.5 * stable_rank + 0.5 * (np.mean(svs_2) / (svs_2.max() ** 2))
+        stable_rank = 0.5 * stable_rank + 0.5 * ((np.mean(svs_2) ** 2) / (svs_2.max() ** 2))
         return sv_min, sv_max, stable_rank
 
     @property
diff --git a/orthogonium/layers/custom_activations.py b/orthogonium/layers/custom_activations.py
index 040f9bd..0204e1a 100644
--- a/orthogonium/layers/custom_activations.py
+++ b/orthogonium/layers/custom_activations.py
@@ -5,6 +5,9 @@
 from torch.autograd import Function
 
 
+SQRT_2 = np.sqrt(2)
+
+
 class Abs(nn.Module):
     def __init__(self):
         super(Abs, self).__init__()
@@ -25,18 +28,19 @@ def forward(self, z):
 
 
 class HouseHolder(nn.Module):
-    def __init__(self, channels):
+    def __init__(self, channels, axis=1):
         super(HouseHolder, self).__init__()
         assert (channels % 2) == 0
         eff_channels = channels // 2
 
         self.theta = nn.Parameter(
-            0.5 * np.pi * torch.ones(1, eff_channels, 1, 1).cuda(), requires_grad=True
+            0.5 * np.pi * torch.ones(1, eff_channels, 1, 1), requires_grad=True
         )
+        self.axis = axis
 
-    def forward(self, z, axis=1):
+    def forward(self, z):
         theta = self.theta
-        x, y = z.split(z.shape[axis] // 2, axis)
+        x, y = z.split(z.shape[self.axis] // 2, self.axis)
 
         selector = (x * torch.sin(0.5 * theta)) - (y * torch.cos(0.5 * theta))
 
@@ -45,29 +49,30 @@ def forward(self, z, axis=1):
 
         a = x * (selector <= 0) + a_2 * (selector > 0)
         b = y * (selector <= 0) + b_2 * (selector > 0)
-        return torch.cat([a, b], dim=axis)
+        return torch.cat([a, b], dim=self.axis) / SQRT_2
 
 
 class HouseHolder_Order_2(nn.Module):
-    def __init__(self, channels):
+    def __init__(self, channels, axis=1):
         super(HouseHolder_Order_2, self).__init__()
         assert (channels % 2) == 0
         self.num_groups = channels // 2
+        self.axis = axis
 
         self.theta0 = nn.Parameter(
-            (np.pi * torch.rand(self.num_groups)).cuda(), requires_grad=True
+            (np.pi * torch.rand(self.num_groups)), requires_grad=True
         )
         self.theta1 = nn.Parameter(
-            (np.pi * torch.rand(self.num_groups)).cuda(), requires_grad=True
+            (np.pi * torch.rand(self.num_groups)), requires_grad=True
         )
         self.theta2 = nn.Parameter(
-            (np.pi * torch.rand(self.num_groups)).cuda(), requires_grad=True
+            (np.pi * torch.rand(self.num_groups)), requires_grad=True
         )
 
-    def forward(self, z, axis=1):
+    def forward(self, z):
         theta0 = torch.clamp(self.theta0.view(1, -1, 1, 1), 0.0, 2 * np.pi)
 
-        x, y = z.split(z.shape[axis] // 2, axis)
+        x, y = z.split(z.shape[self.axis] // 2, self.axis)
         z_theta = (torch.atan2(y, x) - (0.5 * theta0)) % (2 * np.pi)
 
         theta1 = torch.clamp(self.theta1.view(1, -1, 1, 1), 0.0, 2 * np.pi)
@@ -100,5 +105,5 @@ def forward(self, z, axis=1):
         a = (a1 * select1) + (a2 * select2) + (a3 * select3) + (a4 * select4)
         b = (b1 * select1) + (b2 * select2) + (b3 * select3) + (b4 * select4)
 
-        z = torch.cat([a, b], dim=axis)
+        z = torch.cat([a, b], dim=self.axis) / SQRT_2
         return z
diff --git a/orthogonium/layers/linear/ortho_linear.py b/orthogonium/layers/linear/ortho_linear.py
index 18c9590..14c48d6 100644
--- a/orthogonium/layers/linear/ortho_linear.py
+++ b/orthogonium/layers/linear/ortho_linear.py
@@ -3,8 +3,8 @@
 from torch import nn as nn
 from torch.nn.utils import parametrize as parametrize
 
-from orthogonium.layers.linear.reparametrizers import L2Normalize
-from orthogonium.layers.linear.reparametrizers import OrthoParams
+from orthogonium.reparametrizers import L2Normalize
+from orthogonium.reparametrizers import OrthoParams
 
 
 class OrthoLinear(nn.Linear):
@@ -32,7 +32,7 @@ def singular_values(self):
         svs = np.linalg.svd(
             self.weight.detach().cpu().numpy(), full_matrices=False, compute_uv=False
         )
-        stable_rank = np.sum(np.mean(svs)) / (svs.max() ** 2)
+        stable_rank = np.sum((np.mean(svs) ** 2)) / (svs.max() ** 2)
         return svs.min(), svs.max(), stable_rank
 
 
@@ -55,5 +55,5 @@ def singular_values(self):
         svs = np.linalg.svd(
             self.weight.detach().cpu().numpy(), full_matrices=False, compute_uv=False
         )
-        stable_rank = np.sum(np.mean(svs)) / (svs.max() ** 2)
+        stable_rank = np.sum(np.mean(svs) ** 2) / (svs.max() ** 2)
         return svs.min(), svs.max(), stable_rank
diff --git a/orthogonium/layers/linear/reparametrizers.py b/orthogonium/layers/linear/reparametrizers.py
deleted file mode 100644
index 4cd2d44..0000000
--- a/orthogonium/layers/linear/reparametrizers.py
+++ /dev/null
@@ -1,275 +0,0 @@
-from dataclasses import dataclass
-from typing import Callable
-from typing import Tuple
-import torch
-import torch.nn.utils.parametrize as parametrize
-from torch import nn as nn
-from orthogonium.classparam import ClassParam
-
-
-class L2Normalize(nn.Module):
-    def __init__(self, dtype, dim=None):
-        super(L2Normalize, self).__init__()
-        self.dim = dim
-        self.dtype = dtype
-
-    def forward(self, x):
-        return x / (torch.norm(x, dim=self.dim, keepdim=True, dtype=self.dtype) + 1e-8)
-
-    def right_inverse(self, x):
-        return x / (torch.norm(x, dim=self.dim, keepdim=True, dtype=self.dtype) + 1e-8)
-
-
-class BatchedPowerIteration(nn.Module):
-    def __init__(self, weight_shape, power_it_niter=3, eps=1e-12):
-        """
-        This module is a batched version of the Power Iteration algorithm.
-        It is used to normalize the kernel of a convolutional layer.
-
-        Args:
-            weight_shape (tuple): shape of the kernel, the last dimension will be normalized.
-            power_it_niter (int, optional): number of iterations. Defaults to 3.
-            eps (float, optional): small value to avoid division by zero. Defaults to 1e-12.
-        """
-        super(BatchedPowerIteration, self).__init__()
-        self.weight_shape = weight_shape
-        self.power_it_niter = power_it_niter
-        self.eps = eps
-        # init u
-        # u will be weight_shape[:-2] + (weight_shape[:-2], 1)
-        # v will be weight_shape[:-2] + (weight_shape[:-1], 1,)
-        self.u = nn.Parameter(
-            torch.Tensor(torch.randn(*weight_shape[:-2], weight_shape[-2], 1)),
-            requires_grad=False,
-        )
-        self.v = nn.Parameter(
-            torch.Tensor(torch.randn(*weight_shape[:-2], weight_shape[-1], 1)),
-            requires_grad=False,
-        )
-        parametrize.register_parametrization(
-            self, "u", L2Normalize(dtype=self.u.dtype, dim=(-2))
-        )
-        parametrize.register_parametrization(
-            self, "v", L2Normalize(dtype=self.v.dtype, dim=(-2))
-        )
-
-    def forward(self, X, init_u=None, n_iters=3, return_uv=True):
-        for _ in range(n_iters):
-            self.v = X.transpose(-1, -2) @ self.u
-            self.u = X @ self.v
-        # stop gradient on u and v
-        u = self.u.detach()
-        v = self.v.detach()
-        # but keep gradient on s
-        s = u.transpose(-1, -2) @ X @ v
-        return X / (s + self.eps)
-
-    def right_inverse(self, normalized_kernel):
-        # we assume that the kernel is normalized
-        return normalized_kernel.to(self.u.dtype)
-
-
-class BatchedIdentity(nn.Module):
-    def __init__(self, weight_shape):
-        super(BatchedIdentity, self).__init__()
-
-    def forward(self, w):
-        return w
-
-    def right_inverse(self, w):
-        return w
-
-
-class BatchedBjorckOrthogonalization(nn.Module):
-    def __init__(self, weight_shape, beta=0.5, niters=7):
-        self.weight_shape = weight_shape
-        self.beta = beta
-        self.niters = niters
-        if weight_shape[-2] < weight_shape[-1]:
-            self.wwtw_op = BatchedBjorckOrthogonalization.wwt_w_op
-        else:
-            self.wwtw_op = BatchedBjorckOrthogonalization.w_wtw_op
-        super(BatchedBjorckOrthogonalization, self).__init__()
-
-    @staticmethod
-    def w_wtw_op(w):
-        return w @ (w.transpose(-1, -2) @ w)
-
-    @staticmethod
-    def wwt_w_op(w):
-        return (w @ w.transpose(-1, -2)) @ w
-
-    def forward(self, w):
-        for _ in range(self.niters):
-            w = (1 + self.beta) * w - self.beta * self.wwtw_op(w)
-        return w
-
-    def right_inverse(self, w):
-        return w
-
-
-def orth(X):
-    S = X @ X.mT
-    eps = S.diagonal(dim1=1, dim2=2).mean(1).mul(1e-3).detach()
-    eye = torch.eye(S.size(-1), dtype=S.dtype, device=S.device)
-    S = S + eps.view(-1, 1, 1) * eye.unsqueeze(0)
-    L = torch.linalg.cholesky(S)
-    W = torch.linalg.solve_triangular(L, X, upper=False)
-    return W
-
-
-class CholeskyOrthfn(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, X):
-        S = X @ X.mT
-        eps = S.diagonal(dim1=1, dim2=2).mean(1).mul(1e-3)
-        eye = torch.eye(S.size(-1), dtype=S.dtype, device=S.device)
-        S = S + eps.view(-1, 1, 1) * eye.unsqueeze(0)
-        L = torch.linalg.cholesky(S)
-        W = torch.linalg.solve_triangular(L, X, upper=False)
-        ctx.save_for_backward(W, L)
-        return W
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        W, L = ctx.saved_tensors
-        LmT = L.mT.contiguous()
-        gB = torch.linalg.solve_triangular(LmT, grad_output, upper=True)
-        gA = (-gB @ W.mT).tril()
-        gS = (LmT @ gA).tril()
-        gS = gS + gS.tril(-1).mT
-        gS = torch.linalg.solve_triangular(LmT, gS, upper=True)
-        gX = gS @ W + gB
-        return gX
-
-
-class CholeskyOrthfn_stable(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, X):
-        S = X @ X.mT
-        eps = S.diagonal(dim1=1, dim2=2).mean(1).mul(1e-3)
-        eye = torch.eye(S.size(-1), dtype=S.dtype, device=S.device)
-        S = S + eps.view(-1, 1, 1) * eye.unsqueeze(0)
-        L = torch.linalg.cholesky(S)
-        W = torch.linalg.solve_triangular(L, X, upper=False)
-        ctx.save_for_backward(X, W, L)
-        return W
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        X, W, L = ctx.saved_tensors
-        gB = torch.linalg.solve_triangular(L.mT, grad_output, upper=True)
-        gA = (-gB @ W.mT).tril()
-        gS = (L.mT @ gA).tril()
-        gS = gS + gS.tril(-1).mT
-        gS = torch.linalg.solve_triangular(L.mT, gS, upper=True)
-        gS = torch.linalg.solve_triangular(L, gS, upper=False, left=False)
-        gX = gS @ X + gB
-        return gX
-
-
-CholeskyOrth = CholeskyOrthfn.apply
-
-
-class BatchedCholeskyOrthogonalization(nn.Module):
-    def __init__(self, weight_shape):
-        self.weight_shape = weight_shape
-        super(BatchedCholeskyOrthogonalization, self).__init__()
-
-    def forward(self, w):
-        return CholeskyOrth(w)
-
-    def right_inverse(self, w):
-        return w
-
-
-class BatchedExponentialOrthogonalization(nn.Module):
-    def __init__(self, weight_shape, niters=7):
-        self.weight_shape = weight_shape
-        self.max_dim = max(weight_shape[-2:])
-        self.niters = niters
-        super(BatchedExponentialOrthogonalization, self).__init__()
-
-    def forward(self, w):
-        # fill w with zero to have a square matrix over the last two dimensions
-        # if ((self.max_dim - w.shape[-1]) != 0) and ((self.max_dim - w.shape[-2]) != 0):
-        w = torch.nn.functional.pad(
-            w, (0, self.max_dim - w.shape[-1], 0, self.max_dim - w.shape[-2])
-        )
-        # makes w skew symmetric
-        w = (w - w.transpose(-1, -2)) / 2
-        acc = w
-        res = torch.eye(acc.shape[-2], acc.shape[-1], device=w.device) + acc
-        for i in range(2, self.niters):
-            acc = torch.einsum("...ij,...jk->...ik", acc, w) / i
-            res = res + acc
-        # if transpose:
-        #     res = res.transpose(-1, -2)
-        res = res[..., : self.weight_shape[-2], : self.weight_shape[-1]]
-        return res.contiguous()
-
-    def right_inverse(self, w):
-        return w
-
-
-class BatchedQROrthogonalization(nn.Module):
-    def __init__(self, weight_shape):
-        super(BatchedQROrthogonalization, self).__init__()
-
-    def forward(self, w):
-        transpose = w.shape[-2] < w.shape[-1]
-        if transpose:
-            w = w.transpose(-1, -2)
-        q, r = torch.linalg.qr(w, mode="reduced")
-        # compute the sign of the diagonal of d
-        diag_sign = torch.sign(torch.diagonal(r, dim1=-2, dim2=-1)).unsqueeze(-2)
-        # multiply the sign with the diagonal of r
-        q = q * diag_sign
-        if transpose:
-            q = q.transpose(-1, -2)
-        return q
-
-    def right_inverse(self, w):
-        return w
-
-
-@dataclass
-class OrthoParams:
-    # spectral_normalizer: Callable[Tuple[int, ...], nn.Module] = BatchedIdentity
-    spectral_normalizer: Callable[Tuple[int, ...], nn.Module] = ClassParam(  # type: ignore
-        BatchedPowerIteration, power_it_niter=3, eps=1e-6
-    )
-    orthogonalizer: Callable[Tuple[int, ...], nn.Module] = ClassParam(  # type: ignore
-        BatchedBjorckOrthogonalization,
-        beta=0.5,
-        niters=12,
-        # ClassParam(BatchedExponentialOrthogonalization, niters=12)
-        # BatchedCholeskyOrthogonalization,
-        # BatchedQROrthogonalization,
-    )
-    contiguous_optimization: bool = False
-
-
-DEFAULT_ORTHO_PARAMS = OrthoParams()
-DEFAULT_TEST_ORTHO_PARAMS = OrthoParams(
-    spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=1e-6),  # type: ignore
-    orthogonalizer=ClassParam(BatchedBjorckOrthogonalization, beta=0.5, niters=25),
-    # orthogonalizer=ClassParam(BatchedQROrthogonalization),
-    # orthogonalizer=ClassParam(BatchedExponentialOrthogonalization, niters=12),  # type: ignore
-    contiguous_optimization=False,
-)
-EXP_ORTHO_PARAMS = OrthoParams(
-    spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=1e-6),  # type: ignore
-    orthogonalizer=ClassParam(BatchedExponentialOrthogonalization, niters=12),  # type: ignore
-    contiguous_optimization=False,
-)
-QR_ORTHO_PARAMS = OrthoParams(
-    spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=1e-3),  # type: ignore
-    orthogonalizer=ClassParam(BatchedQROrthogonalization),  # type: ignore
-    contiguous_optimization=False,
-)
-CHOLESKY_ORTHO_PARAMS = OrthoParams(
-    spectral_normalizer=BatchedIdentity,  # type: ignore
-    orthogonalizer=ClassParam(BatchedCholeskyOrthogonalization),  # type: ignore
-    contiguous_optimization=False,
-)
diff --git a/orthogonium/layers/legacy/__init__.py b/orthogonium/legacy/__init__.py
similarity index 100%
rename from orthogonium/layers/legacy/__init__.py
rename to orthogonium/legacy/__init__.py
diff --git a/orthogonium/layers/legacy/block_ortho_conv.py b/orthogonium/legacy/block_ortho_conv.py
similarity index 100%
rename from orthogonium/layers/legacy/block_ortho_conv.py
rename to orthogonium/legacy/block_ortho_conv.py
diff --git a/orthogonium/layers/legacy/cayley_ortho_conv.py b/orthogonium/legacy/cayley_ortho_conv.py
similarity index 100%
rename from orthogonium/layers/legacy/cayley_ortho_conv.py
rename to orthogonium/legacy/cayley_ortho_conv.py
diff --git a/orthogonium/layers/legacy/skew_ortho_conv.py b/orthogonium/legacy/skew_ortho_conv.py
similarity index 100%
rename from orthogonium/layers/legacy/skew_ortho_conv.py
rename to orthogonium/legacy/skew_ortho_conv.py
diff --git a/orthogonium/model_factory/__init__.py b/orthogonium/model_factory/__init__.py
new file mode 100644
index 0000000..434484e
--- /dev/null
+++ b/orthogonium/model_factory/__init__.py
@@ -0,0 +1 @@
+from .classparam import ClassParam
diff --git a/orthogonium/classparam.py b/orthogonium/model_factory/classparam.py
similarity index 100%
rename from orthogonium/classparam.py
rename to orthogonium/model_factory/classparam.py
diff --git a/orthogonium/models_factory.py b/orthogonium/model_factory/models_factory.py
similarity index 99%
rename from orthogonium/models_factory.py
rename to orthogonium/model_factory/models_factory.py
index 868dd0c..f03a36f 100644
--- a/orthogonium/models_factory.py
+++ b/orthogonium/model_factory/models_factory.py
@@ -2,7 +2,7 @@
 import torch.nn as nn
 from torch.nn import AvgPool2d
 
-from orthogonium.classparam import ClassParam
+from orthogonium.model_factory.classparam import ClassParam
 from orthogonium.layers import AdaptiveOrthoConv2d
 from orthogonium.layers import BatchCentering2D
 from orthogonium.layers import LayerCentering2D
@@ -11,7 +11,7 @@
 from orthogonium.layers import UnitNormLinear
 from orthogonium.layers.conv.SLL.sll_layer import SLLxAOCLipschitzResBlock
 from orthogonium.layers.conv.SLL.sll_layer import SDPBasedLipschitzResBlock
-from orthogonium.layers.linear.reparametrizers import DEFAULT_ORTHO_PARAMS
+from orthogonium.reparametrizers import DEFAULT_ORTHO_PARAMS
 
 
 def SLLxBCOPResNet50(
diff --git a/orthogonium/reparametrizers.py b/orthogonium/reparametrizers.py
new file mode 100644
index 0000000..29278ff
--- /dev/null
+++ b/orthogonium/reparametrizers.py
@@ -0,0 +1,469 @@
+from dataclasses import dataclass
+from typing import Callable
+from typing import Tuple
+import torch
+import torch.nn.utils.parametrize as parametrize
+from torch import nn as nn
+from orthogonium.model_factory.classparam import ClassParam
+
+
+class L2Normalize(nn.Module):
+    def __init__(self, dtype, dim=None):
+        """
+        A class that performs L2 normalization for the given input tensor.
+
+        L2 normalization is a process that normalizes the input over a specified
+        dimension such that the sum of squares of the elements along that
+        dimension equals 1. It ensures that the resulting tensor has a unit norm.
+        This operation is widely used in machine learning and deep learning
+        applications to standardize feature representations.
+
+        Attributes:
+            dim (Optional[int]): The specific dimension along which normalization
+                is performed. If None, normalization is done over all dimensions.
+            dtype (Any): The data type of the tensor to be normalized.
+
+        Parameters:
+            dtype: The data type of the tensor to be normalized.
+            dim: An optional integer specifying the dimension along which to
+                normalize. If not provided, the input will be normalized globally
+                across all dimensions.
+        """
+        super(L2Normalize, self).__init__()
+        self.dim = dim
+        self.dtype = dtype
+
+    def forward(self, x):
+        return x / (torch.norm(x, dim=self.dim, keepdim=True, dtype=self.dtype) + 1e-8)
+
+    def right_inverse(self, x):
+        return x / (torch.norm(x, dim=self.dim, keepdim=True, dtype=self.dtype) + 1e-8)
+
+
+class BatchedPowerIteration(nn.Module):
+    def __init__(self, weight_shape, power_it_niter=3, eps=1e-12):
+        """
+        BatchedPowerIteration is a class that performs spectral normalization on weights
+        using the power iteration method in a batched manner. It initializes singular
+        vectors 'u' and 'v', which are used to approximate the largest singular value
+        of the associated weight matrix during training. The L2 normalization is applied
+        to stabilize these singular vector parameters.
+
+        Attributes:
+            weight_shape: tuple
+                Shape of the weight tensor. Normalization is applied to the last two dimensions.
+            power_it_niter: int
+                Number of iterations to perform for the power iteration method.
+            eps: float
+                A small constant to ensure numerical stability during calculations. Used in the power iteration
+                method to avoid dividing by zero.
+        """
+        super(BatchedPowerIteration, self).__init__()
+        self.weight_shape = weight_shape
+        self.power_it_niter = power_it_niter
+        self.eps = eps
+        # init u
+        # u will be weight_shape[:-2] + (weight_shape[:-2], 1)
+        # v will be weight_shape[:-2] + (weight_shape[:-1], 1,)
+        self.u = nn.Parameter(
+            torch.Tensor(torch.randn(*weight_shape[:-2], weight_shape[-2], 1)),
+            requires_grad=False,
+        )
+        self.v = nn.Parameter(
+            torch.Tensor(torch.randn(*weight_shape[:-2], weight_shape[-1], 1)),
+            requires_grad=False,
+        )
+        parametrize.register_parametrization(
+            self, "u", L2Normalize(dtype=self.u.dtype, dim=(-2))
+        )
+        parametrize.register_parametrization(
+            self, "v", L2Normalize(dtype=self.v.dtype, dim=(-2))
+        )
+
+    def forward(self, X, init_u=None, n_iters=3, return_uv=True):
+        for _ in range(n_iters):
+            self.v = X.transpose(-1, -2) @ self.u
+            self.u = X @ self.v
+        # stop gradient on u and v
+        u = self.u.detach()
+        v = self.v.detach()
+        # but keep gradient on s
+        s = u.transpose(-1, -2) @ X @ v
+        return X / (s + self.eps)
+
+    def right_inverse(self, normalized_kernel):
+        # we assume that the kernel is normalized
+        return normalized_kernel.to(self.u.dtype)
+
+
+class BatchedIdentity(nn.Module):
+    def __init__(self, weight_shape):
+        """
+        Class representing a batched identity matrix with a specific weight shape. The
+        matrix is initialized based on the provided shape of the weights. It is a
+        convenient utility for applications where identity-like operations are
+        required in a batched manner.
+
+        Attributes:
+            weight_shape (Tuple[int, int]): A tuple representing the shape of the
+            weight matrix for each batch. (unused)
+
+        Args:
+            weight_shape: A tuple specifying the shape of the individual weight matrix.
+        """
+        super(BatchedIdentity, self).__init__()
+
+    def forward(self, w):
+        return w
+
+    def right_inverse(self, w):
+        return w
+
+
+class BatchedBjorckOrthogonalization(nn.Module):
+    def __init__(self, weight_shape, beta=0.5, niters=12, pass_through=False):
+        """
+        Initialize the BatchedBjorckOrthogonalization module.
+
+        This module implements the BjΓΆrck orthogonalization method, which iteratively refines
+        a weight matrix towards orthogonality. The method is especially effective when the
+        weight matrix columns are nearly orthonormal. It balances computational efficiency
+        with convergence speed through a user-defined `beta` parameter and iteration count.
+
+        Args:
+            weight_shape (tuple): The shape of the weight matrix to be orthogonalized.
+            beta (float): Coefficient controlling the convergence of the orthogonalization process.
+                Default is 0.5.
+            niters (int): Number of iterations for the orthogonalization algorithm. Default is 12.
+            pass_through (bool): If True, most iterations are performed without gradient computation,
+                which can improve efficiency.
+        """
+        self.weight_shape = weight_shape
+        self.beta = beta
+        self.niters = niters
+        self.pass_through = pass_through
+        if weight_shape[-2] < weight_shape[-1]:
+            self.wwtw_op = BatchedBjorckOrthogonalization.wwt_w_op
+        else:
+            self.wwtw_op = BatchedBjorckOrthogonalization.w_wtw_op
+        super(BatchedBjorckOrthogonalization, self).__init__()
+
+    @staticmethod
+    def w_wtw_op(w):
+        return w @ (w.transpose(-1, -2) @ w)
+
+    @staticmethod
+    def wwt_w_op(w):
+        return (w @ w.transpose(-1, -2)) @ w
+
+    def forward(self, w):
+        """
+        Apply the BjΓΆrck orthogonalization process to the weight matrix.
+
+        The algorithm adjusts the input matrix to approximate the closest orthogonal matrix
+        by iteratively applying transformations based on the BjΓΆrck algorithm.
+
+        Args:
+            w (torch.Tensor): The weight matrix to be orthogonalized.
+
+        Returns:
+            torch.Tensor: The orthogonalized weight matrix.
+        """
+        if self.pass_through:
+            with torch.no_grad():
+                for _ in range(self.niters):
+                    w = (1 + self.beta) * w - self.beta * self.wwtw_op(w)
+            # Final iteration without no_grad, using parameters:
+            w = (1 + self.beta) * w - self.beta * self.wwtw_op(w)
+        else:
+            for _ in range(self.niters):
+                w = (1 + self.beta) * w - self.beta * self.wwtw_op(w)
+        return w
+
+    def right_inverse(self, w):
+        return w
+
+
+class BatchedCholeskyOrthogonalization(nn.Module):
+    def __init__(self, weight_shape, stable=False):
+        """
+        Initialize the BatchedCholeskyOrthogonalization module.
+
+        This module orthogonalizes a weight matrix using the Cholesky decomposition method.
+        It first computes the positive definite matrix \( V V^T \), then performs a Cholesky
+        decomposition to obtain a lower triangular matrix. Solving the resulting triangular
+        system yields an orthogonal matrix. This method is efficient and numerically stable,
+        making it suitable for a wide range of applications.
+
+        Args:
+            weight_shape (tuple): The shape of the weight matrix.
+            stable (bool): Whether to use the stable version of the Cholesky-based orthogonalization
+                function, which adds a small positive diagonal element to ensure numerical stability.
+                Default is False.
+        """
+        self.weight_shape = weight_shape
+        super(BatchedCholeskyOrthogonalization, self).__init__()
+        if stable:
+            self.orth = BatchedCholeskyOrthogonalization.CholeskyOrthfn_stable.apply
+        else:
+            self.orth = BatchedCholeskyOrthogonalization.CholeskyOrthfn.apply
+
+    # @staticmethod
+    # def orth(X):
+    #     S = X @ X.mT
+    #     eps = S.diagonal(dim1=1, dim2=2).mean(1).mul(1e-3).detach()
+    #     eye = torch.eye(S.size(-1), dtype=S.dtype, device=S.device)
+    #     S = S + eps.view(-1, 1, 1) * eye.unsqueeze(0)
+    #     L = torch.linalg.cholesky(S)
+    #     W = torch.linalg.solve_triangular(L, X, upper=False)
+    #     return W
+
+    class CholeskyOrthfn(torch.autograd.Function):
+        @staticmethod
+        # def forward(ctx, X):
+        #     S = X @ X.mT
+        #     eps = S.diagonal(dim1=1, dim2=2).mean(1).mul(1e-3)
+        #     eye = torch.eye(S.size(-1), dtype=S.dtype, device=S.device)
+        #     S = S + eps.view(-1, 1, 1) * eye.unsqueeze(0)
+        #     L = torch.linalg.cholesky(S)
+        #     W = torch.linalg.solve_triangular(L, X, upper=False)
+        #     ctx.save_for_backward(W, L)
+        #     return W
+        def forward(ctx, X):
+            S = X @ X.mT
+            eps = 1e-5  # A common stable choice
+            S = S + eps * torch.eye(
+                S.size(-1), dtype=S.dtype, device=S.device
+            ).unsqueeze(0)
+            L = torch.linalg.cholesky(S)
+            W = torch.linalg.solve_triangular(L, X, upper=False)
+            ctx.save_for_backward(W, L)
+            return W
+
+        @staticmethod
+        def backward(ctx, grad_output):
+            W, L = ctx.saved_tensors
+            LmT = L.mT.contiguous()
+            gB = torch.linalg.solve_triangular(LmT, grad_output, upper=True)
+            gA = (-gB @ W.mT).tril()
+            gS = (LmT @ gA).tril()
+            gS = gS + gS.tril(-1).mT
+            gS = torch.linalg.solve_triangular(LmT, gS, upper=True)
+            gX = gS @ W + gB
+            return gX
+
+    class CholeskyOrthfn_stable(torch.autograd.Function):
+        @staticmethod
+        def forward(ctx, X):
+            S = X @ X.mT
+            eps = 1e-5  # A common stable choice
+            S = S + eps * torch.eye(
+                S.size(-1), dtype=S.dtype, device=S.device
+            ).unsqueeze(0)
+            L = torch.linalg.cholesky(S)
+            W = torch.linalg.solve_triangular(L, X, upper=False)
+            ctx.save_for_backward(X, W, L)
+            return W
+
+        @staticmethod
+        def backward(ctx, grad_output):
+            X, W, L = ctx.saved_tensors
+            gB = torch.linalg.solve_triangular(L.mT, grad_output, upper=True)
+            gA = (-gB @ W.mT).tril()
+            gS = (L.mT @ gA).tril()
+            gS = gS + gS.tril(-1).mT
+            gS = torch.linalg.solve_triangular(L.mT, gS, upper=True)
+            gS = torch.linalg.solve_triangular(L, gS, upper=False, left=False)
+            gX = gS @ X + gB
+            return gX
+
+    def forward(self, w):
+        """
+        Apply Cholesky-based orthogonalization to the weight matrix.
+
+        This method constructs a symmetric positive definite matrix from the input weight
+        matrix, performs Cholesky decomposition, and solves the triangular system to produce
+        an orthogonal matrix. It mimics the results of the Gram-Schmidt process but with
+        improved numerical stability.
+
+        Args:
+            w (torch.Tensor): The weight matrix to be orthogonalized.
+
+        Returns:
+            torch.Tensor: The orthogonalized weight matrix.
+        """
+        return self.orth(w).view(*self.weight_shape)
+
+    def right_inverse(self, w):
+        return w
+
+
+class BatchedExponentialOrthogonalization(nn.Module):
+    def __init__(self, weight_shape, niters=7):
+        """
+        Initialize the BatchedExponentialOrthogonalization module.
+
+        This module orthogonalizes a weight matrix using the exponential map of a skew-symmetric
+        matrix. By converting the matrix into a skew-symmetric form and applying the matrix
+        exponential, it produces an orthogonal matrix. This approach is particularly useful
+        in contexts where smooth transitions between matrices are required.
+
+        Non-square matrices
+
+        Args:
+            weight_shape (tuple): The shape of the weight matrix.
+            niters (int): Number of iterations for the series expansion approximation of the
+                matrix exponential. Default is 7.
+        """
+        self.weight_shape = weight_shape
+        self.max_dim = max(weight_shape[-2:])
+        self.niters = niters
+        super(BatchedExponentialOrthogonalization, self).__init__()
+
+    def forward(self, w):
+        # fill w with zero to have a square matrix over the last two dimensions
+        # if ((self.max_dim - w.shape[-1]) != 0) and ((self.max_dim - w.shape[-2]) != 0):
+        w = torch.nn.functional.pad(
+            w, (0, self.max_dim - w.shape[-1], 0, self.max_dim - w.shape[-2])
+        )
+        # makes w skew symmetric
+        w = (w - w.transpose(-1, -2)) / 2
+        acc = w
+        res = torch.eye(acc.shape[-2], acc.shape[-1], device=w.device) + acc
+        for i in range(2, self.niters):
+            acc = torch.einsum("...ij,...jk->...ik", acc, w) / i
+            res = res + acc
+        # if transpose:
+        #     res = res.transpose(-1, -2)
+        res = res[..., : self.weight_shape[-2], : self.weight_shape[-1]]
+        return res.contiguous()
+
+    def right_inverse(self, w):
+        return w
+
+
+class BatchedQROrthogonalization(nn.Module):
+    def __init__(self, weight_shape):
+        """
+        Initialize the BatchedQROrthogonalization module.
+
+        This module uses QR decomposition to orthogonalize a weight matrix in a batched manner.
+        It computes the orthogonal component (`Q`) from the decomposition, ensuring that the
+        output satisfies orthogonality constraints.
+
+        Args:
+            weight_shape (tuple): The shape of the weight matrix to be orthogonalized.
+        """
+        super(BatchedQROrthogonalization, self).__init__()
+
+    def forward(self, w):
+        """
+        Perform QR decomposition to compute the orthogonalized weight matrix.
+
+        The QR decomposition splits the input matrix into an orthogonal matrix (`Q`) and
+        an upper triangular matrix (`R`). This module returns the orthogonal component.
+
+        Args:
+            w (torch.Tensor): The weight matrix to be orthogonalized.
+
+        Returns:
+            torch.Tensor: The orthogonalized weight matrix (`Q` from the QR decomposition).
+        """
+        transpose = w.shape[-2] < w.shape[-1]
+        if transpose:
+            w = w.transpose(-1, -2)
+        q, r = torch.linalg.qr(w, mode="reduced")
+        # compute the sign of the diagonal of d
+        diag_sign = torch.sign(torch.diagonal(r, dim1=-2, dim2=-1)).unsqueeze(-2)
+        # multiply the sign with the diagonal of r
+        q = q * diag_sign
+        if transpose:
+            q = q.transpose(-1, -2)
+        return q.contiguous()
+
+    def right_inverse(self, w):
+        return w
+
+
+@dataclass
+class OrthoParams:
+    """
+    Represents the parameters and configurations used for orthogonalization
+    and spectral normalization.
+
+    This class encapsulates the necessary modules and settings required
+    for performing spectral normalization and orthogonalization of tensors
+    in a parameterized way. It accommodates various implementations of
+    normalizers and orthogonalization techniques to provide flexibility
+    in their application. This way we can easily switch between different
+    normalization techniques inside our layer despite that each normalization
+    have different parameters.
+
+    Attributes:
+        spectral_normalizer (Callable[Tuple[int, ...], nn.Module]): A callable
+            that produces a module for spectral normalization. Default is
+            configured to use BatchedPowerIteration with specific parameters.
+            This callable can be provided either as a `functool.partial` or as a
+            `orthogonium.ClassParam`. It will recieve the shape of the weight tensor as its
+            argument.
+        orthogonalizer (Callable[Tuple[int, ...], nn.Module]): A callable
+            that produces a module for orthogonalization. Default is
+            configured to use BatchedBjorckOrthogonalization with specific
+            parameters. This callable can be provided either as a `functool.partial` or as a
+            `orthogonium.ClassParam`. It will recieve the shape of the weight tensor as its argument.
+        contiguous_optimization (bool): Determines whether to perform
+            optimization ensuring contiguous operations. Default is False.
+    """
+
+    # spectral_normalizer: Callable[Tuple[int, ...], nn.Module] = BatchedIdentity
+    spectral_normalizer: Callable[Tuple[int, ...], nn.Module] = ClassParam(  # type: ignore
+        BatchedPowerIteration, power_it_niter=3, eps=1e-6
+    )
+    orthogonalizer: Callable[Tuple[int, ...], nn.Module] = ClassParam(  # type: ignore
+        BatchedBjorckOrthogonalization,
+        beta=0.5,
+        niters=12,
+        pass_through=False,
+        # ClassParam(BatchedExponentialOrthogonalization, niters=12)
+        # BatchedCholeskyOrthogonalization,
+        # BatchedQROrthogonalization,
+    )
+    contiguous_optimization: bool = False
+
+
+DEFAULT_ORTHO_PARAMS = OrthoParams()
+BJORCK_PASS_THROUGH_ORTHO_PARAMS = OrthoParams(
+    spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=1e-4),  # type: ignore
+    orthogonalizer=ClassParam(
+        BatchedBjorckOrthogonalization, beta=0.5, niters=12, pass_through=True
+    ),
+    contiguous_optimization=False,
+)
+DEFAULT_TEST_ORTHO_PARAMS = OrthoParams(
+    spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=4, eps=1e-4),  # type: ignore
+    orthogonalizer=ClassParam(BatchedBjorckOrthogonalization, beta=0.5, niters=25),
+    # orthogonalizer=ClassParam(BatchedQROrthogonalization),
+    # orthogonalizer=ClassParam(BatchedExponentialOrthogonalization, niters=12),  # type: ignore
+    contiguous_optimization=False,
+)
+EXP_ORTHO_PARAMS = OrthoParams(
+    spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=1e-6),  # type: ignore
+    orthogonalizer=ClassParam(BatchedExponentialOrthogonalization, niters=12),  # type: ignore
+    contiguous_optimization=False,
+)
+QR_ORTHO_PARAMS = OrthoParams(
+    spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=1e-3),  # type: ignore
+    orthogonalizer=ClassParam(BatchedQROrthogonalization),  # type: ignore
+    contiguous_optimization=False,
+)
+CHOLESKY_ORTHO_PARAMS = OrthoParams(
+    spectral_normalizer=BatchedIdentity,  # type: ignore
+    orthogonalizer=ClassParam(BatchedCholeskyOrthogonalization),  # type: ignore
+    contiguous_optimization=False,
+)
+
+CHOLESKY_STABLE_ORTHO_PARAMS = OrthoParams(
+    spectral_normalizer=BatchedIdentity,
+    orthogonalizer=ClassParam(BatchedCholeskyOrthogonalization, stable=True),
+    contiguous_optimization=False,
+)
diff --git a/scripts/benchmark/bench_archs.py b/scripts/benchmark/bench_archs.py
index 6e4a1f7..bab6289 100644
--- a/scripts/benchmark/bench_archs.py
+++ b/scripts/benchmark/bench_archs.py
@@ -4,8 +4,6 @@
 import pandas as pd
 import pytorch_lightning
 import torch
-from batch_times import evaluate_all_model_time_statistics
-from memory_usage import get_model_memory
 from torch.nn import Conv2d
 from torch.utils.data import DataLoader
 from torchvision.datasets import ImageFolder
@@ -17,19 +15,27 @@
 from torchvision.transforms import Resize
 from torchvision.transforms import ToTensor
 
-from orthogonium.classparam import ClassParam
+from batch_times import evaluate_all_model_time_statistics
+from memory_usage import get_model_memory
 from orthogonium.layers import AdaptiveOrthoConv2d as BCOP_new
-from orthogonium.layers.legacy.block_ortho_conv import BCOP as BCOP_old
-from orthogonium.layers.legacy.cayley_ortho_conv import Cayley
-from orthogonium.layers.legacy.skew_ortho_conv import SOC
-from orthogonium.models_factory import LipResNet
+from orthogonium.legacy import BCOP as BCOP_old
+from orthogonium.legacy.cayley_ortho_conv import Cayley
+from orthogonium.legacy.skew_ortho_conv import SOC
+from orthogonium.model_factory.classparam import ClassParam
+from orthogonium.model_factory.models_factory import LipResNet
+from orthogonium.reparametrizers import DEFAULT_ORTHO_PARAMS, QR_ORTHO_PARAMS, EXP_ORTHO_PARAMS, CHOLESKY_ORTHO_PARAMS, \
+    CHOLESKY_STABLE_ORTHO_PARAMS
 
 # from orthogonium.layers.conv.reparametrizers import BjorckParams
 
 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 layers = [
-    ("BCOP_new", BCOP_new),
+    ("BCOP_qr", ClassParam(BCOP_new, ortho_params=QR_ORTHO_PARAMS)),
+    ("BCOP_bb", ClassParam(BCOP_new, ortho_params=DEFAULT_ORTHO_PARAMS)),
+    ("BCOP_cho", ClassParam(BCOP_new, ortho_params=CHOLESKY_ORTHO_PARAMS)),
+    ("BCOP_cho_stab", ClassParam(BCOP_new, ortho_params=CHOLESKY_STABLE_ORTHO_PARAMS)),
+    ("BCOP_exp", ClassParam(BCOP_new, ortho_params=EXP_ORTHO_PARAMS)),
     ("BCOP_old", BCOP_old),
     (
         "SOC",
diff --git a/scripts/benchmark/bench_bcop.py b/scripts/benchmark/bench_bcop.py
index 547f64a..7b79797 100644
--- a/scripts/benchmark/bench_bcop.py
+++ b/scripts/benchmark/bench_bcop.py
@@ -12,7 +12,7 @@
 from torch.utils.data import Dataset
 
 from orthogonium.layers import AdaptiveOrthoConv2d as BCOP_new
-from orthogonium.layers.legacy.block_ortho_conv import BCOP as BCOP_old
+from orthogonium.legacy import BCOP as BCOP_old
 
 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
diff --git a/scripts/imagenet.py b/scripts/imagenet.py
index 151f5c4..ec6c199 100644
--- a/scripts/imagenet.py
+++ b/scripts/imagenet.py
@@ -1,9 +1,11 @@
 import math
 import os
 
-import pytorch_lightning
+from lightning.pytorch import callbacks as pl_callbacks
+from lightning.pytorch import Trainer
+from lightning.pytorch import LightningModule
+from lightning.pytorch import LightningDataModule
 import schedulefree
-import torch.nn as nn
 import torch.utils.data
 import torchmetrics
 from lightning.pytorch.loggers import WandbLogger
@@ -18,16 +20,16 @@
 from torchvision.transforms import Resize
 from torchvision.transforms import ToTensor
 
-from orthogonium.classparam import ClassParam
+from orthogonium.model_factory.classparam import ClassParam
 from orthogonium.layers import UnitNormLinear
-from orthogonium.layers.conv import AdaptiveOrthoConv2d
+from orthogonium.layers.conv.AOC import AdaptiveOrthoConv2d
 from orthogonium.layers.custom_activations import MaxMin
-from orthogonium.layers.linear.reparametrizers import DEFAULT_ORTHO_PARAMS
+from orthogonium.reparametrizers import DEFAULT_ORTHO_PARAMS, QR_ORTHO_PARAMS
 from orthogonium.losses import check_last_linear_layer_type
 from orthogonium.losses import LossXent
 from orthogonium.losses import VRA
-from orthogonium.models_factory import AOCNetV1
-from orthogonium.models_factory import Residual
+from orthogonium.model_factory.models_factory import AOCNetV1
+from orthogonium.model_factory.models_factory import Residual
 
 torch.backends.cudnn.benchmark = True
 torch.set_float32_matmul_precision("medium")
@@ -39,7 +41,7 @@
 MAX_EPOCHS = 300  # as done in resnet strikes back
 
 
-class ImagenetDataModule(pytorch_lightning.LightningDataModule):
+class ImagenetDataModule(LightningDataModule):
     # Dataset configuration
     _DATA_PATH = f"/local_data/imagenet_cache/ILSVRC/Data/CLS-LOC/"
     _BATCH_SIZE = 256
@@ -121,8 +123,8 @@ def val_dataloader(self):
         )
 
 
-class ClassificationLightningModule(pytorch_lightning.LightningModule):
-    def __init__(self, num_classes=10):
+class ClassificationLightningModule(LightningModule):
+    def __init__(self, num_classes=1000):
         super().__init__()
         self.num_classes = num_classes
         self.model = AOCNetV1(
@@ -140,15 +142,15 @@ def __init__(self, num_classes=10):
             ),
             conv=ClassParam(
                 AdaptiveOrthoConv2d,
-                bias=True,
+                bias=False,
                 padding="same",
                 padding_mode="zeros",
-                ortho_params=DEFAULT_ORTHO_PARAMS,
+                ortho_params=QR_ORTHO_PARAMS,
             ),
             act=ClassParam(MaxMin),
-            lin=ClassParam(UnitNormLinear, bias=True),
+            lin=ClassParam(UnitNormLinear, bias=False),
             norm=None,
-            pool=ClassParam(nn.LPPool2d, norm_type=2),
+            pool=None,#ClassParam(nn.LPPool2d, norm_type=2),
         )
         # self.criteria = CosineLoss()
         # self.criteria = LossXent(num_classes, offset=0, temperature=0.5 * 0.125)
@@ -167,7 +169,10 @@ def forward(self, x):
 
     def training_step(self, batch, batch_idx):
         self.model.train()
-        self.opt.train()
+        opt = self.optimizers()
+        # opt.zero_grad()
+        if hasattr(opt, "train"):
+            opt.train()
         img, label = batch
         y_hat = self.model(img)
         loss = self.criteria(y_hat, label)
@@ -176,11 +181,11 @@ def training_step(self, batch, batch_idx):
             VRA(
                 y_hat,
                 label,
-                L=1 / max(ImagenetDataModule._PREPROCESSING_PARAMS["img_std"]),
+                L=1 / min(ImagenetDataModule._PREPROCESSING_PARAMS["img_std"]),
                 eps=36 / 255,
                 last_layer_type=check_last_linear_layer_type(self.model),
             )
-        )  # L is 1 / max std of imagenet
+        )  # L is 1 / min std of imagenet
         self.log(
             "loss",
             loss,
@@ -209,7 +214,9 @@ def training_step(self, batch, batch_idx):
 
     def validation_step(self, batch, batch_idx):
         self.model.eval()
-        self.opt.eval()
+        opt = self.optimizers()
+        if hasattr(opt, "eval"):
+            opt.eval()
         img, label = batch
         y_hat = self.model(img)
         loss = self.criteria(y_hat, label)
@@ -218,11 +225,11 @@ def validation_step(self, batch, batch_idx):
             VRA(
                 y_hat,
                 label,
-                L=1 / max(ImagenetDataModule._PREPROCESSING_PARAMS["img_std"]),
+                L=1 / min(ImagenetDataModule._PREPROCESSING_PARAMS["img_std"]),
                 eps=36 / 255,
                 last_layer_type=check_last_linear_layer_type(self.model),
             )
-        )  # L is 1 / max std of imagenet
+        )  # L is 1 / min std of imagenet
         self.log(
             "val_loss",
             loss,
@@ -249,6 +256,33 @@ def validation_step(self, batch, batch_idx):
         )
         return loss
 
+    def on_fit_start(self) -> None:
+        self.optimizers().train()
+
+    def on_predict_start(self) -> None:
+        self.optimizers().eval()
+
+    def on_validation_model_eval(self) -> None:
+        self.model.eval()
+        self.optimizers().eval()
+
+    def on_validation_model_train(self) -> None:
+        self.model.train()
+        self.optimizers().train()
+
+    def on_test_model_eval(self) -> None:
+        self.model.eval()
+        self.optimizers().eval()
+
+    def on_test_model_train(self) -> None:
+        self.model.train()
+        self.optimizers().train()
+
+    def on_predict_model_eval(self) -> None:  # redundant with on_predict_start()
+        self.model.eval()
+        self.optimizers().eval()
+
+
     def configure_optimizers(self):
         """
         Setup the Adam optimizer. Note, that this function also can return a lr scheduler, which is
@@ -256,9 +290,10 @@ def configure_optimizers(self):
         """
         # return torch.optim.AdamW(self.parameters(), lr=0.0001, weight_decay=1e-5)
         optimizer = schedulefree.AdamWScheduleFree(
-            self.parameters(), lr=1e-3, weight_decay=0
+            self.parameters(), lr=5e-3, weight_decay=0
         )
-        self.opt = optimizer
+        optimizer.train()
+        self.hparams["lr"] = optimizer.param_groups[0]["lr"]
         return optimizer
 
 
@@ -266,7 +301,16 @@ def train():
     classification_module = ClassificationLightningModule(num_classes=1000)
     data_module = ImagenetDataModule()
     wandb_logger = WandbLogger(project="lipschitz-robust-imagenet", log_model=True)
-    trainer = pytorch_lightning.Trainer(
+    # wandb_logger.experiment.config["ortho_method"] = method
+    # wandb_logger.experiment.config["criteria"] = criteria
+    checkpoint_callback = pl_callbacks.ModelCheckpoint(
+        monitor="loss",
+        mode="min",
+        save_top_k=1,
+        save_last=True,
+        dirpath=f"./checkpoints/{wandb_logger.experiment.dir}",
+    )
+    trainer = Trainer(
         accelerator="gpu",
         devices=-1,  # GPUs per node
         num_nodes=1,  # Number of nodes
@@ -276,6 +320,11 @@ def train():
         max_epochs=MAX_EPOCHS,
         enable_model_summary=True,
         logger=[wandb_logger],
+        # logger=False,
+        callbacks=[
+            # pl_callbacks.LearningRateFinder(max_lr=0.05),
+            checkpoint_callback,
+        ]
     )
     summary(classification_module, input_size=(1, 3, 224, 224))
 
diff --git a/scripts/train_cifar.py b/scripts/train_cifar.py
index 0346655..d82f76f 100644
--- a/scripts/train_cifar.py
+++ b/scripts/train_cifar.py
@@ -18,13 +18,13 @@
 from torchvision.transforms import RandomResizedCrop
 from torchvision.transforms import ToTensor
 
-from orthogonium.classparam import ClassParam
+from orthogonium.model_factory.classparam import ClassParam
 from orthogonium.layers.conv.AOC import AdaptiveOrthoConv2d
 from orthogonium.layers.linear import OrthoLinear
 from orthogonium.layers.custom_activations import MaxMin
 from orthogonium.losses import LossXent
 from orthogonium.losses import VRA
-from orthogonium.models_factory import (
+from orthogonium.model_factory.models_factory import (
     StagedCNN,
 )
 
diff --git a/setup.cfg b/setup.cfg
index 84a9214..1bd73fc 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -11,25 +11,19 @@ warn_unused_configs = True
 namespace_packages = True
 ignore_missing_imports = True
 
-[mypy-deel.datasets.*]
-ignore_missing_imports = True
-
 [mypy-PIL]
 ignore_missing_imports = True
 
 [mypy-tdqm]
 ignore_missing_imports = True
 
-[mypy-mnist]
-ignore_missing_imports = True
-
 [mypy-scipy]
 ignore_missing_imports = True
 
 [tox:tox]
-envlist = py36,py37,py38,py36-lint
+envlist = py39,py310,py311,py312,py12-lint
 
-[testenv:py36-lint]
+[testenv:py12-lint]
 deps =
     black
     flake8
@@ -37,6 +31,6 @@ deps =
     mypy
     git+https://github.com/numpy/numpy-stubs.git
 commands =
-    black --check --diff setup.py deel tests
-    flake8 deel tests
-    mypy --namespace-packages deel tests
+    black --check --diff setup.py tests
+    flake8 tests
+    mypy --namespace-packages tests
diff --git a/setup.py b/setup.py
index aead7c1..def2d9d 100644
--- a/setup.py
+++ b/setup.py
@@ -32,11 +32,11 @@
 req_dev_path = os.path.join(this_directory, "requirements_dev.txt")
 
 install_requires = []
-
 if os.path.exists(req_path):
     with open(req_path) as fp:
         install_requires = [line.strip() for line in fp]
 
+install_dev_requires = []
 if os.path.exists(req_dev_path):
     with open(req_dev_path) as fp:
         install_dev_requires = [line.strip() for line in fp]
@@ -57,7 +57,9 @@
     version=version,
     # Find the package automatically (include everything):
     packages=find_packages(include=["orthogonium"]),
-    package_data={},
+    package_data={
+        "orthogonium": ["VERSION"],  # Add the VERSION file
+    },
     # Author information:
     author="Thibaut Boissin",
     author_email="thibaut.boissin@gmail.com",
diff --git a/tests/test_block_conv.py b/tests/test_block_conv.py
index d0f5134..ac1649b 100644
--- a/tests/test_block_conv.py
+++ b/tests/test_block_conv.py
@@ -4,7 +4,7 @@
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import fast_batched_matrix_conv
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import fast_matrix_conv
 
-THRESHOLD = 1e-4
+THRESHOLD = 5e-4
 
 
 # note that only square kernels are tested here
@@ -13,9 +13,9 @@
 # implementation withing this test
 @pytest.mark.parametrize("kernel_size_1", [3, 5])
 @pytest.mark.parametrize("kernel_size_2", [3, 5])
-@pytest.mark.parametrize("channels_1", [8, 16])
-@pytest.mark.parametrize("channels_2", [8, 16])
-@pytest.mark.parametrize("channels_3", [8, 16])
+@pytest.mark.parametrize("channels_1", [4, 8])
+@pytest.mark.parametrize("channels_2", [4, 8])
+@pytest.mark.parametrize("channels_3", [4, 8])
 @pytest.mark.parametrize("stride", [1, 2])
 @pytest.mark.parametrize("padding", ["valid-circular", "same-circular"])
 @pytest.mark.parametrize("groups", [1, 4])
@@ -128,4 +128,4 @@ def test_batched_conv2d_operations(
         dim=0,
     )
     res2 = fast_batched_matrix_conv(kernel_1, kernel_2, groups=groups)
-    torch.testing.assert_allclose(res1, res2, rtol=1e-5, atol=1e-5)
+    assert torch.mean(torch.square(res1 - res2)) < THRESHOLD
diff --git a/tests/test_channel_shuffle.py b/tests/test_channel_shuffle.py
new file mode 100644
index 0000000..d25238f
--- /dev/null
+++ b/tests/test_channel_shuffle.py
@@ -0,0 +1,89 @@
+import pytest
+import torch
+from orthogonium.layers.channel_shuffle import ChannelShuffle
+
+
+def test_forward_output_shapes():
+    # Test case 1: 2 groups of size 3 -> 3 groups of size 2
+    x = torch.randn(2, 6)
+    gm = ChannelShuffle(2, 3)
+    y = gm(x)
+
+    assert y.size() == torch.Size([2, 6])  # Output should have the same shape as input
+
+    # Test case 2: 2 groups of size 3 (4D tensor) -> 3 groups of size 2
+    x2 = torch.randn(2, 6, 32, 32)
+    gm = ChannelShuffle(2, 3)
+    y2 = gm(x2)
+
+    assert y2.shape == (2, 6, 32, 32)  # Check that the shape remains unchanged
+
+
+def test_channel_shuffle_invertibility():
+    # Test that ChannelShuffle is invertible
+    x2 = torch.randn(2, 6, 32, 32)
+    gm = ChannelShuffle(2, 3)
+    y2 = gm(x2)
+    gp = ChannelShuffle(3, 2)  # Invert the shuffle step
+    x2b = gp(y2)
+
+    assert torch.allclose(x2, x2b), "ChannelShuffle is not invertible"
+
+
+def test_invalid_tensor_size():
+    # Test case where input tensor size doesn't match group_in * group_out
+    x_invalid = torch.randn(2, 5)  # 5 does not match 2 * 3
+    gm = ChannelShuffle(2, 3)
+    with pytest.raises(AssertionError):
+        gm(x_invalid)
+
+
+def test_invalid_dim():
+    # Test case where dim is not equal to 1
+    with pytest.raises(AssertionError):
+        ChannelShuffle(2, 3, dim=2)
+
+
+def test_extra_repr():
+    # Test the extra_repr output
+    gm = ChannelShuffle(2, 3)
+    assert gm.extra_repr() == "group_in=2, group_out=3"
+
+
+def test_channel_shuffle_1_lipschitz():
+    # Initialize the ChannelShuffle layer
+    group_in, group_out = 2, 3
+    gm = ChannelShuffle(group_in, group_out)
+
+    # Input tensor (requires gradient for Jacobian computation)
+    x = torch.randn(2, group_in * group_out, requires_grad=True)
+
+    # Forward pass
+    y = gm(x)
+
+    # Compute the Jacobian matrix using autograd
+    jacobian = []
+    for i in range(y.numel()):  # Iterate over each output element
+        grad_output = torch.zeros_like(y)
+        grad_output.view(-1)[i] = 1  # Set gradient w.r.t. one output element
+        gradients = torch.autograd.grad(
+            outputs=y,
+            inputs=x,
+            grad_outputs=grad_output,
+            retain_graph=True,
+            create_graph=True,
+            allow_unused=True,
+        )[0]
+        jacobian.append(gradients.view(-1).detach().cpu().numpy())
+    jacobian = torch.tensor(jacobian).view(y.numel(), x.numel())  # Assemble Jacobian
+
+    # Compute the spectral norm of the Jacobian
+    singular_values = torch.linalg.svdvals(jacobian)
+    spectral_norm = singular_values.max()
+    min_singular_value = singular_values.min().item()
+
+    # Check that the spectral norm is <= 1
+    assert spectral_norm <= 1 + 1e-4, "ChannelShuffle is not 1-Lipschitz"
+    assert (
+        pytest.approx(min_singular_value, rel=1e-6) == 1
+    ), "ChannelShuffle is not orthogonal"
diff --git a/tests/test_classparam.py b/tests/test_classparam.py
new file mode 100644
index 0000000..dd95449
--- /dev/null
+++ b/tests/test_classparam.py
@@ -0,0 +1,116 @@
+import pytest
+from orthogonium.model_factory import ClassParam
+
+
+# Dummy functions for testing
+def function_1(a, b, c=3):
+    return a, b, c
+
+
+def function_with_kwargs(a, b=2, c=3):
+    return a, b, c
+
+
+def function_simple(a):
+    return a
+
+
+def test_classparam_init():
+    """Test initialization of ClassParam."""
+    cp = ClassParam(function_1, 1, 2, c=3)
+    assert cp.fct == function_1
+    assert cp.args == (1, 2)
+    assert cp.kwargs == {"c": 3}
+
+
+def test_classparam_call():
+    """Test calling ClassParam with default args and overwrites."""
+    # Initialize with defaults
+    cp = ClassParam(function_1, 1, 2, c=3)
+
+    # Test default behavior
+    assert cp() == (1, 2, 3)
+
+    # Test positional overwrite
+    assert cp(4) == (4, 2, 3)
+
+    # Test keyword overwrite
+    assert cp(4, c=5) == (4, 2, 5)
+
+    # Test both positional and keyword overwrites
+    assert cp(6, 7, c=8) == (6, 7, 8)
+
+
+def test_classparam_call_with_kwargs_only():
+    """Test keyword-based overrides when no positional args are provided."""
+    cp = ClassParam(function_with_kwargs, 1, b=5)
+
+    # Default behavior
+    assert cp() == (1, 5, 3)
+
+    # Override keyword-only arg
+    assert cp(c=10) == (1, 5, 10)
+
+    # Override positional and keyword args
+    assert cp(6, b=7, c=9) == (6, 7, 9)
+
+
+def test_classparam_call_no_function():
+    """Test calling a ClassParam instance without a function."""
+    cp = ClassParam()
+    assert callable(cp)
+    identity_function = cp()
+    assert callable(identity_function)
+    assert identity_function(42) == 42  # Should behave like an identity function
+
+
+def test_classparam_str():
+    """Test string representation of ClassParam."""
+    cp = ClassParam(function_1, 1, 2, c=3)
+    # Check string interpolation
+    assert str(cp) == "function_1(1,2,c=3)"
+
+
+def test_classparam_get_config():
+    """Test get_config generates the correct dict representation."""
+    cp = ClassParam(function_1, 1, 2, c=3)
+
+    # Nested config without flattening
+    config = cp.get_config(flatten=False)
+    assert config == {
+        "fct": "function_1",
+        "args": {"args_0": "1", "args_1": "2"},
+        "kwargs": {"c": "3"},
+    }
+
+    # Flatten the config
+    flat_config = cp.get_config(flatten=True)
+    assert flat_config == {
+        "fct": "function_1",
+        "args_0": "1",
+        "args_1": "2",
+        "c": "3",
+    }
+
+
+def test_classparam_overwrite_restrictions():
+    """Test that positional args cannot be overwritten by kwargs."""
+    cp = ClassParam(function_1, 1, 2, c=3)
+
+    with pytest.raises(TypeError):
+        cp(b="new_value")  # Overwriting positional args with kwargs should fail
+
+
+def test_flatten_config():
+    """Test the flatten_config utility method."""
+    child_cp = ClassParam(function_simple, 5)
+    parent_cp = ClassParam(function_1, child_cp, 10, c=7)
+
+    flattened = parent_cp.get_config(flatten=True)
+    assert flattened == {
+        "fct": "function_1",
+        "args_0/fct": "function_simple",
+        "args_0/args_0": "5",
+        "args_1": "10",
+        "c": "7",
+    }
diff --git a/tests/test_custom_activations.py b/tests/test_custom_activations.py
new file mode 100644
index 0000000..c55d07a
--- /dev/null
+++ b/tests/test_custom_activations.py
@@ -0,0 +1,146 @@
+import pytest
+import torch
+from orthogonium.layers.custom_activations import (
+    Abs,
+    MaxMin,
+    HouseHolder,
+    HouseHolder_Order_2,
+)
+
+
+# ------------------------- Abs Tests -------------------------
+def test_abs_output():
+    abs_layer = Abs()
+    x = torch.tensor([-1.0, 0.0, 1.0, -2.0])
+    y = abs_layer(x)
+    assert torch.allclose(
+        y, torch.tensor([1.0, 0.0, 1.0, 2.0])
+    ), "Abs function is incorrect"
+
+
+def test_abs_shapes():
+    abs_layer = Abs()
+    x = torch.randn(2, 3, 4, 5)  # Random 4D tensor
+    y = abs_layer(x)
+    assert y.shape == x.shape, "Output shape mismatch for Abs function"
+
+
+# ------------------------- MaxMin Tests -------------------------
+def test_maxmin_output_shapes():
+    maxmin_layer = MaxMin(axis=1)
+    x = torch.randn(2, 6, 4, 4)
+    y = maxmin_layer(x)
+    assert y.shape == x.shape, "Output shape mismatch for MaxMin"
+
+
+def test_maxmin_absorbant():
+    maxmin_layer = MaxMin(axis=1)
+    x = torch.randn(2, 6, 4, 4)
+    y = maxmin_layer(x)
+    z = maxmin_layer(y)
+    assert torch.allclose(
+        y.cpu(), z.cpu()
+    ), "MaxMin layer is not idempotent (output of MaxMin is not input)"
+
+
+def test_maxmin_invalid_input():
+    maxmin_layer = MaxMin(axis=1)
+    x = torch.randn(2, 5)  # Odd-sized dimension
+    with pytest.raises(ValueError):
+        maxmin_layer(x)
+
+
+# ------------------------- HouseHolder Tests -------------------------
+def test_householder_shapes():
+    layer = HouseHolder(channels=4)
+    x = torch.randn(2, 4, 8, 8)  # Input channels divisible by 2
+    y = layer(x)
+    assert y.shape == x.shape, "Output shape mismatch for HouseHolder"
+
+
+def test_householder_gradients():
+    layer = HouseHolder(channels=4)
+    x = torch.randn(2, 4, 8, 8, requires_grad=True)
+    y = layer(x)
+    y.sum().backward()
+    assert x.grad is not None, "Gradients not computed for HouseHolder"
+
+
+def test_householder_invalid_channels():
+    with pytest.raises(AssertionError):
+        HouseHolder(channels=3)  # Channels not divisible by 2
+
+
+# ------------------------- HouseHolder_Order_2 Tests -------------------------
+def test_householder_order2_shapes():
+    layer = HouseHolder_Order_2(channels=4)
+    x = torch.randn(2, 4, 8, 8)
+    y = layer(x)
+    assert y.shape == x.shape, "Output shape mismatch for HouseHolder_Order_2"
+
+
+def test_householder_order2_gradients():
+    layer = HouseHolder_Order_2(channels=6)
+    x = torch.randn(2, 6, 16, 16, requires_grad=True)
+    y = layer(x)
+    y.sum().backward()
+    assert x.grad is not None, "Gradients not computed for HouseHolder_Order_2"
+
+
+def test_householder_order2_invalid_channels():
+    with pytest.raises(AssertionError):
+        HouseHolder_Order_2(channels=5)  # Odd number of channels
+
+
+# ------------------------- Lipschitz Property Tests -------------------------
+@pytest.mark.parametrize(
+    "layer_fn, channels",
+    [
+        (Abs, None),
+        (lambda: MaxMin(axis=1), None),
+        (lambda: HouseHolder(channels=4), 4),
+        (lambda: HouseHolder_Order_2(channels=4), 4),
+    ],
+)
+def test_lipschitz_property(layer_fn, channels):
+    """
+    Tests if the layer satisfies the 1-Lipschitz property.
+    """
+    if callable(layer_fn):
+        layer = layer_fn()
+    else:
+        layer = layer_fn
+
+    channels = channels or 4  # Default to 4 channels if unspecified
+    x = torch.randn(2, channels, requires_grad=True)
+    y = layer(x)
+
+    # Calculate Jacobian
+    jacobian = []
+    for i in range(y.numel()):
+        grad_output = torch.zeros_like(y)
+        grad_output.view(-1)[i] = 1
+        gradients = torch.autograd.grad(
+            outputs=y,
+            inputs=x,
+            grad_outputs=grad_output,
+            retain_graph=True,
+            create_graph=True,
+            allow_unused=True,
+        )[0]
+        jacobian.append(gradients.view(-1).detach().cpu().numpy())
+    jacobian = torch.tensor(jacobian).view(y.numel(), x.numel())  # Construct Jacobian
+
+    # Compute singular values and check Lipschitz property
+    singular_values = torch.linalg.svdvals(jacobian)
+    assert (
+        singular_values.max() <= 1 + 1e-4
+    ), f"{layer.__class__.__name__} is not 1-Lipschitz"
+    assert (
+        singular_values.min() >= 1 - 1e-4
+    ), f"{layer.__class__.__name__} is not orthogonal"
+
+
+# ------------------------- Run All Tests -------------------------
+if __name__ == "__main__":
+    pytest.main()
diff --git a/tests/test_normalization_layers.py b/tests/test_normalization_layers.py
new file mode 100644
index 0000000..d7807b0
--- /dev/null
+++ b/tests/test_normalization_layers.py
@@ -0,0 +1,70 @@
+import pytest
+import torch
+
+from orthogonium.layers import LayerCentering2D, BatchCentering2D
+
+
+@pytest.mark.parametrize(
+    "layer_fn, num_features, orthogonal",
+    [
+        (lambda: LayerCentering2D(num_features=4), 4, False),
+        (lambda: BatchCentering2D(num_features=4), 4, True),
+    ],
+)
+@pytest.mark.parametrize(
+    "mean, std",
+    [
+        (0, 1),  # Standard Normal Distribution
+        (5, 2),  # Higher mean and variance
+        (-3, 1),  # Shifted mean
+        (0, 0.1),  # Low variance
+        (10, 5),  # Very high variance
+    ],
+)
+def test_lipschitz_constant_with_various_distributions(
+    layer_fn, num_features, orthogonal, mean, std
+):
+    """
+    Test if the layer satisfies the Lipschitz property when the input
+    comes from distributions with different means and variances.
+    """
+    layer = layer_fn()
+    layer.train()  # Set layer to training mode
+
+    batch_size, h, w = 8, 8, 8  # Input dimensions
+
+    # Generate input tensor from a specific distribution
+    x = torch.randn(batch_size, num_features, h, w) * std + mean
+    x.requires_grad_(True)  # Enable gradient tracking
+
+    # Forward pass through the layer
+    y = layer(x)
+
+    # Calculate Jacobian
+    jacobian = []
+    for i in range(y.numel()):
+        grad_output = torch.zeros_like(y)
+        grad_output.view(-1)[i] = 1
+        gradients = torch.autograd.grad(
+            outputs=y,
+            inputs=x,
+            grad_outputs=grad_output,
+            retain_graph=True,
+            create_graph=True,
+            allow_unused=True,
+        )[0]
+        jacobian.append(gradients.view(-1).detach().cpu().numpy())
+    jacobian = torch.tensor(jacobian).view(
+        y.numel(), x.numel()
+    )  # Construct Jacobian matrix
+
+    # Validate Lipschitz constant
+    singular_values = torch.linalg.svdvals(jacobian)
+    assert singular_values.max() <= 1 + 1e-4, (
+        f"Lipschitz constraint violated for input distribution with mean={mean}, std={std}; "
+        f"max singular value: {singular_values.max()}"
+    )
+    if orthogonal:
+        assert (
+            singular_values.min() >= 1 - 1e-4
+        ), f"Orthogonality constraint violated for input distribution with mean={mean}, std={std}; "
diff --git a/tests/test_ortho_linear.py b/tests/test_ortho_linear.py
new file mode 100644
index 0000000..411fa7e
--- /dev/null
+++ b/tests/test_ortho_linear.py
@@ -0,0 +1,273 @@
+import pytest
+import torch
+import numpy as np
+
+from orthogonium.layers import UnitNormLinear
+from orthogonium.layers.linear import OrthoLinear
+from orthogonium.reparametrizers import (
+    DEFAULT_TEST_ORTHO_PARAMS,
+    CHOLESKY_ORTHO_PARAMS,
+    CHOLESKY_STABLE_ORTHO_PARAMS,
+    QR_ORTHO_PARAMS,
+)
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def _compute_sv_weight_matrix(layer):
+    """
+    Computes the singular values of the weight matrix of a linear layer.
+    """
+    with torch.no_grad():
+        try:
+            svs = torch.linalg.svdvals(layer.weight)
+            return svs.min().item(), svs.max().item(), (svs.mean() / svs.max()).item()
+        except np.linalg.LinAlgError:
+            pytest.fail("SVD computation failed.")
+
+
+@pytest.mark.parametrize("input_features", [16, 32, 64])
+@pytest.mark.parametrize("output_features", [16, 32, 64])
+@pytest.mark.parametrize("bias", [True, False])
+def test_ortho_linear_instantiation(input_features, output_features, bias):
+    """
+    Test that OrthoLinear can be instantiated and has the correct weight properties.
+    """
+    try:
+        layer = OrthoLinear(input_features, output_features, bias=bias).to(device)
+        assert layer.weight.shape == (output_features, input_features), (
+            f"Weight shape mismatch: {layer.weight.shape} "
+            f"!= {(output_features, input_features)}"
+        )
+    except Exception as e:
+        pytest.fail(f"OrthoLinear instantiation failed: {e}")
+
+
+@pytest.mark.parametrize(
+    "input_features, output_features", [(16, 16), (16, 32), (32, 32)]
+)
+@pytest.mark.parametrize("bias", [True, False])
+def test_ortho_linear_singular_values(input_features, output_features, bias):
+    """
+    Test the singular values of the weight matrix in OrthoLinear.
+    """
+    layer = OrthoLinear(input_features, output_features, bias=bias).to(device)
+    sigma_min, sigma_max, stable_rank = layer.singular_values()
+    assert sigma_max <= 1.01, "Maximum singular value exceeds 1.01"
+    assert 0.95 <= sigma_min <= 1.05, "Minimum singular value not close to 1"
+    assert 0.98 <= stable_rank <= 1.02, "Stable rank not close to 1"
+
+
+@pytest.mark.parametrize(
+    "input_features, output_features", [(16, 16), (16, 32), (32, 32)]
+)
+def test_ortho_linear_norm_preservation(input_features, output_features):
+    """
+    Test that the OrthoLinear layer preserves norm as expected.
+    """
+    layer = OrthoLinear(input_features, output_features, bias=False).to(device)
+    inp = torch.randn(8, input_features).to(device)  # Batch of 8
+    with torch.no_grad():
+        output = layer(inp)
+        inp_norm = torch.norm(inp, dim=-1).mean().item()
+        out_norm = torch.norm(output, dim=-1).mean().item()
+        assert (
+            abs(inp_norm - out_norm) < 1e-2
+        ), f"Norm preservation failed: input norm {inp_norm}, output norm {out_norm}"
+
+
+@pytest.mark.parametrize("input_features", [16, 32])
+@pytest.mark.parametrize("output_features", [16, 32])
+def test_ortho_linear_training(input_features, output_features):
+    """
+    Test backpropagation and training of OrthoLinear.
+    """
+    layer = OrthoLinear(input_features, output_features, bias=True).to(device)
+    layer.train()
+    optimizer = torch.optim.SGD(layer.parameters(), lr=0.01)
+    for _ in range(10):
+        optimizer.zero_grad()
+        inp = torch.randn(8, input_features).to(device)
+        output = layer(inp)
+        loss = -output.mean()
+        loss.backward()
+        optimizer.step()
+
+    try:
+        layer.eval()
+        with torch.no_grad():
+            svs_after_training = torch.linalg.svdvals(layer.weight).cpu().numpy()
+            assert (
+                svs_after_training.max() <= 1.05
+            ), f"Max singular value after training too high: {svs_after_training.max()}"
+    except Exception as e:
+        pytest.fail(f"Training or SVD computation failed: {e}")
+
+
+@pytest.mark.parametrize("input_features", [16, 32])
+@pytest.mark.parametrize("output_features", [16, 32])
+def test_ortho_linear_impulse_response(input_features, output_features):
+    """
+    Compare singular values from impulse response and weight matrix SVD.
+    """
+    layer = OrthoLinear(input_features, output_features, bias=False).to(device)
+    sigma_min_wr, sigma_max_wr, stable_rank_wr = layer.singular_values()
+    sigma_min_ir, sigma_max_ir, stable_rank_ir = _compute_sv_weight_matrix(layer)
+
+    tol = 1e-4
+    assert abs(sigma_min_wr - sigma_min_ir) < tol, (
+        f"Impulse response min singular value mismatch: "
+        f"{sigma_min_wr} vs {sigma_min_ir}"
+    )
+    assert abs(sigma_max_wr - sigma_max_ir) < tol, (
+        f"Impulse response max singular value mismatch: "
+        f"{sigma_max_wr} vs {sigma_max_ir}"
+    )
+    assert (
+        abs(stable_rank_wr - stable_rank_ir) < tol
+    ), f"Impulse response stable rank mismatch: {stable_rank_wr} vs {stable_rank_ir}"
+
+
+@pytest.mark.parametrize("input_features", [16, 32])  # Adjust input sizes as needed
+@pytest.mark.parametrize("output_features", [16, 32])
+@pytest.mark.parametrize(
+    "orthparams_name, orthparams",
+    [
+        ("default", DEFAULT_TEST_ORTHO_PARAMS),
+        ("cholesky", CHOLESKY_ORTHO_PARAMS),
+        ("cholesky_stable", CHOLESKY_STABLE_ORTHO_PARAMS),
+        ("qr", QR_ORTHO_PARAMS),
+    ],
+)
+def test_ortho_linear_with_orthparams(
+    input_features, output_features, orthparams_name, orthparams
+):
+    """
+    Test OrthoLinear under different orthparams settings.
+    """
+    try:
+        layer = OrthoLinear(
+            input_features, output_features, bias=True, ortho_params=orthparams
+        ).to(device)
+
+        # Validate weight shape
+        assert layer.weight.shape == (output_features, input_features), (
+            f"Weight shape mismatch for {orthparams_name}: "
+            f"{layer.weight.shape} != {(output_features, input_features)}"
+        )
+
+        # Validate singular values
+        sigma_min, sigma_max, stable_rank = layer.singular_values()
+        # Add precision tolerances for different orthparams
+        tol = 1e-2 if orthparams_name.startswith("cholesky") else 1e-3
+        assert (
+            sigma_max <= 1 + tol
+        ), f"Max singular value exceeds tolerance for {orthparams_name}"
+        assert (
+            1 - tol <= sigma_min <= 1 + tol
+        ), f"Min singular value out of tolerance for {orthparams_name}"
+        assert (
+            0.98 <= stable_rank <= 1.02
+        ), f"Stable rank out of bounds for {orthparams_name}"
+
+    except Exception as e:
+        pytest.fail(f"Test failed for orthparams '{orthparams_name}': {e}")
+
+
+@pytest.mark.parametrize("input_features", [16, 32])
+@pytest.mark.parametrize("output_features", [16, 32])
+@pytest.mark.parametrize("bias", [True, False])
+def test_unitnorm_linear_instantiation(input_features, output_features, bias):
+    """
+    Test that UnitNormLinear can be instantiated and has the correct weight properties.
+    """
+    try:
+        layer = UnitNormLinear(input_features, output_features, bias=bias).to(device)
+        assert layer.weight.shape == (
+            output_features,
+            input_features,
+        ), f"Weight shape mismatch: {layer.weight.shape} != {(output_features, input_features)}"
+    except Exception as e:
+        pytest.fail(f"UnitNormLinear instantiation failed: {e}")
+
+
+@pytest.mark.parametrize("input_features", [16, 32])
+@pytest.mark.parametrize("output_features", [16, 32])
+def test_unitnorm_linear_weight_normalization(input_features, output_features):
+    """
+    Test that the weight rows of UnitNormLinear are unit-normalized.
+    """
+    layer = UnitNormLinear(input_features, output_features).to(device)
+    with torch.no_grad():
+        frobenius_norms = torch.linalg.norm(layer.weight, dim=1)
+        assert torch.allclose(
+            frobenius_norms, torch.ones_like(frobenius_norms), atol=1e-4
+        ), f"Row norms are not equal to 1: {frobenius_norms}"
+
+
+@pytest.mark.parametrize("input_features", [16, 32])
+@pytest.mark.parametrize("output_features", [16, 32])
+@pytest.mark.parametrize("batch_size", [8, 16])
+def test_unitnorm_linear_lipschitz_property(
+    input_features, output_features, batch_size
+):
+    """
+    Test if each output of UnitNormLinear satisfies the 1-Lipschitz property.
+    """
+    layer = UnitNormLinear(input_features, output_features).to(device)
+    layer.eval()
+
+    x = torch.randn(batch_size, input_features, requires_grad=True).to(device)
+    y = layer(x)
+
+    # Calculate Jacobian
+    jacobian = []
+    for i in range(y.numel()):  # Loop over each output feature
+        grad_output = torch.zeros_like(y)
+        grad_output.view(-1)[i] = 1  # Assign 1 to the i-th output for derivative calc
+        gradients = torch.autograd.grad(
+            outputs=y,
+            inputs=x,
+            grad_outputs=grad_output,
+            retain_graph=True,
+            create_graph=True,
+            allow_unused=True,
+        )[0]
+        jacobian.append(gradients.view(-1).detach().cpu().numpy())
+
+    jacobian = torch.tensor(jacobian).view(y.numel(), x.numel())  # Construct Jacobian
+
+    # Validate Lipschitz norm per row
+    row_norms = torch.linalg.norm(jacobian, dim=1)
+    assert torch.all(
+        row_norms <= 1.0 + 1e-4
+    ), f"Some rows do not satisfy 1-Lipschitz: {row_norms}"
+    assert torch.all(
+        row_norms >= 1.0 - 1e-4
+    ), f"Some rows are norm preseving: {row_norms}"
+
+
+@pytest.mark.parametrize("input_features", [16, 32])
+@pytest.mark.parametrize("output_features", [16, 32])
+def test_unitnorm_linear_training(input_features, output_features):
+    """
+    Test backpropagation and training of UnitNormLinear.
+    """
+    layer = UnitNormLinear(input_features, output_features, bias=True).to(device)
+    layer.train()
+    optimizer = torch.optim.SGD(layer.parameters(), lr=0.01)
+
+    for _ in range(10):
+        optimizer.zero_grad()
+        inp = torch.randn(8, input_features).to(device)  # Batch size: 8
+        output = layer(inp)
+        loss = -output.mean()
+        loss.backward()
+        optimizer.step()
+
+    with torch.no_grad():
+        row_norms = torch.linalg.norm(layer.weight, dim=1)
+        # Ensure row norms are still ~1 after training
+        assert torch.allclose(
+            row_norms, torch.ones_like(row_norms), atol=1e-4
+        ), f"Row norms after training not equal to 1: {row_norms}"
diff --git a/tests/test_orthogonality_conv.py b/tests/test_orthogonality_conv.py
index 8bdeef6..bef7a92 100644
--- a/tests/test_orthogonality_conv.py
+++ b/tests/test_orthogonality_conv.py
@@ -1,28 +1,33 @@
 import numpy as np
 import pytest
 import torch
-
-# from orthogonium.layers.conv.AOC.bcop_x_rko_conv import (
-#     BcopRkoConv2d as AdaptiveOrthoConv2d,
-# )
 from orthogonium.layers import AdaptiveOrthoConv2d
-from orthogonium.layers.linear.reparametrizers import DEFAULT_TEST_ORTHO_PARAMS
+from orthogonium.layers.conv.AOC import BcopRkoConv2d, FastBlockConv2d, RKOConv2d
+from orthogonium.reparametrizers import (
+    DEFAULT_TEST_ORTHO_PARAMS,
+    EXP_ORTHO_PARAMS,
+    CHOLESKY_ORTHO_PARAMS,
+    QR_ORTHO_PARAMS,
+    CHOLESKY_STABLE_ORTHO_PARAMS,
+)
 
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = "cpu" #  torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 def _compute_sv_impulse_response_layer(layer, img_shape):
     with torch.no_grad():
+        layer = layer.to(device)
         inputs = torch.eye(img_shape[0] * img_shape[1] * img_shape[2]).view(
             img_shape[0] * img_shape[1] * img_shape[2],
             img_shape[0],
             img_shape[1],
             img_shape[2],
-        )
+        ).to(device)
         outputs = layer(inputs)
         try:
             svs = torch.linalg.svdvals(outputs.view(outputs.shape[0], -1))
+            svs = svs.cpu()
             return svs.min(), svs.max(), svs.mean() / svs.max()
         except np.linalg.LinAlgError:
             print("SVD failed returning only largest singular value")
@@ -36,11 +41,13 @@ def check_orthogonal_layer(
     kernel_size,
     output_channels,
     expected_kernel_shape,
+    tol=5e-4,
+    sigma_min_requirement=0.95,
 ):
     imsize = 8
     # Test backpropagation and weight update
     try:
-        orthoconv.to(device)
+        orthoconv = orthoconv.to(device)
         orthoconv.train()
         opt = torch.optim.SGD(orthoconv.parameters(), lr=0.001)
         for i in range(25):
@@ -58,18 +65,12 @@ def check_orthogonal_layer(
         pytest.fail(
             f"BCOP weight has incorrect shape: {orthoconv.weight.shape} vs {(output_channels, input_channels // groups, kernel_size, kernel_size)}"
         )
-    # check that the layer is norm preserving
-    inp_norm = torch.sqrt(torch.square(inp).sum(dim=(-3, -2, -1))).float().item()
-    out_norm = torch.sqrt(torch.square(output).sum(dim=(-3, -2, -1))).float().item()
-    # if inp_norm <= out_norm - 1e-3:
-    #     pytest.fail(
-    #         f"BCOP is not norm preserving: {inp_norm} vs {out_norm} with rel error {abs(inp_norm - out_norm) / inp_norm}"
-    #     )
     # Test singular_values function
     sigma_min, sigma_max, stable_rank = orthoconv.singular_values()  # try:
     sigma_min_ir, sigma_max_ir, stable_rank_ir = _compute_sv_impulse_response_layer(
         orthoconv, (input_channels, imsize, imsize)
     )
+    print(f"input_shape = {inp.shape}, output_shape = {output.shape}")
     print(
         f"({input_channels}->{output_channels}, g{groups}, k{kernel_size}), "
         f"sigma_max:"
@@ -78,11 +79,10 @@ def check_orthogonal_layer(
         f" {sigma_min:.3f}/{sigma_min_ir:.3f}, "
         f"stable_rank: {stable_rank:.3f}/{stable_rank_ir:.3f}"
     )
-    tol = 1e-4
     # check that the singular values are close to 1
     assert sigma_max_ir < (1 + tol), "sigma_max is not less than 1"
     assert (sigma_min_ir < (1 + tol)) and (
-        sigma_min_ir > 0.95
+        sigma_min_ir > sigma_min_requirement
     ), "sigma_min is not close to 1"
     assert abs(stable_rank_ir - 1) < tol, "stable_rank is not close to 1"
     # check that the singular values are close to the impulse response values
@@ -191,14 +191,17 @@ def test_dilation(kernel_size, input_channels, output_channels, stride, groups):
     )
 
 
-@pytest.mark.parametrize("kernel_size", [3, 4, 5])
+@pytest.mark.parametrize("kernel_size", [2, 3, 4, 5])
 @pytest.mark.parametrize("input_channels", [8, 16])
 @pytest.mark.parametrize(
-    "output_channels", [8]
+    "output_channels", [8, 16]
 )  # dilated+strided convolutions are not supported for output_channels < input_channels
 @pytest.mark.parametrize("stride", [2])
+@pytest.mark.parametrize("dilation", [2, 3])
 @pytest.mark.parametrize("groups", [1, 2, 4])
-def test_dilation_strided(kernel_size, input_channels, output_channels, stride, groups):
+def test_dilation_strided(
+    kernel_size, input_channels, output_channels, stride, dilation, groups
+):
     """
     test combinations of kernel size, input channels, output channels, stride and groups
     """
@@ -209,21 +212,25 @@ def test_dilation_strided(kernel_size, input_channels, output_channels, stride,
             in_channels=input_channels,
             out_channels=output_channels,
             stride=stride,
-            dilation=2,
+            dilation=dilation,
             groups=groups,
             bias=False,
             padding=(
-                int(np.ceil((2 * (kernel_size - 1)) / 2)),
-                int(np.floor((2 * (kernel_size - 1)) / 2)),
+                int(np.ceil((dilation * (kernel_size - 1) + 1 - stride) / 2)),
+                int(np.ceil((dilation * (kernel_size - 1) + 1 - stride) / 2)),
             ),
             padding_mode="circular",
             ortho_params=DEFAULT_TEST_ORTHO_PARAMS,
         )
     except Exception as e:
-        if kernel_size < stride:
-            # we expect this configuration to raise a RuntimeError
+        if (output_channels >= input_channels) and (
+            ((dilation % stride) == 0) and (stride > 1)
+        ):
+            # we expect this configuration to raise a ValueError
             # pytest.skip(f"BCOP instantiation failed with: {e}")
             return
+        if (kernel_size == stride) and (((dilation % stride) == 0) and (stride > 1)):
+            return
         else:
             pytest.fail(f"BCOP instantiation failed with: {e}")
     check_orthogonal_layer(
@@ -242,8 +249,8 @@ def test_dilation_strided(kernel_size, input_channels, output_channels, stride,
 
 
 @pytest.mark.parametrize("kernel_size", [3, 4, 5])
-@pytest.mark.parametrize("input_channels", [2, 4, 8, 16, 32])
-@pytest.mark.parametrize("output_channels", [2, 4, 8, 16, 32])
+@pytest.mark.parametrize("input_channels", [2, 4, 32])
+@pytest.mark.parametrize("output_channels", [2, 4, 32])
 @pytest.mark.parametrize("stride", [2, 4])
 @pytest.mark.parametrize("groups", [1])
 def test_strided(kernel_size, input_channels, output_channels, stride, groups):
@@ -333,8 +340,8 @@ def test_even_kernels(kernel_size, input_channels, output_channels, stride, grou
 
 
 @pytest.mark.parametrize("kernel_size", [1, 2])
-@pytest.mark.parametrize("input_channels", [4, 8, 16, 32, 64])
-@pytest.mark.parametrize("output_channels", [4, 8, 16, 32, 64])
+@pytest.mark.parametrize("input_channels", [4, 8, 32])
+@pytest.mark.parametrize("output_channels", [4, 8, 32])
 @pytest.mark.parametrize("groups", [1, 2])
 def test_rko(kernel_size, input_channels, output_channels, groups):
     """
@@ -416,3 +423,146 @@ def test_depthwise(kernel_size, input_channels, output_channels, stride, groups)
             kernel_size,
         ),
     )
+
+
+def test_invalid_kernel_smaller_than_stride():
+    """
+    A test to ensure that kernel_size < stride raises an expected ValueError
+    """
+    with pytest.raises(ValueError, match=r"kernel size must be smaller than stride"):
+        AdaptiveOrthoConv2d(
+            in_channels=8,
+            out_channels=4,
+            kernel_size=2,
+            stride=3,  # Invalid: kernel_size < stride
+            groups=1,
+            padding=0,
+        )
+    with pytest.raises(ValueError, match=r"kernel size must be smaller than stride"):
+        BcopRkoConv2d(
+            in_channels=8,
+            out_channels=4,
+            kernel_size=2,
+            stride=3,  # Invalid: kernel_size < stride
+            groups=1,
+            padding=0,
+        )
+    with pytest.raises(ValueError, match=r"kernel size must be smaller than stride"):
+        FastBlockConv2d(
+            in_channels=8,
+            out_channels=4,
+            kernel_size=2,
+            stride=3,  # Invalid: kernel_size < stride
+            groups=1,
+            padding=0,
+        )
+
+
+def test_invalid_dilation_with_stride():
+    """
+    A test to ensure dilation > 1 while stride > 1 raises an expected ValueError
+    """
+    with pytest.raises(
+        ValueError,
+        match=r"dilation must be 1 when stride is not 1",
+    ):
+        AdaptiveOrthoConv2d(
+            in_channels=8,
+            out_channels=16,
+            kernel_size=3,
+            stride=2,
+            dilation=2,  # Invalid: dilation > 1 while stride > 1
+            groups=1,
+            padding=0,
+        )
+    with pytest.raises(
+        ValueError,
+        match=r"dilation must be 1 when stride is not 1",
+    ):
+        BcopRkoConv2d(
+            in_channels=8,
+            out_channels=16,
+            kernel_size=3,
+            stride=2,
+            dilation=2,  # Invalid: dilation > 1 while stride > 1
+            groups=1,
+            padding=0,
+        )
+    with pytest.raises(
+        ValueError,
+        match=r"dilation must be 1 when stride is not 1",
+    ):
+        RKOConv2d(
+            in_channels=8,
+            out_channels=16,
+            kernel_size=3,
+            stride=2,
+            dilation=2,  # Invalid: dilation > 1 while stride > 1
+            groups=1,
+            padding=0,
+        )
+
+
+@pytest.mark.parametrize("kernel_size", [1, 3])
+@pytest.mark.parametrize("input_channels", [8, 16])
+@pytest.mark.parametrize("output_channels", [8, 16])
+@pytest.mark.parametrize("stride", [1, 2])
+@pytest.mark.parametrize("groups", [1, 2])
+@pytest.mark.parametrize(
+    "ortho_params",
+    [
+        "default_bb",
+        "exp",
+        "qr",
+        "cholesky",
+        "cholesky_stable",
+    ],
+)
+def test_parametrizers_standard_configs(
+    kernel_size, input_channels, output_channels, stride, groups, ortho_params
+):
+    """
+    test combinations of kernel size, input channels, output channels, stride and groups
+    """
+    ortho_params_dict = {
+        "default_bb": DEFAULT_TEST_ORTHO_PARAMS,
+        "exp": EXP_ORTHO_PARAMS,
+        "qr": QR_ORTHO_PARAMS,
+        "cholesky": CHOLESKY_ORTHO_PARAMS,
+        "cholesky_stable": CHOLESKY_STABLE_ORTHO_PARAMS,
+    }  # trick to have the actual method name displayed properly if test fails
+    # Test instantiation
+    try:
+        orthoconv = AdaptiveOrthoConv2d(
+            kernel_size=kernel_size,
+            in_channels=input_channels,
+            out_channels=output_channels,
+            stride=stride,
+            groups=groups,
+            bias=False,
+            padding=(kernel_size // 2, kernel_size // 2),
+            padding_mode="circular",
+            ortho_params=ortho_params_dict[ortho_params],
+        )
+    except Exception as e:
+        if kernel_size < stride:
+            # we expect this configuration to raise a RuntimeError
+            # pytest.skip(f"BCOP instantiation failed with: {e}")
+            return
+        else:
+            pytest.fail(f"BCOP instantiation failed with: {e}")
+    check_orthogonal_layer(
+        orthoconv,
+        groups,
+        input_channels,
+        kernel_size,
+        output_channels,
+        (
+            output_channels,
+            input_channels // groups,
+            kernel_size,
+            kernel_size,
+        ),
+        tol=5e-2 if ortho_params.startswith("cholesky") else 1e-3,
+        sigma_min_requirement=0.75 if ortho_params.startswith("cholesky") else 0.95,
+    )
diff --git a/tests/test_orthogonality_conv_transpose.py b/tests/test_orthogonality_conv_transpose.py
index f5f1c59..24058c6 100644
--- a/tests/test_orthogonality_conv_transpose.py
+++ b/tests/test_orthogonality_conv_transpose.py
@@ -1,15 +1,20 @@
 import numpy as np
 import pytest
 import torch
-
 from orthogonium.layers.conv.AOC.ortho_conv import AdaptiveOrthoConvTranspose2d
-from orthogonium.layers.linear.reparametrizers import DEFAULT_TEST_ORTHO_PARAMS
 from tests.test_orthogonality_conv import check_orthogonal_layer
-
-
-# from orthogonium.layers.conv.fast_block_ortho_conv import FlashBCOP
-
-# from orthogonium.layers.conv.ortho_conv import OrthoConv as FlashBCOP
+from orthogonium.layers.conv.AOC import (
+    BcopRkoConvTranspose2d,
+    FastBlockConvTranspose2D,
+    RkoConvTranspose2d,
+)
+from orthogonium.reparametrizers import (
+    DEFAULT_TEST_ORTHO_PARAMS,
+    EXP_ORTHO_PARAMS,
+    CHOLESKY_ORTHO_PARAMS,
+    QR_ORTHO_PARAMS,
+    CHOLESKY_STABLE_ORTHO_PARAMS,
+)
 
 
 def _compute_sv_impulse_response_layer(layer, img_shape):
@@ -30,8 +35,8 @@ def _compute_sv_impulse_response_layer(layer, img_shape):
 
 
 @pytest.mark.parametrize("kernel_size", [1, 2, 3])
-@pytest.mark.parametrize("input_channels", [4, 8, 16, 32])
-@pytest.mark.parametrize("output_channels", [4, 8, 16, 32])
+@pytest.mark.parametrize("input_channels", [4, 8, 32])
+@pytest.mark.parametrize("output_channels", [4, 8, 32])
 @pytest.mark.parametrize("stride", [1, 2])
 @pytest.mark.parametrize("groups", [1, 2])
 def test_convtranspose(kernel_size, input_channels, output_channels, stride, groups):
@@ -39,35 +44,119 @@ def test_convtranspose(kernel_size, input_channels, output_channels, stride, gro
     padding = (0, 0)
     padding_mode = "zeros"
     try:
-        if (
-            kernel_size > 1
-            and kernel_size != stride
-            and output_channels * (stride**2) < input_channels
-        ):
-            with pytest.warns(RuntimeWarning):
-                orthoconvtranspose = AdaptiveOrthoConvTranspose2d(
-                    kernel_size=kernel_size,
-                    in_channels=input_channels,
-                    out_channels=output_channels,
-                    stride=stride,
-                    groups=groups,
-                    bias=False,
-                    padding=padding,
-                    padding_mode=padding_mode,
-                    ortho_params=DEFAULT_TEST_ORTHO_PARAMS,
-                )
+
+        orthoconvtranspose = AdaptiveOrthoConvTranspose2d(
+            kernel_size=kernel_size,
+            in_channels=input_channels,
+            out_channels=output_channels,
+            stride=stride,
+            groups=groups,
+            bias=False,
+            padding=padding,
+            padding_mode=padding_mode,
+            ortho_params=DEFAULT_TEST_ORTHO_PARAMS,
+        )
+    except Exception as e:
+        if kernel_size < stride:
+            # we expect this configuration to raise a RuntimeError
+            # pytest.skip(f"BCOP instantiation failed with: {e}")
+            return
         else:
-            orthoconvtranspose = AdaptiveOrthoConvTranspose2d(
-                kernel_size=kernel_size,
-                in_channels=input_channels,
-                out_channels=output_channels,
-                stride=stride,
-                groups=groups,
-                bias=False,
-                padding=padding,
-                padding_mode=padding_mode,
-                ortho_params=DEFAULT_TEST_ORTHO_PARAMS,
-            )
+            pytest.fail(f"BCOP instantiation failed with: {e}")
+    if (
+        kernel_size > 1
+        and kernel_size != stride
+        and output_channels * (stride**2) < input_channels
+    ):
+        pytest.skip("this case is not handled yet")
+    check_orthogonal_layer(
+        orthoconvtranspose,
+        groups,
+        input_channels,
+        kernel_size,
+        output_channels,
+        (
+            input_channels,
+            output_channels // groups,
+            kernel_size,
+            kernel_size,
+        ),
+    )
+
+
+def test_invalid_kernel_smaller_than_stride():
+    """
+    A test to ensure that kernel_size < stride raises an expected ValueError
+    """
+    with pytest.raises(ValueError, match=r"kernel size must be smaller than stride"):
+        AdaptiveOrthoConvTranspose2d(
+            in_channels=8,
+            out_channels=16,
+            kernel_size=2,
+            stride=3,  # Invalid: kernel_size < stride
+            groups=1,
+        )
+    with pytest.raises(ValueError, match=r"kernel size must be smaller than stride"):
+        BcopRkoConvTranspose2d(
+            in_channels=8,
+            out_channels=16,
+            kernel_size=2,
+            stride=3,  # Invalid: kernel_size < stride
+            groups=1,
+        )
+    with pytest.raises(ValueError, match=r"kernel size must be smaller than stride"):
+        FastBlockConvTranspose2D(
+            in_channels=8,
+            out_channels=16,
+            kernel_size=2,
+            stride=3,  # Invalid: kernel_size < stride
+            groups=1,
+        )
+
+
+@pytest.mark.parametrize("kernel_size", [1, 3])
+@pytest.mark.parametrize("input_channels", [8, 16])
+@pytest.mark.parametrize("output_channels", [8, 16])
+@pytest.mark.parametrize("stride", [1, 2])
+@pytest.mark.parametrize("groups", [1, 2])
+@pytest.mark.parametrize(
+    "ortho_params",
+    [
+        "default_bb",
+        "exp",
+        "qr",
+        "cholesky",
+        "cholesky_stable",
+    ],
+)
+def test_parametrizers_standard_configs(
+    kernel_size, input_channels, output_channels, stride, groups, ortho_params
+):
+    """
+    test combinations of kernel size, input channels, output channels, stride and groups
+    """
+    ortho_params_dict = {
+        "default_bb": DEFAULT_TEST_ORTHO_PARAMS,
+        "exp": EXP_ORTHO_PARAMS,
+        "qr": QR_ORTHO_PARAMS,
+        "cholesky": CHOLESKY_ORTHO_PARAMS,
+        "cholesky_stable": CHOLESKY_STABLE_ORTHO_PARAMS,
+    }  # trick to have the actual method name displayed properly if test fails
+    # Test instantiation
+    padding = (0, 0)
+    padding_mode = "zeros"
+    try:
+        orthoconvtranspose = AdaptiveOrthoConvTranspose2d(
+            kernel_size=kernel_size,
+            in_channels=input_channels,
+            out_channels=output_channels,
+            stride=stride,
+            groups=groups,
+            bias=False,
+            padding=padding,
+            padding_mode=padding_mode,
+            ortho_params=ortho_params_dict[ortho_params],
+        )
     except Exception as e:
         if kernel_size < stride:
             # we expect this configuration to raise a RuntimeError
@@ -93,4 +182,6 @@ def test_convtranspose(kernel_size, input_channels, output_channels, stride, gro
             kernel_size,
             kernel_size,
         ),
+        tol=5e-2 if ortho_params.startswith("cholesky") else 1e-3,
+        sigma_min_requirement=0.75 if ortho_params.startswith("cholesky") else 0.95,
     )
diff --git a/tests/test_rko.py b/tests/test_rko.py
index c38c9b2..494ce9f 100644
--- a/tests/test_rko.py
+++ b/tests/test_rko.py
@@ -3,7 +3,7 @@
 import torch
 
 from orthogonium.layers.conv.AOC.rko_conv import RKOConv2d
-from orthogonium.layers.linear.reparametrizers import DEFAULT_TEST_ORTHO_PARAMS
+from orthogonium.reparametrizers import DEFAULT_TEST_ORTHO_PARAMS
 
 # from orthogonium.layers.conv.fast_block_ortho_conv import FlashBCOP
 
@@ -57,11 +57,6 @@ def check_orthogonal_layer(
     # check that the layer is norm preserving
     inp_norm = torch.sqrt(torch.square(inp).sum(dim=(-3, -2, -1))).float().item()
     out_norm = torch.sqrt(torch.square(output).sum(dim=(-3, -2, -1))).float().item()
-    if check_orthogonality:
-        if inp_norm <= out_norm - 1e-3:
-            pytest.fail(
-                f"RKO is not norm preserving: {inp_norm} vs {out_norm} with rel error {abs(inp_norm - out_norm) / inp_norm}"
-            )
     # Test singular_values function
     sigma_min_ir, sigma_max_ir, stable_rank_ir = _compute_sv_impulse_response_layer(
         orthoconv, (input_channels, imsize, imsize)
@@ -102,9 +97,10 @@ def check_orthogonal_layer(
     assert (
         abs(sigma_min - sigma_min_ir) < tol
     ), f"sigma_min is not close to its IR value: {sigma_min} vs {sigma_min_ir}"
-    assert (
-        abs(stable_rank - stable_rank_ir) < tol
-    ), f"stable_rank is not close to its IR value: {stable_rank} vs {stable_rank_ir}"
+    if check_orthogonality:
+        assert (
+            abs(stable_rank - stable_rank_ir) < tol
+        ), f"stable_rank is not close to its IR value: {stable_rank} vs {stable_rank_ir}"
 
 
 @pytest.mark.parametrize("kernel_size", [1, 3, 5])
@@ -117,6 +113,7 @@ def test_standard_configs(kernel_size, input_channels, output_channels, stride,
     test combinations of kernel size, input channels, output channels, stride and groups
     """
     # Test instantiation
+    padding = (0,0) if (kernel_size == stride) else ((kernel_size - 1) // 2, (kernel_size - 1) // 2)
     try:
         orthoconv = RKOConv2d(
             kernel_size=kernel_size,
@@ -125,7 +122,7 @@ def test_standard_configs(kernel_size, input_channels, output_channels, stride,
             stride=stride,
             groups=groups,
             bias=False,
-            padding=(kernel_size // 2, kernel_size // 2),
+            padding=padding,
             padding_mode="circular",
             ortho_params=DEFAULT_TEST_ORTHO_PARAMS,
         )
@@ -153,8 +150,8 @@ def test_standard_configs(kernel_size, input_channels, output_channels, stride,
 
 
 @pytest.mark.parametrize("kernel_size", [3, 4, 5])
-@pytest.mark.parametrize("input_channels", [2, 4, 8, 16, 32])
-@pytest.mark.parametrize("output_channels", [2, 4, 8, 16, 32])
+@pytest.mark.parametrize("input_channels", [2, 4, 16])
+@pytest.mark.parametrize("output_channels", [2, 4, 16])
 @pytest.mark.parametrize("stride", [2, 4])
 @pytest.mark.parametrize("groups", [1])
 def test_strided(kernel_size, input_channels, output_channels, stride, groups):
@@ -165,6 +162,7 @@ def test_strided(kernel_size, input_channels, output_channels, stride, groups):
     that you actually increase overall dimension.
     """
     # Test instantiation
+    padding = (0,0) if (kernel_size == stride) else ((kernel_size - 1) // 2, (kernel_size - 1) // 2)
     try:
         orthoconv = RKOConv2d(
             kernel_size=kernel_size,
@@ -173,7 +171,7 @@ def test_strided(kernel_size, input_channels, output_channels, stride, groups):
             stride=stride,
             groups=groups,
             bias=False,
-            padding=((kernel_size - 1) // 2, (kernel_size - 1) // 2),
+            padding=padding,
             padding_mode="circular",
             ortho_params=DEFAULT_TEST_ORTHO_PARAMS,
         )
@@ -246,8 +244,8 @@ def test_even_kernels(kernel_size, input_channels, output_channels, stride, grou
 
 
 @pytest.mark.parametrize("kernel_size", [1, 2])
-@pytest.mark.parametrize("input_channels", [4, 8, 16, 32, 64])
-@pytest.mark.parametrize("output_channels", [4, 8, 16, 32, 64])
+@pytest.mark.parametrize("input_channels", [4, 8, 32])
+@pytest.mark.parametrize("output_channels", [4, 8, 32])
 @pytest.mark.parametrize("groups", [1, 2])
 def test_rko(kernel_size, input_channels, output_channels, groups):
     """
@@ -280,7 +278,7 @@ def test_rko(kernel_size, input_channels, output_channels, groups):
             kernel_size,
             kernel_size,
         ),
-        check_orthogonality=(kernel_size == kernel_size),
+        check_orthogonality=True,
     )
 
 
@@ -294,6 +292,7 @@ def test_depthwise(kernel_size, input_channels, output_channels, stride, groups)
     test combinations of kernel size, input channels, output channels, stride and groups
     """
     # Test instantiation
+    padding = (0,0) if (kernel_size == stride) else ((kernel_size - 1) // 2, (kernel_size - 1) // 2)
     try:
         orthoconv = RKOConv2d(
             kernel_size=kernel_size,
@@ -302,7 +301,7 @@ def test_depthwise(kernel_size, input_channels, output_channels, stride, groups)
             stride=stride,
             groups=groups,
             bias=False,
-            padding=(kernel_size // 2, kernel_size // 2),
+            padding=padding,
             padding_mode="circular",
             ortho_params=DEFAULT_TEST_ORTHO_PARAMS,
         )
diff --git a/tests/test_sll.py b/tests/test_sll.py
new file mode 100644
index 0000000..fe7bc37
--- /dev/null
+++ b/tests/test_sll.py
@@ -0,0 +1,99 @@
+import pytest
+import torch
+from orthogonium.layers.conv.SLL import (
+    SLLxAOCLipschitzResBlock,
+    SDPBasedLipschitzResBlock,
+    SDPBasedLipschitzDense,
+    AOCLipschitzResBlock,
+)
+
+
+@pytest.mark.parametrize(
+    "layer_class, init_params, batch_shape",
+    [
+        (
+            SDPBasedLipschitzResBlock,
+            {"cin": 4, "inner_dim_factor": 2, "kernel_size": 3},
+            (8, 4, 8, 8),
+        ),
+        (
+            SLLxAOCLipschitzResBlock,
+            {"cin": 4, "cout": 4, "inner_dim_factor": 2, "kernel_size": 3},
+            (8, 4, 8, 8),
+        ),
+        (
+            SDPBasedLipschitzDense,
+            {"in_features": 64, "out_features": 64, "inner_dim": 64},
+            (8, 64),
+        ),
+        (
+            AOCLipschitzResBlock,
+            {"in_channels": 4, "inner_dim_factor": 2, "kernel_size": 3},
+            (8, 4, 8, 8),
+        ),
+    ],
+)
+def test_lipschitz_layers(layer_class, init_params, batch_shape):
+    """
+    Generalized test for layers in the SLLx module to check Lipschitz constraints.
+    """
+    # Initialize layer
+    layer = layer_class(**init_params)
+
+    # Define input and target tensors
+    x = torch.randn(*batch_shape, requires_grad=True)  # Input
+
+    # Pre-optimization Lipschitz constant (if applicable)
+    if hasattr(layer, "compute_t"):
+        pre_lipschitz_constant = compute_lipschitz_constant(layer, x)
+        print(f"{layer_class.__name__} | Before: {pre_lipschitz_constant:.6f}")
+        assert (
+            pre_lipschitz_constant <= 1 + 1e-4
+        ), "Pre-optimization Lipschitz constant violation."
+
+    # Define optimizer and loss function
+    optimizer = torch.optim.Adam(layer.parameters(), lr=1e-3)
+
+    # Perform a few optimization steps
+    for _ in range(10):  # Run 10 optimization steps
+        optimizer.zero_grad()
+        output = layer(x)
+        loss = -torch.sum(torch.square(output))
+        loss.backward()
+        optimizer.step()
+
+    # Post-optimization Lipschitz constant (if applicable)
+    if hasattr(layer, "compute_t"):
+        post_lipschitz_constant = compute_lipschitz_constant(layer, x)
+        print(f"{layer_class.__name__} | After: {post_lipschitz_constant:.6f}")
+        assert (
+            post_lipschitz_constant <= 1 + 1e-4
+        ), "Post-optimization Lipschitz constant violation."
+
+
+def compute_lipschitz_constant(layer, x):
+    """
+    Calculate the Lipschitz constant for a given layer by computing the
+    maximum singular value of the Jacobian.
+    """
+    y = layer(x)
+
+    # Compute Jacobian by autograd
+    jacobian = []
+    for i in range(y.numel()):
+        grad_output = torch.zeros_like(y)
+        grad_output.view(-1)[i] = 1
+        gradients = torch.autograd.grad(
+            outputs=y,
+            inputs=x,
+            grad_outputs=grad_output,
+            retain_graph=True,
+            create_graph=True,
+            allow_unused=True,
+        )[0]
+        jacobian.append(gradients.view(-1).detach().cpu().numpy())
+    jacobian = torch.tensor(jacobian).view(y.numel(), x.numel())  # Construct Jacobian
+
+    # Compute singular values and return the maximum value
+    singular_values = torch.linalg.svdvals(jacobian)
+    return singular_values.max().item()