Skip to content

Commit 2576c12

Browse files
committed
Refactor EvalMonitor to use centralized history management
- Updated `EvalMonitor` to use custom_op - Updated `EvalMonitor` to use device-specific history storage. Format code
1 parent 9cf444d commit 2576c12

File tree

1 file changed

+121
-21
lines changed

1 file changed

+121
-21
lines changed

src/evox/workflows/eval_monitor.py

Lines changed: 121 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import warnings
2-
from typing import Dict, List, Tuple
2+
import weakref
3+
from enum import IntEnum
4+
from typing import Dict, List, NamedTuple, Tuple
35

46
import torch
5-
from torch._C._functorch import get_unwrapped, is_batchedtensor
67

78
from evox.core import Monitor, Mutable
89
from evox.operators.selection import non_dominate_rank
10+
from evox.utils import register_vmap_op
911

1012
try:
1113
from evox.vis_tools import plot
@@ -24,6 +26,57 @@ def unique(x: torch.Tensor, dim=0):
2426
return unique, inverse, counts, index
2527

2628

29+
class HistoryType(IntEnum):
30+
"""History type for the monitor."""
31+
32+
FITNESS = 0
33+
SOLUTION = 1
34+
AUXILIARY = 2
35+
36+
37+
class MonitorHisotry(NamedTuple):
38+
fit_history: List[torch.Tensor]
39+
sol_history: List[torch.Tensor]
40+
aux_history: List[torch.Tensor]
41+
42+
43+
__monitor_history__: Dict[int, MonitorHisotry] = {}
44+
45+
46+
def _fake_data_sink(monitor_id: int, data: torch.Tensor, data_type: int, token: torch.Tensor) -> torch.Tensor:
47+
return token.new_empty(token.size())
48+
49+
50+
def _fake_vmap_data_sink(
51+
monitor_id: int,
52+
data: torch.Tensor,
53+
data_type: int,
54+
token: torch.Tensor,
55+
) -> torch.Tensor:
56+
return token.new_empty(token.size())
57+
58+
59+
def _vmap_data_sink(
60+
monitor_id: int,
61+
data: torch.Tensor,
62+
data_type: int,
63+
token: torch.Tensor,
64+
) -> torch.Tensor:
65+
__monitor_history__[monitor_id][data_type].append(data)
66+
return token + 1
67+
68+
69+
@register_vmap_op(fake_fn=_fake_data_sink, vmap_fn=_vmap_data_sink, fake_vmap_fn=_fake_vmap_data_sink)
70+
def _data_sink(monitor_id: int, data: torch.Tensor, data_type: int, token: torch.Tensor) -> torch.Tensor:
71+
"""Record the data into the monitor history log.
72+
73+
This function uses the provided token to establish data dependencies between
74+
successive function calls, ensuring proper tracking and ordering of monitored values.
75+
"""
76+
__monitor_history__[monitor_id][data_type].append(data)
77+
return token + 1
78+
79+
2780
class EvalMonitor(Monitor):
2881
"""Evaluation monitor.
2982
Used for both single-objective and multi-objective workflow.
@@ -32,10 +85,6 @@ class EvalMonitor(Monitor):
3285
Moreover, it can also record the best solution or the pareto front on-the-fly.
3386
"""
3487

35-
fitness_history: List[torch.Tensor]
36-
solution_history: List[torch.Tensor]
37-
auxiliary: List[Dict[str, torch.Tensor]]
38-
3988
def __init__(
4089
self,
4190
multi_obj: bool = False,
@@ -44,6 +93,7 @@ def __init__(
4493
full_pop_history: bool = False,
4594
topk: int = 1,
4695
device: torch.device | None = None,
96+
history_device: torch.device | None = None,
4797
):
4898
"""Initialize the monitor.
4999
@@ -57,21 +107,67 @@ def __init__(
57107
"""
58108
super().__init__()
59109
device = torch.get_default_device() if device is None else device
110+
history_device = torch.device("cpu") if history_device is None else history_device
60111
self.multi_obj = multi_obj
61112
self.full_fit_history = full_fit_history
62113
self.full_sol_history = full_sol_history
63114
self.full_pop_history = full_pop_history
64115
self.opt_direction = 1
65116
self.topk = topk
66117
self.device = device
118+
self.history_device = history_device
119+
self.aux_keys = []
67120
# mutable
68121
self.latest_solution = Mutable(torch.empty(0, device=device))
69122
self.latest_fitness = Mutable(torch.empty(0, device=device))
70123
self.topk_solutions = Mutable(torch.empty(0, device=device))
71124
self.topk_fitness = Mutable(torch.empty(0, device=device))
72-
self.fitness_history = []
73-
self.solution_history = []
74-
self.auxiliary = []
125+
self._id_ = id(self)
126+
self._token = Mutable(torch.tensor(0, device=device))
127+
__monitor_history__[self._id_] = MonitorHisotry([], [], [])
128+
weakref.finalize(
129+
self,
130+
__monitor_history__.pop,
131+
self._id_,
132+
None,
133+
)
134+
135+
@property
136+
def fitness_history(self) -> List[torch.Tensor]:
137+
return __monitor_history__[self._id_][HistoryType.FITNESS]
138+
139+
@property
140+
def fit_history(self) -> List[torch.Tensor]:
141+
# alias for fitness_history
142+
return self.fitness_history
143+
144+
@property
145+
def solution_history(self) -> List[torch.Tensor]:
146+
return __monitor_history__[self._id_][HistoryType.SOLUTION]
147+
148+
@property
149+
def sol_history(self) -> List[torch.Tensor]:
150+
# alias for solution_history
151+
return self.solution_history
152+
153+
@property
154+
def aux_history(self) -> Dict[str, List[torch.Tensor]]:
155+
# alias for auxiliary_history
156+
return self.auxiliary_history
157+
158+
@property
159+
def auxiliary_history(self) -> Dict[str, List[torch.Tensor]]:
160+
raw_aux_history = __monitor_history__[self._id_][HistoryType.AUXILIARY]
161+
n_keys = len(self.aux_keys)
162+
if n_keys == 0:
163+
return {}
164+
165+
assert len(raw_aux_history) % n_keys == 0
166+
aux_history = {}
167+
for i, key in enumerate(self.aux_keys):
168+
aux_history[key] = raw_aux_history[i::n_keys]
169+
170+
return aux_history
75171

76172
def set_config(self, **config):
77173
if "multi_obj" in config:
@@ -88,7 +184,14 @@ def set_config(self, **config):
88184

89185
def record_auxiliary(self, aux: Dict[str, torch.Tensor]):
90186
if self.full_pop_history:
91-
self.auxiliary.append(aux)
187+
if len(self.aux_keys) == 0:
188+
self.aux_keys = list(aux.keys())
189+
190+
for key in self.aux_keys:
191+
assert isinstance(aux[key], torch.Tensor)
192+
self._token = _data_sink(
193+
self._id_, aux[key].to(self.history_device, non_blocking=True), HistoryType.AUXILIARY, self._token
194+
)
92195

93196
def post_ask(self, candidate_solution: torch.Tensor):
94197
self.latest_solution = candidate_solution
@@ -123,18 +226,15 @@ def pre_tell(self, fitness: torch.Tensor):
123226
if self.full_fit_history or self.full_sol_history:
124227
self.record_history()
125228

126-
@torch.compiler.disable
127229
def record_history(self):
128230
if self.full_sol_history:
129-
latest_solution = self.latest_solution.to(self.device)
130-
if is_batchedtensor(self.latest_solution):
131-
latest_solution = get_unwrapped(latest_solution)
132-
self.solution_history.append(latest_solution)
231+
latest_solution = self.latest_solution.to(self.history_device, non_blocking=True)
232+
assert isinstance(latest_solution, torch.Tensor)
233+
self._token = _data_sink(self._id_, latest_solution, HistoryType.SOLUTION, self._token)
133234
if self.full_fit_history:
134-
latest_fitness = self.latest_fitness.to(self.device)
135-
if is_batchedtensor(self.latest_fitness):
136-
latest_fitness = get_unwrapped(latest_fitness)
137-
self.fitness_history.append(latest_fitness)
235+
latest_fitness = self.latest_fitness.to(self.history_device, non_blocking=True)
236+
assert isinstance(latest_fitness, torch.Tensor)
237+
self._token = _data_sink(self._id_, latest_fitness, HistoryType.FITNESS, self._token)
138238

139239
def get_latest_fitness(self) -> torch.Tensor:
140240
"""Get the fitness values from the latest iteration."""
@@ -232,7 +332,7 @@ def plot(self, problem_pf=None, source="eval", **kwargs):
232332
When "pop", the fitness from the population inside the algorithm will be plotted, representing what the algorithm sees.
233333
:param kwargs: Additional arguments for the plot.
234334
"""
235-
if not self.fitness_history and not self.auxiliary:
335+
if not self.fitness_history and not self.aux_history:
236336
warnings.warn("No fitness history recorded, return None")
237337
return
238338

@@ -241,7 +341,7 @@ def plot(self, problem_pf=None, source="eval", **kwargs):
241341
return
242342

243343
if source == "pop":
244-
fitness_history = [aux["fit"] for aux in self.auxiliary]
344+
fitness_history = self.aux_history["fit"]
245345
elif source == "eval":
246346
fitness_history = self.get_fitness_history()
247347
else:

0 commit comments

Comments
 (0)