@@ -307,6 +307,72 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
307
307
return jac
308
308
309
309
310
+ class MaxPool1d (AbstractJacobian , nn .MaxPool1d ):
311
+ def forward (self , input : Tensor ):
312
+ val , idx = F .max_pool1d (
313
+ input , self .kernel_size , self .stride ,
314
+ self .padding , self .dilation , self .ceil_mode ,
315
+ return_indices = True
316
+ )
317
+ self .idx = idx
318
+ return val
319
+
320
+ def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
321
+ b , c1 , l1 = x .shape
322
+ c2 , l2 = val .shape [1 :]
323
+
324
+ jac_in_orig_shape = jac_in .shape
325
+ jac_in = jac_in .reshape (- 1 , l1 , * jac_in_orig_shape [3 :])
326
+ arange_repeated = torch .repeat_interleave (torch .arange (b * c1 ), l2 ).long ()
327
+ idx = self .idx .reshape (- 1 )
328
+ jac_in = jac_in [arange_repeated , idx , :, :].reshape (* val .shape , * jac_in_orig_shape [3 :])
329
+ return jac_in
330
+
331
+
332
+ class MaxPool2d (AbstractJacobian , nn .MaxPool2d ):
333
+ def forward (self , input : Tensor ):
334
+ val , idx = F .max_pool2d (
335
+ input , self .kernel_size , self .stride ,
336
+ self .padding , self .dilation , self .ceil_mode ,
337
+ return_indices = True
338
+ )
339
+ self .idx = idx
340
+ return val
341
+
342
+ def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
343
+ b , c1 , h1 , w1 = x .shape
344
+ c2 , h2 , w2 = val .shape [1 :]
345
+
346
+ jac_in_orig_shape = jac_in .shape
347
+ jac_in = jac_in .reshape (- 1 , h1 * w1 , * jac_in_orig_shape [4 :])
348
+ arange_repeated = torch .repeat_interleave (torch .arange (b * c1 ), h2 * w2 ).long ()
349
+ idx = self .idx .reshape (- 1 )
350
+ jac_in = jac_in [arange_repeated , idx , :, :, :].reshape (* val .shape , * jac_in_orig_shape [4 :])
351
+ return jac_in
352
+
353
+
354
+ class MaxPool3d (AbstractJacobian , nn .MaxPool3d ):
355
+ def forward (self , input : Tensor ):
356
+ val , idx = F .max_pool3d (
357
+ input , self .kernel_size , self .stride ,
358
+ self .padding , self .dilation , self .ceil_mode ,
359
+ return_indices = True
360
+ )
361
+ self .idx = idx
362
+ return val
363
+
364
+ def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
365
+ b , c1 , d1 , h1 , w1 = x .shape
366
+ c2 , d2 , h2 , w2 = val .shape [1 :]
367
+
368
+ jac_in_orig_shape = jac_in .shape
369
+ jac_in = jac_in .reshape (- 1 , d1 * h1 * w1 , * jac_in_orig_shape [5 :])
370
+ arange_repeated = torch .repeat_interleave (torch .arange (b * c1 ), h2 * d2 * w2 ).long ()
371
+ idx = self .idx .reshape (- 1 )
372
+ jac_in = jac_in [arange_repeated , idx , :, :].reshape (* val .shape , * jac_in_orig_shape [5 :])
373
+ return jac_in
374
+
375
+
310
376
class Sigmoid (AbstractActivationJacobian , nn .Sigmoid ):
311
377
def _jacobian (self , x : Tensor , val : Tensor ) -> Tensor :
312
378
jac = val * (1.0 - val )
0 commit comments