34
34
from multiprocessing .managers import BaseManager
35
35
import subprocess
36
36
from rl_coach .graph_managers .graph_manager import HumanPlayScheduleParameters , GraphManager
37
- from rl_coach .utils import list_all_presets , short_dynamic_import , get_open_port , SharedMemoryScratchPad , get_base_dir
37
+ from rl_coach .utils import list_all_presets , short_dynamic_import , get_open_port , SharedMemoryScratchPad , \
38
+ get_base_dir , start_multi_threaded_learning
38
39
from rl_coach .graph_managers .basic_rl_graph_manager import BasicRLGraphManager
39
40
from rl_coach .environments .environment import SingleLevelSelection
40
41
from rl_coach .memories .backend .redis import RedisPubSubMemoryBackendParameters
@@ -87,6 +88,17 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
87
88
88
89
def handle_distributed_coach_tasks (graph_manager , args , task_parameters ):
89
90
ckpt_inside_container = "/checkpoint"
91
+ non_dist_task_parameters = TaskParameters (
92
+ framework_type = args .framework ,
93
+ evaluate_only = args .evaluate ,
94
+ experiment_path = args .experiment_path ,
95
+ seed = args .seed ,
96
+ use_cpu = args .use_cpu ,
97
+ checkpoint_save_secs = args .checkpoint_save_secs ,
98
+ checkpoint_save_dir = args .checkpoint_save_dir ,
99
+ export_onnx_graph = args .export_onnx_graph ,
100
+ apply_stop_condition = args .apply_stop_condition
101
+ )
90
102
91
103
memory_backend_params = None
92
104
if args .memory_backend_params :
@@ -102,15 +114,18 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
102
114
graph_manager .data_store_params = data_store_params
103
115
104
116
if args .distributed_coach_run_type == RunType .TRAINER :
117
+ if not args .distributed_training :
118
+ task_parameters = non_dist_task_parameters
105
119
task_parameters .checkpoint_save_dir = ckpt_inside_container
106
120
training_worker (
107
121
graph_manager = graph_manager ,
108
122
task_parameters = task_parameters ,
123
+ args = args ,
109
124
is_multi_node_test = args .is_multi_node_test
110
125
)
111
126
112
127
if args .distributed_coach_run_type == RunType .ROLLOUT_WORKER :
113
- task_parameters .checkpoint_restore_dir = ckpt_inside_container
128
+ non_dist_task_parameters .checkpoint_restore_dir = ckpt_inside_container
114
129
115
130
data_store = None
116
131
if args .data_store_params :
@@ -120,7 +135,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
120
135
graph_manager = graph_manager ,
121
136
data_store = data_store ,
122
137
num_workers = args .num_workers ,
123
- task_parameters = task_parameters
138
+ task_parameters = non_dist_task_parameters
124
139
)
125
140
126
141
@@ -552,6 +567,11 @@ def get_argument_parser(self) -> argparse.ArgumentParser:
552
567
parser .add_argument ('-dc' , '--distributed_coach' ,
553
568
help = "(flag) Use distributed Coach." ,
554
569
action = 'store_true' )
570
+ parser .add_argument ('-dt' , '--distributed_training' ,
571
+ help = "(flag) Use distributed training with Coach."
572
+ "Used only with --distributed_coach flag."
573
+ "Ignored if --distributed_coach flag is not used." ,
574
+ action = 'store_true' )
555
575
parser .add_argument ('-dcp' , '--distributed_coach_config_path' ,
556
576
help = "(string) Path to config file when using distributed rollout workers."
557
577
"Only distributed Coach parameters should be provided through this config file."
@@ -607,18 +627,31 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp
607
627
atexit .register (logger .summarize_experiment )
608
628
screen .change_terminal_title (args .experiment_name )
609
629
610
- task_parameters = TaskParameters (
611
- framework_type = args .framework ,
612
- evaluate_only = args .evaluate ,
613
- experiment_path = args .experiment_path ,
614
- seed = args .seed ,
615
- use_cpu = args .use_cpu ,
616
- checkpoint_save_secs = args .checkpoint_save_secs ,
617
- checkpoint_restore_dir = args .checkpoint_restore_dir ,
618
- checkpoint_save_dir = args .checkpoint_save_dir ,
619
- export_onnx_graph = args .export_onnx_graph ,
620
- apply_stop_condition = args .apply_stop_condition
621
- )
630
+ if args .num_workers == 1 :
631
+ task_parameters = TaskParameters (
632
+ framework_type = args .framework ,
633
+ evaluate_only = args .evaluate ,
634
+ experiment_path = args .experiment_path ,
635
+ seed = args .seed ,
636
+ use_cpu = args .use_cpu ,
637
+ checkpoint_save_secs = args .checkpoint_save_secs ,
638
+ checkpoint_restore_dir = args .checkpoint_restore_dir ,
639
+ checkpoint_save_dir = args .checkpoint_save_dir ,
640
+ export_onnx_graph = args .export_onnx_graph ,
641
+ apply_stop_condition = args .apply_stop_condition
642
+ )
643
+ else :
644
+ task_parameters = DistributedTaskParameters (
645
+ framework_type = args .framework ,
646
+ use_cpu = args .use_cpu ,
647
+ num_training_tasks = args .num_workers ,
648
+ experiment_path = args .experiment_path ,
649
+ checkpoint_save_secs = args .checkpoint_save_secs ,
650
+ checkpoint_restore_dir = args .checkpoint_restore_dir ,
651
+ checkpoint_save_dir = args .checkpoint_save_dir ,
652
+ export_onnx_graph = args .export_onnx_graph ,
653
+ apply_stop_condition = args .apply_stop_condition
654
+ )
622
655
623
656
# open dashboard
624
657
if args .open_dashboard :
@@ -633,78 +666,16 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp
633
666
634
667
# Single-threaded runs
635
668
if args .num_workers == 1 :
636
- self .start_single_threaded (task_parameters , graph_manager , args )
669
+ self .start_single_threaded_learning (task_parameters , graph_manager , args )
637
670
else :
638
- self .start_multi_threaded (graph_manager , args )
671
+ global start_graph
672
+ start_multi_threaded_learning (start_graph , (graph_manager , task_parameters ),
673
+ task_parameters , graph_manager , args )
639
674
640
- def start_single_threaded (self , task_parameters , graph_manager : 'GraphManager' , args : argparse .Namespace ):
675
+ def start_single_threaded_learning (self , task_parameters , graph_manager : 'GraphManager' , args : argparse .Namespace ):
641
676
# Start the training or evaluation
642
677
start_graph (graph_manager = graph_manager , task_parameters = task_parameters )
643
678
644
- def start_multi_threaded (self , graph_manager : 'GraphManager' , args : argparse .Namespace ):
645
- total_tasks = args .num_workers
646
- if args .evaluation_worker :
647
- total_tasks += 1
648
-
649
- ps_hosts = "localhost:{}" .format (get_open_port ())
650
- worker_hosts = "," .join (["localhost:{}" .format (get_open_port ()) for i in range (total_tasks )])
651
-
652
- # Shared memory
653
- class CommManager (BaseManager ):
654
- pass
655
- CommManager .register ('SharedMemoryScratchPad' , SharedMemoryScratchPad , exposed = ['add' , 'get' , 'internal_call' ])
656
- comm_manager = CommManager ()
657
- comm_manager .start ()
658
- shared_memory_scratchpad = comm_manager .SharedMemoryScratchPad ()
659
-
660
- def start_distributed_task (job_type , task_index , evaluation_worker = False ,
661
- shared_memory_scratchpad = shared_memory_scratchpad ):
662
- task_parameters = DistributedTaskParameters (
663
- framework_type = args .framework ,
664
- parameters_server_hosts = ps_hosts ,
665
- worker_hosts = worker_hosts ,
666
- job_type = job_type ,
667
- task_index = task_index ,
668
- evaluate_only = 0 if evaluation_worker else None , # 0 value for evaluation worker as it should run infinitely
669
- use_cpu = args .use_cpu ,
670
- num_tasks = total_tasks , # training tasks + 1 evaluation task
671
- num_training_tasks = args .num_workers ,
672
- experiment_path = args .experiment_path ,
673
- shared_memory_scratchpad = shared_memory_scratchpad ,
674
- seed = args .seed + task_index if args .seed is not None else None , # each worker gets a different seed
675
- checkpoint_save_secs = args .checkpoint_save_secs ,
676
- checkpoint_restore_dir = args .checkpoint_restore_dir ,
677
- checkpoint_save_dir = args .checkpoint_save_dir ,
678
- export_onnx_graph = args .export_onnx_graph ,
679
- apply_stop_condition = args .apply_stop_condition
680
- )
681
- # we assume that only the evaluation workers are rendering
682
- graph_manager .visualization_parameters .render = args .render and evaluation_worker
683
- p = Process (target = start_graph , args = (graph_manager , task_parameters ))
684
- # p.daemon = True
685
- p .start ()
686
- return p
687
-
688
- # parameter server
689
- parameter_server = start_distributed_task ("ps" , 0 )
690
-
691
- # training workers
692
- # wait a bit before spawning the non chief workers in order to make sure the session is already created
693
- workers = []
694
- workers .append (start_distributed_task ("worker" , 0 ))
695
- time .sleep (2 )
696
- for task_index in range (1 , args .num_workers ):
697
- workers .append (start_distributed_task ("worker" , task_index ))
698
-
699
- # evaluation worker
700
- if args .evaluation_worker or args .render :
701
- evaluation_worker = start_distributed_task ("worker" , args .num_workers , evaluation_worker = True )
702
-
703
- # wait for all workers
704
- [w .join () for w in workers ]
705
- if args .evaluation_worker :
706
- evaluation_worker .terminate ()
707
-
708
679
709
680
def main ():
710
681
launcher = CoachLauncher ()
0 commit comments