Skip to content

Commit 18408fe

Browse files
committed
Manually split into separate tests
Fix typo
1 parent 9c3da77 commit 18408fe

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

unit_test/algorithms/test_moea.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,37 @@ def setUp(self):
4949
torch.compiler.reset()
5050
torch.manual_seed(42)
5151
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
52-
pop_size = 20
53-
dim = 10
54-
lb = -torch.ones(dim)
55-
ub = torch.ones(dim)
56-
self.algo = [
57-
NSGA2(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
58-
NSGA3(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
59-
RVEA(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
60-
MOEAD(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
61-
HypE(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
62-
RVEAa(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
63-
]
64-
65-
def test_moea_variants(self):
66-
for algo in self.algo:
67-
self.run_algorithm(algo)
68-
self.run_compiled_algorithm(algo)
69-
# self.run_vmap_algorithm(algo)
52+
self.pop_size = 20
53+
self.dim = 10
54+
self.lb = -torch.ones(self.dim)
55+
self.ub = torch.ones(self.dim)
56+
57+
def test_nsga2(self):
58+
algo = NSGA2(pop_size=self.pop_size, n_objs=3, lb=self.lb, ub=self.ub)
59+
self.run_algorithm(algo)
60+
self.run_compiled_algorithm(algo)
61+
62+
def test_nsga3(self):
63+
algo = NSGA3(pop_size=self.pop_size, n_objs=3, lb=self.lb, ub=self.ub)
64+
self.run_algorithm(algo)
65+
self.run_compiled_algorithm(algo)
66+
67+
def test_rvea(self):
68+
algo = RVEA(pop_size=self.pop_size, n_objs=3, lb=self.lb, ub=self.ub)
69+
self.run_algorithm(algo)
70+
self.run_compiled_algorithm(algo)
71+
72+
def test_moead(self):
73+
algo = MOEAD(pop_size=self.pop_size, n_objs=3, lb=self.lb, ub=self.ub)
74+
self.run_algorithm(algo)
75+
self.run_compiled_algorithm(algo)
76+
77+
def test_hype(self):
78+
algo = HypE(pop_size=self.pop_size, n_objs=3, lb=self.lb, ub=self.ub)
79+
self.run_algorithm(algo)
80+
self.run_compiled_algorithm(algo)
81+
82+
def test_rveaa(self):
83+
algo = RVEAa(pop_size=self.pop_size, n_objs=3, lb=self.lb, ub=self.ub)
84+
self.run_algorithm(algo)
85+
self.run_compiled_algorithm(algo)

0 commit comments

Comments
 (0)