Skip to content

Commit 9280b54

Browse files
authored
Merge pull request #18 from MachineLearningLifeScience/nnj_maxpooling
added maxpooling to nnj and corresponding tests
2 parents be84b2a + e2965e3 commit 9280b54

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

stochman/nnj.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,72 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
307307
return jac
308308

309309

310+
class MaxPool1d(AbstractJacobian, nn.MaxPool1d):
311+
def forward(self, input: Tensor):
312+
val, idx = F.max_pool1d(
313+
input, self.kernel_size, self.stride,
314+
self.padding, self.dilation, self.ceil_mode,
315+
return_indices=True
316+
)
317+
self.idx = idx
318+
return val
319+
320+
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
321+
b, c1, l1 = x.shape
322+
c2, l2 = val.shape[1:]
323+
324+
jac_in_orig_shape = jac_in.shape
325+
jac_in = jac_in.reshape(-1, l1, *jac_in_orig_shape[3:])
326+
arange_repeated = torch.repeat_interleave(torch.arange(b * c1), l2).long()
327+
idx = self.idx.reshape(-1)
328+
jac_in = jac_in[arange_repeated, idx, :, :].reshape(*val.shape, *jac_in_orig_shape[3:])
329+
return jac_in
330+
331+
332+
class MaxPool2d(AbstractJacobian, nn.MaxPool2d):
333+
def forward(self, input: Tensor):
334+
val, idx = F.max_pool2d(
335+
input, self.kernel_size, self.stride,
336+
self.padding, self.dilation, self.ceil_mode,
337+
return_indices=True
338+
)
339+
self.idx = idx
340+
return val
341+
342+
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
343+
b, c1, h1, w1 = x.shape
344+
c2, h2, w2 = val.shape[1:]
345+
346+
jac_in_orig_shape = jac_in.shape
347+
jac_in = jac_in.reshape(-1, h1 * w1, *jac_in_orig_shape[4:])
348+
arange_repeated = torch.repeat_interleave(torch.arange(b * c1), h2 * w2).long()
349+
idx = self.idx.reshape(-1)
350+
jac_in = jac_in[arange_repeated, idx, :, :, :].reshape(*val.shape, *jac_in_orig_shape[4:])
351+
return jac_in
352+
353+
354+
class MaxPool3d(AbstractJacobian, nn.MaxPool3d):
355+
def forward(self, input: Tensor):
356+
val, idx = F.max_pool3d(
357+
input, self.kernel_size, self.stride,
358+
self.padding, self.dilation, self.ceil_mode,
359+
return_indices=True
360+
)
361+
self.idx = idx
362+
return val
363+
364+
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
365+
b, c1, d1, h1, w1 = x.shape
366+
c2, d2, h2, w2 = val.shape[1:]
367+
368+
jac_in_orig_shape = jac_in.shape
369+
jac_in = jac_in.reshape(-1, d1 * h1 * w1, *jac_in_orig_shape[5:])
370+
arange_repeated = torch.repeat_interleave(torch.arange(b * c1), h2 * d2 * w2).long()
371+
idx = self.idx.reshape(-1)
372+
jac_in = jac_in[arange_repeated, idx, :, :].reshape(*val.shape, *jac_in_orig_shape[5:])
373+
return jac_in
374+
375+
310376
class Sigmoid(AbstractActivationJacobian, nn.Sigmoid):
311377
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
312378
jac = val * (1.0 - val)

tests/test_nnj.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
130130
(nnj.Sequential(nnj.Conv1d(_features, 3, 3), nnj.BatchNorm1d(3)), _1d_conv_input_shape),
131131
(nnj.Sequential(nnj.Conv2d(_features, 3, 3), nnj.BatchNorm2d(3)), _2d_conv_input_shape),
132132
(nnj.Sequential(nnj.Conv3d(_features, 3, 3), nnj.BatchNorm3d(3)), _3d_conv_input_shape),
133+
(nnj.Sequential(nnj.Conv1d(_features, 3, 3), nnj.MaxPool1d(2)), _1d_conv_input_shape),
134+
(nnj.Sequential(nnj.Conv2d(_features, 3, 3), nnj.MaxPool2d(2)), _2d_conv_input_shape),
135+
(nnj.Sequential(nnj.Conv3d(_features, 3, 3), nnj.MaxPool3d(2)), _3d_conv_input_shape),
133136
],
134137
)
135138
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])

0 commit comments

Comments
 (0)