Skip to content

Commit 5a73ae5

Browse files
nipung90facebook-github-bot
authored andcommitted
Fix RE failures on test_metric_module by removing the use of elastic capacity (pytorch#3176)
Summary: Pull Request resolved: pytorch#3176 test_metric_module tests have been failing due to timeouts due to the enabling of remote execution on these tests. The current set up uses elastic capacity for remote execution which can run into various issues like bugs/lack of SLA on available capacity. Based on test_comm.py(https://www.internalfb.com/code/fbsource/[03b0ea7723d9]/fbcode/torchrec/distributed/tests/test_comm.py), which has the same remote execution condition, which continues to succeed, changing how remote execution is kicked off. Reviewed By: iamzainhuda, kausv Differential Revision: D78018973 fbshipit-source-id: 150fe0c50cabd5e9b6f3d566991b3cc0fbb19a46
1 parent 407b339 commit 5a73ae5

File tree

1 file changed

+82
-43
lines changed

1 file changed

+82
-43
lines changed

torchrec/metrics/tests/test_metric_module.py

Lines changed: 82 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
import copy
1111
import dataclasses
1212
import logging
13+
import multiprocessing
1314
import os
1415
import tempfile
1516
import unittest
16-
from typing import Any, Dict, List, Optional
17+
from typing import Any, Callable, Dict, List, Optional
1718
from unittest.mock import MagicMock, patch
1819

1920
import torch
2021
import torch.distributed as dist
21-
import torch.distributed.launcher as pet
2222
from torchrec.distributed.test_utils.multi_process import (
2323
MultiProcessContext,
2424
MultiProcessTestBase,
@@ -43,9 +43,9 @@
4343
)
4444
from torchrec.metrics.model_utils import parse_task_model_outputs
4545
from torchrec.metrics.rec_metric import RecMetricList, RecTaskInfo
46-
from torchrec.metrics.test_utils import gen_test_batch, get_launch_config
46+
from torchrec.metrics.test_utils import gen_test_batch
4747
from torchrec.metrics.throughput import ThroughputMetric
48-
from torchrec.test_utils import seed_and_log, skip_if_asan_class
48+
from torchrec.test_utils import get_free_port, seed_and_log, skip_if_asan_class
4949

5050
METRIC_MODULE_PATH = "torchrec.metrics.metric_module"
5151

@@ -100,6 +100,47 @@ def _update_rec_metrics(
100100

101101

102102
class MetricModuleTest(unittest.TestCase):
103+
@seed_and_log
104+
def setUp(self) -> None:
105+
os.environ["MASTER_ADDR"] = str("localhost")
106+
os.environ["MASTER_PORT"] = str(get_free_port())
107+
os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"
108+
os.environ["NCCL_SOCKET_IFNAME"] = "lo"
109+
self.WORLD_SIZE = 2
110+
111+
def tearDown(self) -> None:
112+
del os.environ["GLOO_DEVICE_TRANSPORT"]
113+
del os.environ["NCCL_SOCKET_IFNAME"]
114+
super().tearDown()
115+
116+
def _run_multi_process_test(
117+
self,
118+
world_size: int,
119+
backend: str,
120+
callable: Callable[..., None],
121+
*args: Any,
122+
**kwargs: Any,
123+
) -> None:
124+
processes = []
125+
ctx = multiprocessing.get_context("spawn")
126+
for rank in range(world_size):
127+
p = ctx.Process(
128+
target=callable,
129+
args=(
130+
rank,
131+
world_size,
132+
backend,
133+
*args,
134+
),
135+
kwargs=kwargs,
136+
)
137+
p.start()
138+
processes.append(p)
139+
140+
for p in processes:
141+
p.join()
142+
self.assertEqual(0, p.exitcode)
143+
103144
def test_metric_module(self) -> None:
104145
rec_metric_list_patch = patch(
105146
METRIC_MODULE_PATH + ".RecMetricList",
@@ -184,11 +225,9 @@ def test_rectask_info(self) -> None:
184225
)
185226

186227
@staticmethod
187-
def _run_trainer_checkpointing() -> None:
188-
world_size = int(os.environ["WORLD_SIZE"])
189-
rank = int(os.environ["RANK"])
228+
def _run_trainer_checkpointing(rank: int, world_size: int, backend: str) -> None:
190229
dist.init_process_group(
191-
backend="gloo",
230+
backend=backend,
192231
world_size=world_size,
193232
rank=rank,
194233
)
@@ -263,18 +302,18 @@ def test_rank0_checkpointing(self) -> None:
263302
metric_module.reset()
264303
# End of dummy codes
265304

266-
with tempfile.TemporaryDirectory() as tmpdir:
267-
lc = get_launch_config(
268-
world_size=2, rdzv_endpoint=os.path.join(tmpdir, "rdzv")
269-
)
270-
pet.elastic_launch(lc, entrypoint=self._run_trainer_checkpointing)()
305+
self._run_multi_process_test(
306+
world_size=self.WORLD_SIZE,
307+
backend="gloo",
308+
callable=self._run_trainer_checkpointing,
309+
)
271310

272311
@staticmethod
273-
def _run_trainer_initial_states_checkpointing() -> None:
274-
world_size = int(os.environ["WORLD_SIZE"])
275-
rank = int(os.environ["RANK"])
312+
def _run_trainer_initial_states_checkpointing(
313+
rank: int, world_size: int, backend: str
314+
) -> None:
276315
dist.init_process_group(
277-
backend="gloo",
316+
backend=backend,
278317
world_size=world_size,
279318
rank=rank,
280319
)
@@ -352,13 +391,11 @@ def _run_trainer_initial_states_checkpointing() -> None:
352391
)
353392

354393
def test_initial_states_rank0_checkpointing(self) -> None:
355-
with tempfile.TemporaryDirectory() as tmpdir:
356-
lc = get_launch_config(
357-
world_size=2, rdzv_endpoint=os.path.join(tmpdir, "rdzv")
358-
)
359-
pet.elastic_launch(
360-
lc, entrypoint=self._run_trainer_initial_states_checkpointing
361-
)()
394+
self._run_multi_process_test(
395+
world_size=self.WORLD_SIZE,
396+
backend="gloo",
397+
callable=self._run_trainer_initial_states_checkpointing,
398+
)
362399

363400
def test_should_compute(self) -> None:
364401
metric_module = generate_metric_module(
@@ -381,6 +418,9 @@ def test_should_compute(self) -> None:
381418
@patch("torchrec.metrics.metric_module.RecMetricList")
382419
@patch("torchrec.metrics.metric_module.time")
383420
def _test_adjust_compute_interval(
421+
rank: int,
422+
world_size: int,
423+
backend: str,
384424
batch_time: float,
385425
min_interval: float,
386426
max_interval: float,
@@ -390,10 +430,8 @@ def _test_adjust_compute_interval(
390430
init_by_me = False
391431
if not dist.is_initialized():
392432
init_by_me = True
393-
world_size = int(os.environ["WORLD_SIZE"])
394-
rank = int(os.environ["RANK"])
395433
dist.init_process_group(
396-
backend="gloo",
434+
backend=backend,
397435
world_size=world_size,
398436
rank=rank,
399437
)
@@ -461,13 +499,14 @@ def _test_adjust_compute_interval_launcher(
461499
min_interval: float = 0.0,
462500
max_interval: float = float("inf"),
463501
) -> None:
464-
with tempfile.TemporaryDirectory() as tmpdir:
465-
lc = get_launch_config(
466-
world_size=2, rdzv_endpoint=os.path.join(tmpdir, "rdzv")
467-
)
468-
pet.elastic_launch(lc, entrypoint=self._test_adjust_compute_interval)(
469-
batch_time, min_interval, max_interval
470-
)
502+
self._run_multi_process_test(
503+
self.WORLD_SIZE,
504+
"gloo",
505+
self._test_adjust_compute_interval,
506+
batch_time,
507+
min_interval,
508+
max_interval,
509+
)
471510

472511
def test_adjust_compute_interval_not_set(self) -> None:
473512
self._test_adjust_compute_interval_launcher(
@@ -482,15 +521,15 @@ def test_adjust_compute_interval_0_30(self) -> None:
482521
)
483522

484523
# This is to ensure the test coverage is correct.
485-
with tempfile.NamedTemporaryFile(delete=True) as backend:
524+
with tempfile.NamedTemporaryFile(delete=True) as backend_file:
486525
dist.init_process_group(
487526
backend="gloo",
488-
init_method=f"file://{backend.name}",
527+
init_method=f"file://{backend_file.name}",
489528
world_size=1,
490529
rank=0,
491530
)
492531

493-
self._test_adjust_compute_interval(1, 0.0, 30.0)
532+
self._test_adjust_compute_interval(0, 1, "gloo", 1, 0.0, 30.0)
494533
# Needed to destroy the process group as _test_adjust_compute_interval
495534
# won't since we initialize the process group for it.
496535
dist.destroy_process_group()
@@ -503,15 +542,15 @@ def test_adjust_compute_interval_15_inf(self) -> None:
503542
)
504543

505544
# This is to ensure the test coverage is correct.
506-
with tempfile.NamedTemporaryFile(delete=True) as backend:
545+
with tempfile.NamedTemporaryFile(delete=True) as backend_file:
507546
dist.init_process_group(
508547
backend="gloo",
509-
init_method=f"file://{backend.name}",
548+
init_method=f"file://{backend_file.name}",
510549
world_size=1,
511550
rank=0,
512551
)
513552

514-
self._test_adjust_compute_interval(0.1, 15.0, float("inf"))
553+
self._test_adjust_compute_interval(0, 1, "gloo", 0.1, 15.0, float("inf"))
515554
# Needed to destroy the process group as _test_adjust_compute_interval
516555
# won't since we initialize the process group for it.
517556
dist.destroy_process_group()
@@ -524,15 +563,15 @@ def test_adjust_compute_interval_15_30(self) -> None:
524563
)
525564

526565
# This is to ensure the test coverage is correct.
527-
with tempfile.NamedTemporaryFile(delete=True) as backend:
566+
with tempfile.NamedTemporaryFile(delete=True) as backend_file:
528567
dist.init_process_group(
529568
backend="gloo",
530-
init_method=f"file://{backend.name}",
569+
init_method=f"file://{backend_file.name}",
531570
world_size=1,
532571
rank=0,
533572
)
534573

535-
self._test_adjust_compute_interval(1, 15.0, 30.0)
574+
self._test_adjust_compute_interval(0, 1, "gloo", 1, 15.0, 30.0)
536575
# Needed to destroy the process group as _test_adjust_compute_interval
537576
# won't since we initialize the process group for it.
538577
dist.destroy_process_group()

0 commit comments

Comments
 (0)