Skip to content
This repository was archived by the owner on Dec 11, 2022. It is now read-only.

Commit a4662d1

Browse files
committed
Enable multi-process training with distributed Coach.
1 parent c02333b commit a4662d1

File tree

9 files changed

+167
-108
lines changed

9 files changed

+167
-108
lines changed

rl_coach/architectures/tensorflow_components/distributed_tf_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_ind
7474
return server.target, device
7575

7676

77-
def create_monitored_session(target: tf.train.Server, task_index: int,
78-
checkpoint_dir: str, checkpoint_save_secs: int, config: tf.ConfigProto=None) -> tf.Session:
77+
def create_monitored_session(target: tf.train.Server, task_index: int, checkpoint_dir: str, checkpoint_save_secs: int,
78+
scaffold: tf.train.Scaffold, config: tf.ConfigProto=None) -> tf.Session:
7979
"""
8080
Create a monitored session for the worker
8181
:param target: the target string for the tf.Session

rl_coach/architectures/tensorflow_components/savers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, name):
2828
# if graph is finalized, savers must have already already been added. This happens
2929
# in the case of a MonitoredSession
3030
self._variables = tf.global_variables()
31-
31+
3232
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
3333
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
3434
self._variables = [v for v in self._variables if '/target' not in v.name]

rl_coach/base_parameters.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,8 @@ def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_on
583583

584584

585585
class DistributedTaskParameters(TaskParameters):
586-
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
587-
task_index: int, evaluate_only: int=None, num_tasks: int=None,
586+
def __init__(self, framework_type: Frameworks=None, parameters_server_hosts: str=None, worker_hosts: str=None,
587+
job_type: str=None, task_index: int=None, evaluate_only: int=None, num_tasks: int=None,
588588
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
589589
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
590590
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):

rl_coach/coach.py

+53-82
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from multiprocessing.managers import BaseManager
3535
import subprocess
3636
from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager
37-
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir
37+
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, \
38+
get_base_dir, start_multi_threaded_learning
3839
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
3940
from rl_coach.environments.environment import SingleLevelSelection
4041
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
@@ -87,6 +88,17 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
8788

8889
def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
8990
ckpt_inside_container = "/checkpoint"
91+
non_dist_task_parameters = TaskParameters(
92+
framework_type=args.framework,
93+
evaluate_only=args.evaluate,
94+
experiment_path=args.experiment_path,
95+
seed=args.seed,
96+
use_cpu=args.use_cpu,
97+
checkpoint_save_secs=args.checkpoint_save_secs,
98+
checkpoint_save_dir=args.checkpoint_save_dir,
99+
export_onnx_graph=args.export_onnx_graph,
100+
apply_stop_condition=args.apply_stop_condition
101+
)
90102

91103
memory_backend_params = None
92104
if args.memory_backend_params:
@@ -102,15 +114,18 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
102114
graph_manager.data_store_params = data_store_params
103115

104116
if args.distributed_coach_run_type == RunType.TRAINER:
117+
if not args.distributed_training:
118+
task_parameters = non_dist_task_parameters
105119
task_parameters.checkpoint_save_dir = ckpt_inside_container
106120
training_worker(
107121
graph_manager=graph_manager,
108122
task_parameters=task_parameters,
123+
args=args,
109124
is_multi_node_test=args.is_multi_node_test
110125
)
111126

112127
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
113-
task_parameters.checkpoint_restore_dir = ckpt_inside_container
128+
non_dist_task_parameters.checkpoint_restore_dir = ckpt_inside_container
114129

115130
data_store = None
116131
if args.data_store_params:
@@ -120,7 +135,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
120135
graph_manager=graph_manager,
121136
data_store=data_store,
122137
num_workers=args.num_workers,
123-
task_parameters=task_parameters
138+
task_parameters=non_dist_task_parameters
124139
)
125140

126141

@@ -552,6 +567,11 @@ def get_argument_parser(self) -> argparse.ArgumentParser:
552567
parser.add_argument('-dc', '--distributed_coach',
553568
help="(flag) Use distributed Coach.",
554569
action='store_true')
570+
parser.add_argument('-dt', '--distributed_training',
571+
help="(flag) Use distributed training with Coach."
572+
"Used only with --distributed_coach flag."
573+
"Ignored if --distributed_coach flag is not used.",
574+
action='store_true')
555575
parser.add_argument('-dcp', '--distributed_coach_config_path',
556576
help="(string) Path to config file when using distributed rollout workers."
557577
"Only distributed Coach parameters should be provided through this config file."
@@ -607,18 +627,31 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp
607627
atexit.register(logger.summarize_experiment)
608628
screen.change_terminal_title(args.experiment_name)
609629

610-
task_parameters = TaskParameters(
611-
framework_type=args.framework,
612-
evaluate_only=args.evaluate,
613-
experiment_path=args.experiment_path,
614-
seed=args.seed,
615-
use_cpu=args.use_cpu,
616-
checkpoint_save_secs=args.checkpoint_save_secs,
617-
checkpoint_restore_dir=args.checkpoint_restore_dir,
618-
checkpoint_save_dir=args.checkpoint_save_dir,
619-
export_onnx_graph=args.export_onnx_graph,
620-
apply_stop_condition=args.apply_stop_condition
621-
)
630+
if args.num_workers == 1:
631+
task_parameters = TaskParameters(
632+
framework_type=args.framework,
633+
evaluate_only=args.evaluate,
634+
experiment_path=args.experiment_path,
635+
seed=args.seed,
636+
use_cpu=args.use_cpu,
637+
checkpoint_save_secs=args.checkpoint_save_secs,
638+
checkpoint_restore_dir=args.checkpoint_restore_dir,
639+
checkpoint_save_dir=args.checkpoint_save_dir,
640+
export_onnx_graph=args.export_onnx_graph,
641+
apply_stop_condition=args.apply_stop_condition
642+
)
643+
else:
644+
task_parameters = DistributedTaskParameters(
645+
framework_type=args.framework,
646+
use_cpu=args.use_cpu,
647+
num_training_tasks=args.num_workers,
648+
experiment_path=args.experiment_path,
649+
checkpoint_save_secs=args.checkpoint_save_secs,
650+
checkpoint_restore_dir=args.checkpoint_restore_dir,
651+
checkpoint_save_dir=args.checkpoint_save_dir,
652+
export_onnx_graph=args.export_onnx_graph,
653+
apply_stop_condition=args.apply_stop_condition
654+
)
622655

623656
# open dashboard
624657
if args.open_dashboard:
@@ -633,78 +666,16 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp
633666

634667
# Single-threaded runs
635668
if args.num_workers == 1:
636-
self.start_single_threaded(task_parameters, graph_manager, args)
669+
self.start_single_threaded_learning(task_parameters, graph_manager, args)
637670
else:
638-
self.start_multi_threaded(graph_manager, args)
671+
global start_graph
672+
start_multi_threaded_learning(start_graph, (graph_manager, task_parameters),
673+
task_parameters, graph_manager, args)
639674

640-
def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
675+
def start_single_threaded_learning(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
641676
# Start the training or evaluation
642677
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
643678

644-
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
645-
total_tasks = args.num_workers
646-
if args.evaluation_worker:
647-
total_tasks += 1
648-
649-
ps_hosts = "localhost:{}".format(get_open_port())
650-
worker_hosts = ",".join(["localhost:{}".format(get_open_port()) for i in range(total_tasks)])
651-
652-
# Shared memory
653-
class CommManager(BaseManager):
654-
pass
655-
CommManager.register('SharedMemoryScratchPad', SharedMemoryScratchPad, exposed=['add', 'get', 'internal_call'])
656-
comm_manager = CommManager()
657-
comm_manager.start()
658-
shared_memory_scratchpad = comm_manager.SharedMemoryScratchPad()
659-
660-
def start_distributed_task(job_type, task_index, evaluation_worker=False,
661-
shared_memory_scratchpad=shared_memory_scratchpad):
662-
task_parameters = DistributedTaskParameters(
663-
framework_type=args.framework,
664-
parameters_server_hosts=ps_hosts,
665-
worker_hosts=worker_hosts,
666-
job_type=job_type,
667-
task_index=task_index,
668-
evaluate_only=0 if evaluation_worker else None, # 0 value for evaluation worker as it should run infinitely
669-
use_cpu=args.use_cpu,
670-
num_tasks=total_tasks, # training tasks + 1 evaluation task
671-
num_training_tasks=args.num_workers,
672-
experiment_path=args.experiment_path,
673-
shared_memory_scratchpad=shared_memory_scratchpad,
674-
seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed
675-
checkpoint_save_secs=args.checkpoint_save_secs,
676-
checkpoint_restore_dir=args.checkpoint_restore_dir,
677-
checkpoint_save_dir=args.checkpoint_save_dir,
678-
export_onnx_graph=args.export_onnx_graph,
679-
apply_stop_condition=args.apply_stop_condition
680-
)
681-
# we assume that only the evaluation workers are rendering
682-
graph_manager.visualization_parameters.render = args.render and evaluation_worker
683-
p = Process(target=start_graph, args=(graph_manager, task_parameters))
684-
# p.daemon = True
685-
p.start()
686-
return p
687-
688-
# parameter server
689-
parameter_server = start_distributed_task("ps", 0)
690-
691-
# training workers
692-
# wait a bit before spawning the non chief workers in order to make sure the session is already created
693-
workers = []
694-
workers.append(start_distributed_task("worker", 0))
695-
time.sleep(2)
696-
for task_index in range(1, args.num_workers):
697-
workers.append(start_distributed_task("worker", task_index))
698-
699-
# evaluation worker
700-
if args.evaluation_worker or args.render:
701-
evaluation_worker = start_distributed_task("worker", args.num_workers, evaluation_worker=True)
702-
703-
# wait for all workers
704-
[w.join() for w in workers]
705-
if args.evaluation_worker:
706-
evaluation_worker.terminate()
707-
708679

709680
def main():
710681
launcher = CoachLauncher()

rl_coach/data_stores/s3_data_store.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,20 @@ def save_to_store(self):
8888
# Acquire lock
8989
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
9090

91+
ckpt_state_filename = CheckpointStateFile.checkpoint_state_filename
9192
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
9293
if state_file.exists():
9394
ckpt_state = state_file.read()
95+
ckpt_name_prefix = ckpt_state.name
96+
97+
if ckpt_state_filename is not None and ckpt_name_prefix is not None:
9498
checkpoint_file = None
9599
for root, dirs, files in os.walk(self.params.checkpoint_dir):
96100
for filename in files:
97-
if filename == CheckpointStateFile.checkpoint_state_filename:
101+
if filename == ckpt_state_filename:
98102
checkpoint_file = (root, filename)
99103
continue
100-
if filename.startswith(ckpt_state.name):
104+
if filename.startswith(ckpt_name_prefix):
101105
abs_name = os.path.abspath(os.path.join(root, filename))
102106
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
103107
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
@@ -131,6 +135,8 @@ def load_from_store(self):
131135
"""
132136
try:
133137
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
138+
ckpt_state_filename = state_file.filename
139+
ckpt_state_file_path = state_file.path
134140

135141
# wait until lock is removed
136142
while True:
@@ -139,7 +145,7 @@ def load_from_store(self):
139145
if next(objects, None) is None:
140146
try:
141147
# fetch checkpoint state file from S3
142-
self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path)
148+
self.mc.fget_object(self.params.bucket_name, ckpt_state_filename, ckpt_state_file_path)
143149
except Exception as e:
144150
continue
145151
break
@@ -156,10 +162,12 @@ def load_from_store(self):
156162
)
157163
except Exception as e:
158164
pass
165+
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
166+
ckpt_state = state_file.read()
167+
ckpt_name_prefix = ckpt_state.name
159168

160-
checkpoint_state = state_file.read()
161-
if checkpoint_state is not None:
162-
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
169+
if ckpt_name_prefix is not None:
170+
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=ckpt_name_prefix, recursive=True)
163171
for obj in objects:
164172
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
165173
if not os.path.exists(filename):

rl_coach/graph_managers/graph_manager.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,15 @@ def _create_session_tf(self, task_parameters: TaskParameters):
226226
else:
227227
checkpoint_dir = task_parameters.checkpoint_save_dir
228228

229+
self.checkpoint_saver = tf.train.Saver()
230+
scaffold = tf.train.Scaffold(saver=self.checkpoint_saver)
231+
229232
self.sess = create_monitored_session(target=task_parameters.worker_target,
230233
task_index=task_parameters.task_index,
231234
checkpoint_dir=checkpoint_dir,
232235
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
233-
config=config)
236+
config=config,
237+
scaffold=scaffold)
234238
# set the session for all the modules
235239
self.set_session(self.sess)
236240
else:
@@ -258,9 +262,11 @@ def create_session(self, task_parameters: TaskParameters):
258262
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
259263

260264
# Create parameter saver
261-
self.checkpoint_saver = SaverCollection()
262-
for level in self.level_managers:
263-
self.checkpoint_saver.update(level.collect_savers())
265+
if not isinstance(task_parameters, DistributedTaskParameters):
266+
self.checkpoint_saver = SaverCollection()
267+
for level in self.level_managers:
268+
self.checkpoint_saver.update(level.collect_savers())
269+
264270
# restore from checkpoint if given
265271
self.restore_checkpoint()
266272

@@ -540,8 +546,9 @@ def improve(self):
540546
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
541547
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
542548
self.train_and_act(self.steps_between_evaluation_periods)
543-
if self.evaluate(self.evaluation_steps):
544-
break
549+
if self.task_parameters.task_index == 0 or self.task_parameters.task_index is None:
550+
if self.evaluate(self.evaluation_steps):
551+
break
545552

546553
def restore_checkpoint(self):
547554
self.verify_graph_was_created()
@@ -599,7 +606,9 @@ def save_checkpoint(self):
599606
if not isinstance(self.task_parameters, DistributedTaskParameters):
600607
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
601608
else:
602-
saved_checkpoint_path = checkpoint_path
609+
# FIXME: Explicitly managing Saver inside monitored training session is not recommended.
610+
# https://github.com/tensorflow/tensorflow/issues/8425#issuecomment-286927528.
611+
saved_checkpoint_path = self.checkpoint_saver.save(self.sess._tf_sess(), checkpoint_path)
603612

604613
# this is required in order for agents to save additional information like a DND for example
605614
[manager.save_checkpoint(checkpoint_name) for manager in self.level_managers]

rl_coach/tests/test_dist_coach.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def get_tests():
7373
"""
7474
tests = [
7575
'rl_coach/coach.py -p CartPole_ClippedPPO -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1',
76-
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1'
76+
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1',
77+
'rl_coach/coach.py -p CartPole_ClippedPPO -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1 -n 2',
78+
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1 -n 2'
7779
]
7880
return tests
7981

0 commit comments

Comments
 (0)