Skip to content

Commit 0ab52dd

Browse files
authored
[BugFix] ConvNet forward method with tensors of more than 4 dimensions (#686)
* cnn forward fix * more general code * cnn testing * precommit run check * convnet tests
1 parent 26881b3 commit 0ab52dd

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

test/test_modules.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def test_mlp(
124124
)
125125
@pytest.mark.parametrize("squeeze_output", [False])
126126
@pytest.mark.parametrize("device", get_available_devices())
127+
@pytest.mark.parametrize("batch", [(2,), (2, 2)])
127128
def test_convnet(
129+
batch,
128130
in_features,
129131
depth,
130132
num_cells,
@@ -145,7 +147,6 @@ def test_convnet(
145147
seed=0,
146148
):
147149
torch.manual_seed(seed)
148-
batch = 2
149150
convnet = ConvNet(
150151
in_features=in_features,
151152
depth=depth,
@@ -165,9 +166,9 @@ def test_convnet(
165166
)
166167
if in_features is None:
167168
in_features = 5
168-
x = torch.randn(batch, in_features, input_size, input_size, device=device)
169+
x = torch.randn(*batch, in_features, input_size, input_size, device=device)
169170
y = convnet(x)
170-
assert y.shape == torch.Size([batch, expected_features])
171+
assert y.shape == torch.Size([*batch, expected_features])
171172

172173

173174
@pytest.mark.parametrize(

torchrl/modules/models/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,15 @@ def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module:
463463
layers.append(Squeeze2dLayer())
464464
return layers
465465

466+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
467+
*batch, C, L, W = inputs.shape
468+
if len(batch) > 1:
469+
inputs = inputs.flatten(0, len(batch) - 1)
470+
out = super(ConvNet, self).forward(inputs)
471+
if len(batch) > 1:
472+
out = out.unflatten(0, batch)
473+
return out
474+
466475

467476
class DuelingMlpDQNet(nn.Module):
468477
"""Creates a Dueling MLP Q-network.

0 commit comments

Comments
 (0)