11import warnings
2- from typing import Dict , List , Tuple
2+ import weakref
3+ from enum import IntEnum
4+ from typing import Dict , List , NamedTuple , Tuple
35
46import torch
5- from torch ._C ._functorch import get_unwrapped , is_batchedtensor
67
78from evox .core import Monitor , Mutable
89from evox .operators .selection import non_dominate_rank
10+ from evox .utils import register_vmap_op
911
1012try :
1113 from evox .vis_tools import plot
@@ -24,6 +26,57 @@ def unique(x: torch.Tensor, dim=0):
2426 return unique , inverse , counts , index
2527
2628
29+ class HistoryType (IntEnum ):
30+ """History type for the monitor."""
31+
32+ FITNESS = 0
33+ SOLUTION = 1
34+ AUXILIARY = 2
35+
36+
37+ class MonitorHisotry (NamedTuple ):
38+ fit_history : List [torch .Tensor ]
39+ sol_history : List [torch .Tensor ]
40+ aux_history : List [torch .Tensor ]
41+
42+
43+ __monitor_history__ : Dict [int , MonitorHisotry ] = {}
44+
45+
46+ def _fake_data_sink (monitor_id : int , data : torch .Tensor , data_type : int , token : torch .Tensor ) -> torch .Tensor :
47+ return token .new_empty (token .size ())
48+
49+
50+ def _fake_vmap_data_sink (
51+ monitor_id : int ,
52+ data : torch .Tensor ,
53+ data_type : int ,
54+ token : torch .Tensor ,
55+ ) -> torch .Tensor :
56+ return token .new_empty (token .size ())
57+
58+
59+ def _vmap_data_sink (
60+ monitor_id : int ,
61+ data : torch .Tensor ,
62+ data_type : int ,
63+ token : torch .Tensor ,
64+ ) -> torch .Tensor :
65+ __monitor_history__ [monitor_id ][data_type ].append (data )
66+ return token + 1
67+
68+
69+ @register_vmap_op (fake_fn = _fake_data_sink , vmap_fn = _vmap_data_sink , fake_vmap_fn = _fake_vmap_data_sink )
70+ def _data_sink (monitor_id : int , data : torch .Tensor , data_type : int , token : torch .Tensor ) -> torch .Tensor :
71+ """Record the data into the monitor history log.
72+
73+ This function uses the provided token to establish data dependencies between
74+ successive function calls, ensuring proper tracking and ordering of monitored values.
75+ """
76+ __monitor_history__ [monitor_id ][data_type ].append (data )
77+ return token + 1
78+
79+
2780class EvalMonitor (Monitor ):
2881 """Evaluation monitor.
2982 Used for both single-objective and multi-objective workflow.
@@ -32,10 +85,6 @@ class EvalMonitor(Monitor):
3285 Moreover, it can also record the best solution or the pareto front on-the-fly.
3386 """
3487
35- fitness_history : List [torch .Tensor ]
36- solution_history : List [torch .Tensor ]
37- auxiliary : List [Dict [str , torch .Tensor ]]
38-
3988 def __init__ (
4089 self ,
4190 multi_obj : bool = False ,
@@ -44,6 +93,7 @@ def __init__(
4493 full_pop_history : bool = False ,
4594 topk : int = 1 ,
4695 device : torch .device | None = None ,
96+ history_device : torch .device | None = None ,
4797 ):
4898 """Initialize the monitor.
4999
@@ -57,21 +107,67 @@ def __init__(
57107 """
58108 super ().__init__ ()
59109 device = torch .get_default_device () if device is None else device
110+ history_device = torch .device ("cpu" ) if history_device is None else history_device
60111 self .multi_obj = multi_obj
61112 self .full_fit_history = full_fit_history
62113 self .full_sol_history = full_sol_history
63114 self .full_pop_history = full_pop_history
64115 self .opt_direction = 1
65116 self .topk = topk
66117 self .device = device
118+ self .history_device = history_device
119+ self .aux_keys = []
67120 # mutable
68121 self .latest_solution = Mutable (torch .empty (0 , device = device ))
69122 self .latest_fitness = Mutable (torch .empty (0 , device = device ))
70123 self .topk_solutions = Mutable (torch .empty (0 , device = device ))
71124 self .topk_fitness = Mutable (torch .empty (0 , device = device ))
72- self .fitness_history = []
73- self .solution_history = []
74- self .auxiliary = []
125+ self ._id_ = id (self )
126+ self ._token = Mutable (torch .tensor (0 , device = device ))
127+ __monitor_history__ [self ._id_ ] = MonitorHisotry ([], [], [])
128+ weakref .finalize (
129+ self ,
130+ __monitor_history__ .pop ,
131+ self ._id_ ,
132+ None ,
133+ )
134+
135+ @property
136+ def fitness_history (self ) -> List [torch .Tensor ]:
137+ return __monitor_history__ [self ._id_ ][HistoryType .FITNESS ]
138+
139+ @property
140+ def fit_history (self ) -> List [torch .Tensor ]:
141+ # alias for fitness_history
142+ return self .fitness_history
143+
144+ @property
145+ def solution_history (self ) -> List [torch .Tensor ]:
146+ return __monitor_history__ [self ._id_ ][HistoryType .SOLUTION ]
147+
148+ @property
149+ def sol_history (self ) -> List [torch .Tensor ]:
150+ # alias for solution_history
151+ return self .solution_history
152+
153+ @property
154+ def aux_history (self ) -> Dict [str , List [torch .Tensor ]]:
155+ # alias for auxiliary_history
156+ return self .auxiliary_history
157+
158+ @property
159+ def auxiliary_history (self ) -> Dict [str , List [torch .Tensor ]]:
160+ raw_aux_history = __monitor_history__ [self ._id_ ][HistoryType .AUXILIARY ]
161+ n_keys = len (self .aux_keys )
162+ if n_keys == 0 :
163+ return {}
164+
165+ assert len (raw_aux_history ) % n_keys == 0
166+ aux_history = {}
167+ for i , key in enumerate (self .aux_keys ):
168+ aux_history [key ] = raw_aux_history [i ::n_keys ]
169+
170+ return aux_history
75171
76172 def set_config (self , ** config ):
77173 if "multi_obj" in config :
@@ -88,7 +184,14 @@ def set_config(self, **config):
88184
89185 def record_auxiliary (self , aux : Dict [str , torch .Tensor ]):
90186 if self .full_pop_history :
91- self .auxiliary .append (aux )
187+ if len (self .aux_keys ) == 0 :
188+ self .aux_keys = list (aux .keys ())
189+
190+ for key in self .aux_keys :
191+ assert isinstance (aux [key ], torch .Tensor )
192+ self ._token = _data_sink (
193+ self ._id_ , aux [key ].to (self .history_device , non_blocking = True ), HistoryType .AUXILIARY , self ._token
194+ )
92195
93196 def post_ask (self , candidate_solution : torch .Tensor ):
94197 self .latest_solution = candidate_solution
@@ -123,18 +226,15 @@ def pre_tell(self, fitness: torch.Tensor):
123226 if self .full_fit_history or self .full_sol_history :
124227 self .record_history ()
125228
126- @torch .compiler .disable
127229 def record_history (self ):
128230 if self .full_sol_history :
129- latest_solution = self .latest_solution .to (self .device )
130- if is_batchedtensor (self .latest_solution ):
131- latest_solution = get_unwrapped (latest_solution )
132- self .solution_history .append (latest_solution )
231+ latest_solution = self .latest_solution .to (self .history_device , non_blocking = True )
232+ assert isinstance (latest_solution , torch .Tensor )
233+ self ._token = _data_sink (self ._id_ , latest_solution , HistoryType .SOLUTION , self ._token )
133234 if self .full_fit_history :
134- latest_fitness = self .latest_fitness .to (self .device )
135- if is_batchedtensor (self .latest_fitness ):
136- latest_fitness = get_unwrapped (latest_fitness )
137- self .fitness_history .append (latest_fitness )
235+ latest_fitness = self .latest_fitness .to (self .history_device , non_blocking = True )
236+ assert isinstance (latest_fitness , torch .Tensor )
237+ self ._token = _data_sink (self ._id_ , latest_fitness , HistoryType .FITNESS , self ._token )
138238
139239 def get_latest_fitness (self ) -> torch .Tensor :
140240 """Get the fitness values from the latest iteration."""
@@ -232,7 +332,7 @@ def plot(self, problem_pf=None, source="eval", **kwargs):
232332 When "pop", the fitness from the population inside the algorithm will be plotted, representing what the algorithm sees.
233333 :param kwargs: Additional arguments for the plot.
234334 """
235- if not self .fitness_history and not self .auxiliary :
335+ if not self .fitness_history and not self .aux_history :
236336 warnings .warn ("No fitness history recorded, return None" )
237337 return
238338
@@ -241,7 +341,7 @@ def plot(self, problem_pf=None, source="eval", **kwargs):
241341 return
242342
243343 if source == "pop" :
244- fitness_history = [ aux [ "fit" ] for aux in self . auxiliary ]
344+ fitness_history = self . aux_history [ "fit" ]
245345 elif source == "eval" :
246346 fitness_history = self .get_fitness_history ()
247347 else :
0 commit comments