Skip to content

Commit 8b051b4

Browse files
authored
fix float8 training TP+SP integration tests (#2414)
Update [ghstack-poisoned]
1 parent d506cc7 commit 8b051b4

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

test/float8/test_dtensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def setup_distributed():
6767
device_mesh = init_device_mesh("cuda", (world_size,))
6868
# seed must be the same in all processes
6969
torch.manual_seed(1)
70+
local_rank = torch.distributed.get_rank()
71+
torch.cuda.set_device(local_rank)
7072
return device_mesh
7173

7274

test/float8/test_fsdp2_tp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def setup_distributed():
4646
)
4747
# seed must be the same in all processes
4848
torch.manual_seed(1)
49+
local_rank = torch.distributed.get_rank()
50+
torch.cuda.set_device(local_rank)
4951
return device_mesh
5052

5153

0 commit comments

Comments
 (0)