Skip to content

Commit 7dbbb40

Browse files
evgri243meta-codesync[bot]
authored andcommitted
Evgri243/multi device models (#796)
Summary: ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Docs change / refactoring / dependency upgrade ## Motivation and Context / Related issue This PR adds support for multi-device training scenarios where model parameters are distributed across multiple GPU devices (e.g., when assigning different layers directly with `module.to(device[I])` oe using `device_map="auto"` with accelerate). **Problem solved:** When training large models that don't fit on a single GPU, parameters and gradients can be spread across multiple devices. The existing Opacus optimizers and gradient clipping modules assumed all tensors were on the same device, causing runtime errors during norm computation and gradient clipping operations. **Changes:** 1. **Sequential multi-device execution support (#9: Modified `DPOptimizer` and `AdaClipDPOptimizer` to move tensors to appropriate devices before operations like `torch.stack()` and `torch.einsum()`, preventing device mismatch errors during gradient clipping and accumulation. 2. **Multi-device support in GradSampleModuleFastGradientClipping (#10: Extended multi-device handling to `GradSampleModuleFastGradientClipping`, `DPPerLayerOptimizer`, and additional edge cases in optimizers that were previously uncovered. ## How Has This Been Tested - The code was used to train 7B Zetta model with LoRA on 8xH200 GPU node. - Added test suite in `multidevice_optimizer_test.py` covering: - `DPOptimizer`, `AdaClipDPOptimizer`, and `DPPerLayerOptimizer` with multi-device models - Both `clip_and_accumulate()` and full `step()` operations - Helper function `_clip_and_accumulate_parameter()` with multi-device parameters - Added additional tests in `grad_sample_module_fast_gradient_clipping_test.py` for: - `get_norm_sample()` with parameters on different devices - `get_clipping_coef()` with parameters on different devices - All tests require at least 2 GPUs and verify that operations complete without device mismatch errors while maintaining correctness ## Checklist - [x] The documentation is up-to-date with the changes I made. - [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**). - [x] All tests passed, and additional code has been covered with new tests. Pull Request resolved: #796 Reviewed By: iden-kalemaj Differential Revision: D85355821 fbshipit-source-id: 19da3c47ba5308748e839984194d1ce4b802d52f
1 parent 7dd02da commit 7dbbb40

File tree

5 files changed

+504
-6
lines changed

5 files changed

+504
-6
lines changed

opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,11 @@ def get_clipping_coef(self) -> torch.Tensor:
140140

141141
def get_norm_sample(self) -> torch.Tensor:
142142
"""Get per-example gradient norms."""
143-
norm_sample = torch.stack(
144-
[param._norm_sample for param in self.trainable_parameters], dim=0
145-
).norm(2, dim=0)
143+
norm_samples = [param._norm_sample for param in self.trainable_parameters]
144+
if norm_samples:
145+
target_device = norm_samples[0].device
146+
norm_samples = [norm.to(target_device) for norm in norm_samples]
147+
norm_sample = torch.stack(norm_samples, dim=0).norm(2, dim=0)
146148
self.per_sample_gradient_norms = norm_sample
147149
return norm_sample
148150

opacus/optimizers/adaclipoptimizer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ def clip_and_accumulate(self):
9797
per_param_norms = [
9898
g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
9999
]
100+
101+
if per_param_norms:
102+
target_device = per_param_norms[0].device
103+
per_param_norms = [norm.to(target_device) for norm in per_param_norms]
104+
100105
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
101106
per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp(
102107
max=1.0
@@ -112,7 +117,9 @@ def clip_and_accumulate(self):
112117
for p in self.params:
113118
_check_processed_flag(p.grad_sample)
114119
grad_sample = self._get_flat_grad_sample(p)
115-
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
120+
121+
clip_factor_on_device = per_sample_clip_factor.to(grad_sample.device)
122+
grad = torch.einsum("i,i...", clip_factor_on_device, grad_sample)
116123

117124
if p.summed_grad is not None:
118125
p.summed_grad += grad

opacus/optimizers/optimizer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,11 @@ def clip_and_accumulate(self):
444444
per_param_norms = [
445445
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
446446
]
447+
448+
if per_param_norms:
449+
target_device = per_param_norms[0].device
450+
per_param_norms = [norm.to(target_device) for norm in per_param_norms]
451+
447452
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
448453
per_sample_clip_factor = (
449454
self.max_grad_norm / (per_sample_norms + 1e-6)
@@ -457,8 +462,10 @@ def clip_and_accumulate(self):
457462
# for mixed precision, optimizer parameters are usually in FP32
458463
# lower precision grads will be cast up to FP32
459464
grad_sample = grad_sample.to(p.dtype)
460-
per_sample_clip_factor = per_sample_clip_factor.to(p.dtype)
461-
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
465+
clip_factor_on_device = per_sample_clip_factor.to(grad_sample.device).to(
466+
p.dtype
467+
)
468+
grad = torch.einsum("i,i...", clip_factor_on_device, grad_sample)
462469

463470
if p.summed_grad is not None:
464471
p.summed_grad += grad

opacus/tests/grad_sample_module_fast_gradient_clipping_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,97 @@ def test_gradient_calculation(self):
422422
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
423423
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
424424
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg
425+
426+
@unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs")
427+
def test_multidevice_get_norm_sample(self):
428+
"""Test that get_norm_sample handles parameters on different devices."""
429+
device1 = torch.device("cuda:0")
430+
device2 = torch.device("cuda:1")
431+
432+
# Create a simple model with parameters on different devices
433+
class MultiDeviceModel(nn.Module):
434+
def __init__(self):
435+
super().__init__()
436+
self.fc1 = nn.Linear(10, 20).to(device1)
437+
self.fc2 = nn.Linear(20, 5).to(device2)
438+
439+
def forward(self, x):
440+
x = x.to(device1)
441+
x = torch.relu(self.fc1(x))
442+
x = x.to(device2)
443+
return self.fc2(x)
444+
445+
model = MultiDeviceModel()
446+
grad_sample_module = GradSampleModuleFastGradientClipping(
447+
model, max_grad_norm=1.0, use_ghost_clipping=False
448+
)
449+
450+
# Simulate _norm_sample on different devices
451+
batch_size = 4
452+
for param in grad_sample_module.trainable_parameters:
453+
param._norm_sample = torch.randn(batch_size, device=param.device)
454+
455+
# This should not raise any device mismatch errors
456+
try:
457+
norm_sample = grad_sample_module.get_norm_sample()
458+
success = True
459+
except RuntimeError as e:
460+
if "Expected all tensors to be on the same device" in str(e):
461+
success = False
462+
self.fail(f"Device mismatch error in get_norm_sample: {e}")
463+
else:
464+
raise
465+
466+
self.assertTrue(
467+
success, "get_norm_sample should handle multi-device parameters"
468+
)
469+
self.assertEqual(norm_sample.shape[0], batch_size)
470+
471+
@unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs")
472+
def test_multidevice_get_clipping_coef(self):
473+
"""Test that get_clipping_coef handles parameters on different devices."""
474+
device1 = torch.device("cuda:0")
475+
device2 = torch.device("cuda:1")
476+
477+
# Create a simple model with parameters on different devices
478+
class MultiDeviceModel(nn.Module):
479+
def __init__(self):
480+
super().__init__()
481+
self.fc1 = nn.Linear(10, 20).to(device1)
482+
self.fc2 = nn.Linear(20, 5).to(device2)
483+
484+
def forward(self, x):
485+
x = x.to(device1)
486+
x = torch.relu(self.fc1(x))
487+
x = x.to(device2)
488+
return self.fc2(x)
489+
490+
model = MultiDeviceModel()
491+
max_grad_norm = 1.0
492+
grad_sample_module = GradSampleModuleFastGradientClipping(
493+
model, max_grad_norm=max_grad_norm, use_ghost_clipping=False
494+
)
495+
496+
# Simulate _norm_sample on different devices
497+
batch_size = 4
498+
for param in grad_sample_module.trainable_parameters:
499+
# Create norms with values that will require clipping
500+
param._norm_sample = torch.ones(batch_size, device=param.device) * 2.0
501+
502+
# This should not raise any device mismatch errors
503+
try:
504+
clipping_coef = grad_sample_module.get_clipping_coef()
505+
success = True
506+
except RuntimeError as e:
507+
if "Expected all tensors to be on the same device" in str(e):
508+
success = False
509+
self.fail(f"Device mismatch error in get_clipping_coef: {e}")
510+
else:
511+
raise
512+
513+
self.assertTrue(
514+
success, "get_clipping_coef should handle multi-device parameters"
515+
)
516+
self.assertEqual(clipping_coef.shape[0], batch_size)
517+
# Verify clipping coefficients are correct
518+
self.assertTrue(torch.all(clipping_coef <= 1.0))

0 commit comments

Comments
 (0)