Skip to content

Opacus + GRU/LSTM fails with NotImplementedError #783

@AbdessamedSed

Description

@AbdessamedSed

🐛 Bug

Opacus + GRU/LSTM fails with
NotImplementedError: Cannot access storage of TensorWrapper

Please reproduce using our template Colab and post here the link

To Reproduce

Steps to reproduce the behavior:

  1. Define a simple GRU model (see code below).

  2. Wrap it with PrivacyEngine.make_private(...).

  3. Run a training step on dummy data → error occurs inside _VF.gru.

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from opacus import PrivacyEngine

=== GRU model ===

class GRUNet(nn.Module):
def init(self, input_dim=10, hidden_dim=16, output_dim=1):
super().init()
self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)

def forward(self, x):
    out, _ = self.gru(x)
    return self.fc(out[:, -1, :])

Dummy dataset

X = torch.randn(32, 5, 10)
y = torch.randn(32, 1)
dataloader = DataLoader(TensorDataset(X, y), batch_size=8)

Model + optimizer

model = GRUNet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

PrivacyEngine

privacy_engine = PrivacyEngine()
model, optimizer, dataloader = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=dataloader,
noise_multiplier=1.0,
max_grad_norm=1.0,
)

Training step (crashes)

for x, y in dataloader:
preds = model(x)
loss = criterion(preds, y)
loss.backward()
optimizer.step()
break

Error Trace

NotImplementedError: Cannot access storage of TensorWrapper
File "torch/nn/modules/rnn.py", line 1391, in forward
result = _VF.gru(input, hx, self._flat_weights, ...)

Expected behavior

The GRU model should train with Opacus’ PrivacyEngine attached (with DP-SGD) without needing to disable cuDNN or hack around RNN internals.

Environment

PyTorch Version: 2.3.1 (Colab)
Opacus Version: 1.4.0
OS: Ubuntu 22.04 (Google Colab)
How installed: pip
Python version: 3.12
CUDA/cuDNN: CUDA 12.1 / cuDNN 8.9
GPU: Tesla T4 (Colab)

Additional context

Disabling flatten_parameters removes one error, but _VF.gru still fails.

Disabling torch.backends.cudnn.enabled = False avoids the crash, but it slows down training.

Seems related to cuDNN fast-path accessing .data_ptr() on Opacus’ TensorWrapper.

Is there an official workaround or fix planned for using GRU/LSTM with Opacus?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions