Skip to content

Commit 57fa0d9

Browse files
authored
Merge pull request #26 from lxuechen/v0.2.2
V0.2.2; fixes #25
2 parents e7ac941 + 3336191 commit 57fa0d9

File tree

5 files changed

+18
-37
lines changed

5 files changed

+18
-37
lines changed

private_transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .privacy_engine import PrivacyEngine
22
from .transformers_support import freeze_isolated_params_for_vit
33

4-
__version__ = '0.2.1'
4+
__version__ = '0.2.2'

private_transformers/autograd_grad_sample.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,6 @@ def requires_grad(module: nn.Module, recurse: bool = False) -> bool:
5252
return requires_grad
5353

5454

55-
def get_layer_type(layer: nn.Module) -> str:
56-
"""
57-
Returns the name of the type of the given layer.
58-
59-
Args:
60-
layer: The module corresponding to the layer whose type
61-
is being queried.
62-
63-
Returns:
64-
Name of the class of the layer
65-
"""
66-
return layer.__class__.__name__
67-
68-
6955
def add_hooks(model: nn.Module, loss_reduction: str = "mean"):
7056
r"""
7157
Adds hooks to model to save activations and backprop values.
@@ -86,7 +72,7 @@ def add_hooks(model: nn.Module, loss_reduction: str = "mean"):
8672

8773
handles = []
8874
for name, layer in model.named_modules():
89-
if get_layer_type(layer) in _supported_layers_grad_samplers.keys():
75+
if type(layer) in _supported_layers_grad_samplers:
9076
# Check if the layer has trainable parameters.
9177
is_trainable = False
9278
for p in layer.parameters(recurse=False):

private_transformers/privacy_engine.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ def step(
248248
# This option was included to help with another spectrum analysis project.
249249
callback: Optional[Callable] = None,
250250
):
251+
if loss.dim() != 1:
252+
raise ValueError(
253+
f"Expected `loss` to be the per-example loss 1-D tensor, but got a tensor with dims={loss.dim()}."
254+
)
255+
251256
if self.clipping_mode == ClippingMode.ghost:
252257
if callback is not None:
253258
raise ValueError("Ghost clipping does not support `callback` in `optimizer.step`.")
@@ -359,11 +364,6 @@ def _ghost_virtual_step(self, loss: torch.Tensor):
359364
@torch.enable_grad()
360365
def _double_backward(self, loss: torch.Tensor):
361366
"""Given per-example losses, backward twice to accumulate summed clipped gradients in `.grad`."""
362-
if loss.dim() != 1:
363-
raise ValueError(
364-
f"Expected `loss` to be the per-example loss 1-D tensor, but got a tensor with dims={loss.dim()}."
365-
)
366-
367367
first_loss = loss.sum()
368368
first_loss.backward(retain_graph=True)
369369

@@ -437,9 +437,6 @@ def _accumulate_summed_grad(self, loss, scale):
437437
438438
Removes `.grad_sample` and `.grad` for each variable that requires grad at the end.
439439
"""
440-
if loss.dim() != 1:
441-
raise ValueError(f"Expected `loss` to be a the per-example loss 1-D tensor.")
442-
443440
with torch.enable_grad():
444441
loss.sum(dim=0).backward()
445442

@@ -466,12 +463,8 @@ def _accumulate_summed_grad(self, loss, scale):
466463
for tensor in norm_sample:
467464
shapes[tensor.size()] += 1
468465

469-
major_shape = None
470-
major_count = 0
471-
for shape, count in shapes.items():
472-
if count > major_count:
473-
major_shape = shape
474-
del shape, count
466+
# Get the shape that most tensors have.
467+
major_shape, major_count = max(shapes.items(), key=lambda x: x[1])
475468

476469
# Check which tensors don't have the major shape!
477470
extra_msg = f" \n*** Major shape: {major_shape}"

private_transformers/supported_layers_grad_samplers.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
import torch
19+
import transformers.pytorch_utils
1920
from opt_einsum import contract
2021
from torch import nn
2122
from torch.functional import F
@@ -313,11 +314,11 @@ def _compute_conv2d_grad_sample(layer: nn.Conv2d, activations: Tuple[torch.Tenso
313314

314315

315316
_supported_layers_grad_samplers = {
316-
"Embedding": _compute_embedding_grad_sample,
317-
"Linear": _compute_linear_grad_sample,
318-
"Conv2d": _compute_conv2d_grad_sample, # nn.Conv2d.
319-
"LayerNorm": _compute_layer_norm_grad_sample,
320-
"Conv1D": _custom_compute_conv1d_grad_sample, # HuggingFace Open-AI GPT-2.
321-
"T5LayerNorm": _compute_t5_layer_norm_grad_sample,
322-
"OPTLearnedPositionalEmbedding": _compute_opt_learned_positional_embedding_grad_sample,
317+
nn.Embedding: _compute_embedding_grad_sample,
318+
nn.Linear: _compute_linear_grad_sample,
319+
nn.Conv2d: _compute_conv2d_grad_sample,
320+
nn.LayerNorm: _compute_layer_norm_grad_sample,
321+
transformers.pytorch_utils.Conv1D: _custom_compute_conv1d_grad_sample,
322+
transformers.models.t5.modeling_t5.T5LayerNorm: _compute_t5_layer_norm_grad_sample,
323+
OPTLearnedPositionalEmbedding: _compute_opt_learned_positional_embedding_grad_sample,
323324
}

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"jupyter",
3535
"ml-swissknife",
3636
"opt_einsum",
37+
"pytest"
3738
],
3839
python_requires='~=3.8',
3940
classifiers=[

0 commit comments

Comments
 (0)