Skip to content

Commit e32d5b5

Browse files
committed
add additional tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent d2af054 commit e32d5b5

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

tests/test_utils/test_offload.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,8 @@ def test_offload_to_weights_map():
344344

345345
@requires_gpu
346346
@requires_accelerate()
347-
def test_register_offload_module():
348-
execution_device = torch.device("cuda")
349-
347+
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
348+
def test_register_offload_module(exec_device):
350349
# no offloading
351350
model = ExampleModel()
352351
child = torch.nn.Linear(2, 3)
@@ -358,37 +357,36 @@ def test_register_offload_module():
358357
# with offloading
359358
model = ExampleModel()
360359
child = torch.nn.Linear(2, 3)
361-
force_cpu_offload(model, execution_device)
360+
force_cpu_offload(model, exec_device)
362361
register_offload_module(model, "child", child)
363362
register_offload_module(model.linear, "child", child)
364363
assert child in model.children()
365364
assert child in model.linear.children()
366365

367366
# can run modules
368367
model(torch.empty(1))
369-
child(torch.empty(2, device=execution_device))
368+
child(torch.empty(2, device=exec_device))
370369

371370

372371
@requires_gpu
373372
@requires_accelerate()
374-
def test_force_cpu_offload():
375-
execution_device = torch.device("cuda")
376-
373+
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
374+
def test_force_cpu_offload(exec_device):
377375
# single module
378376
module = torch.nn.Linear(1, 2)
379-
module = force_cpu_offload(module, execution_device)
377+
module = force_cpu_offload(module, exec_device)
380378
assert has_offloaded_params(module)
381379
assert module._hf_hook.offload
382380
assert module.weight.device == torch.device("meta")
383381
assert "weight" in module._hf_hook.weights_map
384382
assert module._hf_hook.tied_params_map is not None
385383

386384
# can run
387-
module(torch.empty(1, device=execution_device))
385+
module(torch.empty(1, device=exec_device))
388386

389387
# model
390388
model = ExampleModel()
391-
model = force_cpu_offload(model, execution_device)
389+
model = force_cpu_offload(model, exec_device)
392390
assert not has_offloaded_params(model)
393391

394392
assert has_offloaded_params(model.linear)
@@ -398,4 +396,4 @@ def test_force_cpu_offload():
398396
assert model.linear._hf_hook.tied_params_map is not None
399397

400398
# can run
401-
model(torch.empty(1, device=execution_device))
399+
model(torch.empty(1, device=exec_device))

0 commit comments

Comments
 (0)