11
11
import pytest
12
12
import torch
13
13
from torch import nn
14
- from torch .distributed ._composable .fsdp import fully_shard
14
+ from torch .distributed ._composable .fsdp import (
15
+ fully_shard ,
16
+ CPUOffloadPolicy ,
17
+ OffloadPolicy ,
18
+ )
15
19
from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
16
20
from torch .testing ._internal .common_fsdp import FSDPTest
17
21
from torch .testing ._internal .common_utils import (
@@ -427,16 +431,21 @@ def world_size(self) -> int:
427
431
@skip_if_lt_x_gpu (_FSDP_WORLD_SIZE )
428
432
@skip_if_rocm ("ROCm enablement in progress" )
429
433
def test_fsdp2 (self ):
430
- optim_classes = [optim .AdamW8bit , optim .AdamW4bit ]
434
+ # we do this to avoid all combinations
435
+ args_list = [
436
+ (optim .AdamW8bit , OffloadPolicy ),
437
+ (optim .AdamW4bit , OffloadPolicy ),
438
+ (optim .AdamW8bit , CPUOffloadPolicy ),
439
+ ]
431
440
if torch .cuda .get_device_capability () >= (8 , 9 ):
432
- optim_classes .append (optim .AdamWFp8 )
441
+ args_list .append (( optim .AdamWFp8 , OffloadPolicy ) )
433
442
434
443
self .run_subtests (
435
- {"optim_cls " : optim_classes },
444
+ {"args " : args_list },
436
445
self ._test_fsdp2 ,
437
446
)
438
447
439
- def _test_fsdp2 (self , optim_cls ):
448
+ def _test_fsdp2 (self , args ):
440
449
import torch .distributed as dist
441
450
import torch .distributed .checkpoint as dcp
442
451
import torch .utils ._pytree as pytree
@@ -447,6 +456,8 @@ def _test_fsdp2(self, optim_cls):
447
456
TransformerBlock ,
448
457
)
449
458
459
+ optim_cls , offload_policy = args
460
+
450
461
batch_size = 3
451
462
vocab_size = 1024
452
463
seq_len = 64
@@ -466,8 +477,8 @@ def _test_fsdp2(self, optim_cls):
466
477
fsdp_model = copy .deepcopy (base_model )
467
478
for m in fsdp_model .modules ():
468
479
if isinstance (m , TransformerBlock ):
469
- fully_shard (m )
470
- fully_shard (fsdp_model )
480
+ fully_shard (m , offload_policy = offload_policy )
481
+ fully_shard (fsdp_model , offload_policy = offload_policy )
471
482
fsdp_optim = optim_cls (fsdp_model .parameters (), lr = 1e-2 )
472
483
473
484
torch .manual_seed (42 + self .rank + 1 )
0 commit comments