10
10
import copy
11
11
import dataclasses
12
12
import logging
13
+ import multiprocessing
13
14
import os
14
15
import tempfile
15
16
import unittest
16
- from typing import Any , Dict , List , Optional
17
+ from typing import Any , Callable , Dict , List , Optional
17
18
from unittest .mock import MagicMock , patch
18
19
19
20
import torch
20
21
import torch .distributed as dist
21
- import torch .distributed .launcher as pet
22
22
from torchrec .distributed .test_utils .multi_process import (
23
23
MultiProcessContext ,
24
24
MultiProcessTestBase ,
43
43
)
44
44
from torchrec .metrics .model_utils import parse_task_model_outputs
45
45
from torchrec .metrics .rec_metric import RecMetricList , RecTaskInfo
46
- from torchrec .metrics .test_utils import gen_test_batch , get_launch_config
46
+ from torchrec .metrics .test_utils import gen_test_batch
47
47
from torchrec .metrics .throughput import ThroughputMetric
48
- from torchrec .test_utils import seed_and_log , skip_if_asan_class
48
+ from torchrec .test_utils import get_free_port , seed_and_log , skip_if_asan_class
49
49
50
50
METRIC_MODULE_PATH = "torchrec.metrics.metric_module"
51
51
@@ -100,6 +100,47 @@ def _update_rec_metrics(
100
100
101
101
102
102
class MetricModuleTest (unittest .TestCase ):
103
+ @seed_and_log
104
+ def setUp (self ) -> None :
105
+ os .environ ["MASTER_ADDR" ] = str ("localhost" )
106
+ os .environ ["MASTER_PORT" ] = str (get_free_port ())
107
+ os .environ ["GLOO_DEVICE_TRANSPORT" ] = "TCP"
108
+ os .environ ["NCCL_SOCKET_IFNAME" ] = "lo"
109
+ self .WORLD_SIZE = 2
110
+
111
+ def tearDown (self ) -> None :
112
+ del os .environ ["GLOO_DEVICE_TRANSPORT" ]
113
+ del os .environ ["NCCL_SOCKET_IFNAME" ]
114
+ super ().tearDown ()
115
+
116
+ def _run_multi_process_test (
117
+ self ,
118
+ world_size : int ,
119
+ backend : str ,
120
+ callable : Callable [..., None ],
121
+ * args : Any ,
122
+ ** kwargs : Any ,
123
+ ) -> None :
124
+ processes = []
125
+ ctx = multiprocessing .get_context ("spawn" )
126
+ for rank in range (world_size ):
127
+ p = ctx .Process (
128
+ target = callable ,
129
+ args = (
130
+ rank ,
131
+ world_size ,
132
+ backend ,
133
+ * args ,
134
+ ),
135
+ kwargs = kwargs ,
136
+ )
137
+ p .start ()
138
+ processes .append (p )
139
+
140
+ for p in processes :
141
+ p .join ()
142
+ self .assertEqual (0 , p .exitcode )
143
+
103
144
def test_metric_module (self ) -> None :
104
145
rec_metric_list_patch = patch (
105
146
METRIC_MODULE_PATH + ".RecMetricList" ,
@@ -184,11 +225,9 @@ def test_rectask_info(self) -> None:
184
225
)
185
226
186
227
@staticmethod
187
- def _run_trainer_checkpointing () -> None :
188
- world_size = int (os .environ ["WORLD_SIZE" ])
189
- rank = int (os .environ ["RANK" ])
228
+ def _run_trainer_checkpointing (rank : int , world_size : int , backend : str ) -> None :
190
229
dist .init_process_group (
191
- backend = "gloo" ,
230
+ backend = backend ,
192
231
world_size = world_size ,
193
232
rank = rank ,
194
233
)
@@ -263,18 +302,18 @@ def test_rank0_checkpointing(self) -> None:
263
302
metric_module .reset ()
264
303
# End of dummy codes
265
304
266
- with tempfile . TemporaryDirectory () as tmpdir :
267
- lc = get_launch_config (
268
- world_size = 2 , rdzv_endpoint = os . path . join ( tmpdir , "rdzv" )
269
- )
270
- pet . elastic_launch ( lc , entrypoint = self . _run_trainer_checkpointing )( )
305
+ self . _run_multi_process_test (
306
+ world_size = self . WORLD_SIZE ,
307
+ backend = "gloo" ,
308
+ callable = self . _run_trainer_checkpointing ,
309
+ )
271
310
272
311
@staticmethod
273
- def _run_trainer_initial_states_checkpointing () -> None :
274
- world_size = int ( os . environ [ "WORLD_SIZE" ])
275
- rank = int ( os . environ [ "RANK" ])
312
+ def _run_trainer_initial_states_checkpointing (
313
+ rank : int , world_size : int , backend : str
314
+ ) -> None :
276
315
dist .init_process_group (
277
- backend = "gloo" ,
316
+ backend = backend ,
278
317
world_size = world_size ,
279
318
rank = rank ,
280
319
)
@@ -352,13 +391,11 @@ def _run_trainer_initial_states_checkpointing() -> None:
352
391
)
353
392
354
393
def test_initial_states_rank0_checkpointing (self ) -> None :
355
- with tempfile .TemporaryDirectory () as tmpdir :
356
- lc = get_launch_config (
357
- world_size = 2 , rdzv_endpoint = os .path .join (tmpdir , "rdzv" )
358
- )
359
- pet .elastic_launch (
360
- lc , entrypoint = self ._run_trainer_initial_states_checkpointing
361
- )()
394
+ self ._run_multi_process_test (
395
+ world_size = self .WORLD_SIZE ,
396
+ backend = "gloo" ,
397
+ callable = self ._run_trainer_initial_states_checkpointing ,
398
+ )
362
399
363
400
def test_should_compute (self ) -> None :
364
401
metric_module = generate_metric_module (
@@ -381,6 +418,9 @@ def test_should_compute(self) -> None:
381
418
@patch ("torchrec.metrics.metric_module.RecMetricList" )
382
419
@patch ("torchrec.metrics.metric_module.time" )
383
420
def _test_adjust_compute_interval (
421
+ rank : int ,
422
+ world_size : int ,
423
+ backend : str ,
384
424
batch_time : float ,
385
425
min_interval : float ,
386
426
max_interval : float ,
@@ -390,10 +430,8 @@ def _test_adjust_compute_interval(
390
430
init_by_me = False
391
431
if not dist .is_initialized ():
392
432
init_by_me = True
393
- world_size = int (os .environ ["WORLD_SIZE" ])
394
- rank = int (os .environ ["RANK" ])
395
433
dist .init_process_group (
396
- backend = "gloo" ,
434
+ backend = backend ,
397
435
world_size = world_size ,
398
436
rank = rank ,
399
437
)
@@ -461,13 +499,14 @@ def _test_adjust_compute_interval_launcher(
461
499
min_interval : float = 0.0 ,
462
500
max_interval : float = float ("inf" ),
463
501
) -> None :
464
- with tempfile .TemporaryDirectory () as tmpdir :
465
- lc = get_launch_config (
466
- world_size = 2 , rdzv_endpoint = os .path .join (tmpdir , "rdzv" )
467
- )
468
- pet .elastic_launch (lc , entrypoint = self ._test_adjust_compute_interval )(
469
- batch_time , min_interval , max_interval
470
- )
502
+ self ._run_multi_process_test (
503
+ self .WORLD_SIZE ,
504
+ "gloo" ,
505
+ self ._test_adjust_compute_interval ,
506
+ batch_time ,
507
+ min_interval ,
508
+ max_interval ,
509
+ )
471
510
472
511
def test_adjust_compute_interval_not_set (self ) -> None :
473
512
self ._test_adjust_compute_interval_launcher (
@@ -482,15 +521,15 @@ def test_adjust_compute_interval_0_30(self) -> None:
482
521
)
483
522
484
523
# This is to ensure the test coverage is correct.
485
- with tempfile .NamedTemporaryFile (delete = True ) as backend :
524
+ with tempfile .NamedTemporaryFile (delete = True ) as backend_file :
486
525
dist .init_process_group (
487
526
backend = "gloo" ,
488
- init_method = f"file://{ backend .name } " ,
527
+ init_method = f"file://{ backend_file .name } " ,
489
528
world_size = 1 ,
490
529
rank = 0 ,
491
530
)
492
531
493
- self ._test_adjust_compute_interval (1 , 0.0 , 30.0 )
532
+ self ._test_adjust_compute_interval (0 , 1 , "gloo" , 1 , 0.0 , 30.0 )
494
533
# Needed to destroy the process group as _test_adjust_compute_interval
495
534
# won't since we initialize the process group for it.
496
535
dist .destroy_process_group ()
@@ -503,15 +542,15 @@ def test_adjust_compute_interval_15_inf(self) -> None:
503
542
)
504
543
505
544
# This is to ensure the test coverage is correct.
506
- with tempfile .NamedTemporaryFile (delete = True ) as backend :
545
+ with tempfile .NamedTemporaryFile (delete = True ) as backend_file :
507
546
dist .init_process_group (
508
547
backend = "gloo" ,
509
- init_method = f"file://{ backend .name } " ,
548
+ init_method = f"file://{ backend_file .name } " ,
510
549
world_size = 1 ,
511
550
rank = 0 ,
512
551
)
513
552
514
- self ._test_adjust_compute_interval (0.1 , 15.0 , float ("inf" ))
553
+ self ._test_adjust_compute_interval (0 , 1 , "gloo" , 0 .1 , 15.0 , float ("inf" ))
515
554
# Needed to destroy the process group as _test_adjust_compute_interval
516
555
# won't since we initialize the process group for it.
517
556
dist .destroy_process_group ()
@@ -524,15 +563,15 @@ def test_adjust_compute_interval_15_30(self) -> None:
524
563
)
525
564
526
565
# This is to ensure the test coverage is correct.
527
- with tempfile .NamedTemporaryFile (delete = True ) as backend :
566
+ with tempfile .NamedTemporaryFile (delete = True ) as backend_file :
528
567
dist .init_process_group (
529
568
backend = "gloo" ,
530
- init_method = f"file://{ backend .name } " ,
569
+ init_method = f"file://{ backend_file .name } " ,
531
570
world_size = 1 ,
532
571
rank = 0 ,
533
572
)
534
573
535
- self ._test_adjust_compute_interval (1 , 15.0 , 30.0 )
574
+ self ._test_adjust_compute_interval (0 , 1 , "gloo" , 1 , 15.0 , 30.0 )
536
575
# Needed to destroy the process group as _test_adjust_compute_interval
537
576
# won't since we initialize the process group for it.
538
577
dist .destroy_process_group ()
0 commit comments