5
5
import argparse
6
6
import importlib .util
7
7
import os
8
+ import subprocess
8
9
import sys
9
- from time import sleep
10
+ from time import sleep , time
10
11
11
12
import ray
12
13
import util
57
58
PYTHON_EXEC = "./isaaclab.sh -p"
58
59
WORKFLOW = "scripts/reinforcement_learning/rl_games/train.py"
59
60
NUM_WORKERS_PER_NODE = 1 # needed for local parallelism
61
+ PROCESS_RESPONSE_TIMEOUT = 200.0 # seconds to wait before killing the process when it stops responding
62
+ MAX_LINES_TO_SEARCH_EXPERIMENT_LOGS = 1000 # maximum number of lines to read from the training process logs
63
+ MAX_LOG_EXTRACTION_ERRORS = 2 # maximum allowed LogExtractionErrors before we abort the whole training
60
64
61
65
62
66
class IsaacLabTuneTrainable (tune .Trainable ):
@@ -70,6 +74,7 @@ class IsaacLabTuneTrainable(tune.Trainable):
70
74
def setup (self , config : dict ) -> None :
71
75
"""Get the invocation command, return quick for easy scheduling."""
72
76
self .data = None
77
+ self .time_since_last_proc_response = 0.0
73
78
self .invoke_cmd = util .get_invocation_command_from_cfg (cfg = config , python_cmd = PYTHON_EXEC , workflow = WORKFLOW )
74
79
print (f"[INFO]: Recovered invocation with { self .invoke_cmd } " )
75
80
self .experiment = None
@@ -84,12 +89,21 @@ def step(self) -> dict:
84
89
# When including this as first step instead of setup, experiments get scheduled faster
85
90
# Don't want to block the scheduler while the experiment spins up
86
91
print (f"[INFO]: Invoking experiment as first step with { self .invoke_cmd } ..." )
87
- experiment = util .execute_job (
88
- self .invoke_cmd ,
89
- identifier_string = "" ,
90
- extract_experiment = True ,
91
- persistent_dir = BASE_DIR ,
92
- )
92
+ try :
93
+ experiment = util .execute_job (
94
+ self .invoke_cmd ,
95
+ identifier_string = "" ,
96
+ extract_experiment = True , # Keep this as True to return a valid dictionary
97
+ persistent_dir = BASE_DIR ,
98
+ max_lines_to_search_logs = MAX_LINES_TO_SEARCH_EXPERIMENT_LOGS ,
99
+ max_time_to_search_logs = PROCESS_RESPONSE_TIMEOUT ,
100
+ )
101
+ except util .LogExtractionError :
102
+ self .data = {
103
+ "LOG_EXTRACTION_ERROR_STOPPER_FLAG" : True ,
104
+ "done" : True ,
105
+ }
106
+ return self .data
93
107
self .experiment = experiment
94
108
print (f"[INFO]: Tuner recovered experiment info { experiment } " )
95
109
self .proc = experiment ["proc" ]
@@ -109,11 +123,35 @@ def step(self) -> dict:
109
123
110
124
while data is None :
111
125
data = util .load_tensorboard_logs (self .tensorboard_logdir )
126
+ proc_status = self .proc .poll ()
127
+ if proc_status is not None :
128
+ break
112
129
sleep (2 ) # Lazy report metrics to avoid performance overhead
113
130
114
131
if self .data is not None :
115
- while util ._dicts_equal (data , self .data ):
132
+ data_ = {k : v for k , v in data .items () if k != "done" }
133
+ self_data_ = {k : v for k , v in self .data .items () if k != "done" }
134
+ unresponsiveness_start_time = time ()
135
+ while util ._dicts_equal (data_ , self_data_ ):
136
+ self .time_since_last_proc_response = time () - unresponsiveness_start_time
116
137
data = util .load_tensorboard_logs (self .tensorboard_logdir )
138
+ data_ = {k : v for k , v in data .items () if k != "done" }
139
+ proc_status = self .proc .poll ()
140
+ if proc_status is not None :
141
+ break
142
+ if self .time_since_last_proc_response > PROCESS_RESPONSE_TIMEOUT :
143
+ self .time_since_last_proc_response = 0.0
144
+ print ("[WARNING]: Training workflow process is not responding, terminating..." )
145
+ self .proc .terminate ()
146
+ try :
147
+ self .proc .wait (timeout = 20 )
148
+ except subprocess .TimeoutExpired :
149
+ print ("[ERROR]: The process did not terminate within timeout duration." )
150
+ self .proc .kill ()
151
+ self .proc .wait ()
152
+ self .data = data
153
+ self .data ["done" ] = True
154
+ return self .data
117
155
sleep (2 ) # Lazy report metrics to avoid performance overhead
118
156
119
157
self .data = data
@@ -132,6 +170,39 @@ def default_resource_request(self):
132
170
)
133
171
134
172
173
+ class LogExtractionErrorStopper (tune .Stopper ):
174
+ """Stopper that stops all trials if multiple LogExtractionErrors occur.
175
+
176
+ Args:
177
+ max_errors: The maximum number of LogExtractionErrors allowed before terminating the experiment.
178
+ """
179
+
180
+ def __init__ (self , max_errors : int ):
181
+ self .max_errors = max_errors
182
+ self .error_count = 0
183
+
184
+ def __call__ (self , trial_id , result ):
185
+ """Increments the error count if trial has encountered a LogExtractionError.
186
+
187
+ It does not stop the trial based on the metrics, always returning False.
188
+ """
189
+ if result .get ("LOG_EXTRACTION_ERROR_STOPPER_FLAG" , False ):
190
+ self .error_count += 1
191
+ print (
192
+ f"[ERROR]: Encountered LogExtractionError { self .error_count } times. "
193
+ f"Maximum allowed is { self .max_errors } ."
194
+ )
195
+ return False
196
+
197
+ def stop_all (self ):
198
+ """Returns true if number of LogExtractionErrors exceeds the maximum allowed, terminating the experiment."""
199
+ if self .error_count > self .max_errors :
200
+ print ("[FATAL]: Encountered LogExtractionError more than allowed, aborting entire tuning run... " )
201
+ return True
202
+ else :
203
+ return False
204
+
205
+
135
206
def invoke_tuning_run (cfg : dict , args : argparse .Namespace ) -> None :
136
207
"""Invoke an Isaac-Ray tuning run.
137
208
@@ -175,6 +246,7 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
175
246
checkpoint_frequency = 0 , # Disable periodic checkpointing
176
247
checkpoint_at_end = False , # Disable final checkpoint
177
248
),
249
+ stop = LogExtractionErrorStopper (max_errors = MAX_LOG_EXTRACTION_ERRORS ),
178
250
)
179
251
180
252
elif args .run_mode == "remote" : # MLFlow, to MLFlow server
@@ -190,6 +262,7 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
190
262
storage_path = "/tmp/ray" ,
191
263
callbacks = [mlflow_callback ],
192
264
checkpoint_config = ray .train .CheckpointConfig (checkpoint_frequency = 0 , checkpoint_at_end = False ),
265
+ stop = LogExtractionErrorStopper (max_errors = MAX_LOG_EXTRACTION_ERRORS ),
193
266
)
194
267
else :
195
268
raise ValueError ("Unrecognized run mode." )
@@ -199,6 +272,8 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
199
272
IsaacLabTuneTrainable ,
200
273
param_space = cfg ,
201
274
tune_config = tune .TuneConfig (
275
+ metric = args .metric ,
276
+ mode = args .mode ,
202
277
search_alg = repeat_search ,
203
278
num_samples = args .num_samples ,
204
279
reuse_actors = True ,
@@ -306,8 +381,39 @@ def __init__(self, cfg: dict):
306
381
default = 3 ,
307
382
help = "How many times to repeat each hyperparameter config." ,
308
383
)
384
+ parser .add_argument (
385
+ "--process_response_timeout" ,
386
+ type = float ,
387
+ default = PROCESS_RESPONSE_TIMEOUT ,
388
+ help = "Training workflow process response timeout." ,
389
+ )
390
+ parser .add_argument (
391
+ "--max_lines_to_search_experiment_logs" ,
392
+ type = float ,
393
+ default = MAX_LINES_TO_SEARCH_EXPERIMENT_LOGS ,
394
+ help = "Max number of lines to search for experiment logs before terminating the training workflow process." ,
395
+ )
396
+ parser .add_argument (
397
+ "--max_log_extraction_errors" ,
398
+ type = float ,
399
+ default = MAX_LOG_EXTRACTION_ERRORS ,
400
+ help = "Max number number of LogExtractionError failures before we abort the whole tuning run." ,
401
+ )
309
402
310
403
args = parser .parse_args ()
404
+ PROCESS_RESPONSE_TIMEOUT = args .process_response_timeout
405
+ MAX_LINES_TO_SEARCH_EXPERIMENT_LOGS = int (args .max_lines_to_search_experiment_logs )
406
+ print (
407
+ "[INFO]: The max number of lines to search for experiment logs before (early) terminating the training "
408
+ f"workflow process is set to { MAX_LINES_TO_SEARCH_EXPERIMENT_LOGS } .\n "
409
+ "[INFO]: The process response timeout, used while updating tensorboard scalars and searching for "
410
+ f"experiment logs, is set to { PROCESS_RESPONSE_TIMEOUT } seconds."
411
+ )
412
+ MAX_LOG_EXTRACTION_ERRORS = int (args .max_log_extraction_errors )
413
+ print (
414
+ "[INFO]: Max number of LogExtractionError failures before we abort the whole tuning run is "
415
+ f"set to { MAX_LOG_EXTRACTION_ERRORS } .\n "
416
+ )
311
417
NUM_WORKERS_PER_NODE = args .num_workers_per_node
312
418
print (f"[INFO]: Using { NUM_WORKERS_PER_NODE } workers per node." )
313
419
if args .run_mode == "remote" :
0 commit comments