Skip to content

Commit 45720a6

Browse files
committed
Put __all__ in complex modules to hide internal functions
1 parent 6427306 commit 45720a6

File tree

11 files changed

+93
-7
lines changed

11 files changed

+93
-7
lines changed

docs/source/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@
4848
]
4949

5050
autodoc2_module_all_regexes = [
51-
r"evox\..*",
51+
r"evox\.core\..*",
52+
r"evox\.problems\..*",
53+
r"evox\.workflows\..*",
5254
]
5355

5456
autodoc2_render_plugin = "myst"

src/evox/core/components.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
__all__ = [
2+
"Algorithm",
3+
"Problem",
4+
"Workflow",
5+
"Monitor",
6+
]
7+
8+
19
from abc import ABC
210
from typing import Any, Dict
311

src/evox/core/module.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
__all__ = [
2+
"Parameter",
3+
"Mutable",
4+
"ModuleBase",
5+
"TransformGetSetItemToIndex",
6+
"compile",
7+
"vmap",
8+
"use_state",
9+
]
10+
11+
112
from functools import wraps
213
from typing import Callable, Dict, Optional, TypeVar, Union
314

@@ -34,9 +45,7 @@ def Parameter(
3445
)
3546

3647

37-
def Mutable(
38-
value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None
39-
) -> torch.Tensor:
48+
def Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> torch.Tensor:
4049
"""Wraps a value as a mutable tensor.
4150
This is often used to label a value in an algorithm as a mutable tensor that may changes during iteration(s).
4251
@@ -172,7 +181,7 @@ def use_state(stateful_func: Union[Callable, nn.Module]) -> Callable:
172181
def wrapper(params_and_buffers: Dict[str, torch.Tensor], *args, **kwargs):
173182
params_and_buffers = {("_inner_module." + k): v for k, v in params_and_buffers.items()}
174183
output = torch.func.functional_call(module, params_and_buffers, *args, **kwargs)
175-
params_and_buffers = {k[len("_inner_module."):]: v for k, v in params_and_buffers.items()}
184+
params_and_buffers = {k[len("_inner_module.") :]: v for k, v in params_and_buffers.items()}
176185
if output is None:
177186
return params_and_buffers
178187
else:

src/evox/problems/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["neuroevolution", "numerical"]
1+
__all__ = ["neuroevolution", "numerical", "hpo_wrapper"]
22

33

4-
from . import neuroevolution, numerical
4+
from . import neuroevolution, numerical, hpo_wrapper

src/evox/problems/hpo_wrapper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
__all__ = [
2+
"HPOMonitor",
3+
"HPOFitnessMonitor",
4+
"HPOProblemWrapper",
5+
"HPOData",
6+
]
7+
8+
19
import weakref
210
from abc import ABC
311
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
__all__ = [
2+
"brax",
3+
"mujoco_playground",
4+
"supervised_learning",
5+
"BraxProblem",
6+
"MujocoPlaygroundProblem",
7+
"SupervisedLearningProblem",
8+
]
9+
10+
from . import brax, mujoco_playground, supervised_learning
11+
from .brax import BraxProblem
12+
from .mujoco_playground import MujocoPlaygroundProblem
13+
from .supervised_learning import SupervisedLearningProblem

src/evox/problems/neuroevolution/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
__all__ = [
2+
"ModelStateForwardResult",
3+
"get_vmap_model_state_forward",
4+
]
5+
6+
17
import copy
28
from typing import Callable, Dict, NamedTuple, Tuple
39

src/evox/problems/numerical/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
__all__ = [
2+
"basic",
3+
"cec2022",
4+
"dtlz",
25
"Ackley",
36
"Griewank",
47
"Rastrigin",
@@ -23,6 +26,7 @@
2326
"ellipsoid_func",
2427
]
2528

29+
from . import basic, cec2022, dtlz
2630
from .basic import (
2731
Ackley,
2832
Ellipsoid,

src/evox/problems/numerical/basic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
__all__ = [
2+
"ShiftAffineNumericalProblem",
3+
"Ackley",
4+
"Griewank",
5+
"Rastrigin",
6+
"Rosenbrock",
7+
"Schwefel",
8+
"Sphere",
9+
"Ellipsoid",
10+
"ackley_func",
11+
"griewank_func",
12+
"rastrigin_func",
13+
"rosenbrock_func",
14+
"schwefel_func",
15+
"sphere_func",
16+
"ellipsoid_func",
17+
]
18+
19+
120
import torch
221

322
from evox.core import Problem

src/evox/problems/numerical/cec2022.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
__all__ = [
2+
"CEC2022",
3+
]
4+
5+
16
import os
27
from math import ceil
38
from typing import List, Optional

0 commit comments

Comments
 (0)