From 5c23c6b112ccc3d24de51c4f37421602a6d0959d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 07:10:13 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- test/float8/test_dtensor.py | 2 ++ test/float8/test_fsdp2_tp.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 9db046b749..a9ccb35b79 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -67,6 +67,8 @@ def setup_distributed(): device_mesh = init_device_mesh("cuda", (world_size,)) # seed must be the same in all processes torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) return device_mesh diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index fa3d30410b..f04b791273 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -46,6 +46,8 @@ def setup_distributed(): ) # seed must be the same in all processes torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) return device_mesh