1
+ from builtins import breakpoint
1
2
from math import prod
2
3
from typing import Optional , Tuple , Union
3
4
@@ -126,10 +127,26 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
126
127
127
128
def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
128
129
# non parametric, so return empty
129
- return []
130
+ return None
130
131
131
132
def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
132
- raise NotImplementedError
133
+
134
+ b , c1 , h1 , w1 = x .shape
135
+ c2 , h2 , w2 = val .shape [1 :]
136
+
137
+ weight = torch .ones (c2 , c1 , int (self .scale_factor ), int (self .scale_factor ), device = x .device )
138
+
139
+ tmp = F .conv2d (
140
+ tmp .reshape (- 1 , c2 , h2 , w2 ),
141
+ weight = weight ,
142
+ bias = None ,
143
+ stride = int (self .scale_factor ),
144
+ padding = 0 ,
145
+ dilation = 1 ,
146
+ groups = 1 ,
147
+ )
148
+
149
+ return tmp .reshape (b , c1 * h1 * w1 )
133
150
134
151
135
152
@@ -176,7 +193,7 @@ def compute_reversed_padding(padding, kernel_size=1):
176
193
return kernel_size - 1 - padding
177
194
178
195
179
- class Conv2d (nn .Conv2d ):
196
+ class Conv2d (AbstractJacobian , nn .Conv2d ):
180
197
def __init__ (
181
198
self ,
182
199
in_channels ,
@@ -348,7 +365,7 @@ def _jacobian_wrt_weight_T_mult_right(
348
365
349
366
if use_less_memory :
350
367
# define moving sum for Jt_tmp
351
- Jt_tmp = torch .zeros (b , c2 * c1 * kernel_h * kernel_w , num_of_cols )
368
+ Jt_tmp = torch .zeros (b , c2 * c1 * kernel_h * kernel_w , num_of_cols , device = x . device )
352
369
for i in range (b ):
353
370
# set the weight to the convolution
354
371
input_single_batch = x [i : i + 1 , :, :, :]
@@ -376,7 +393,6 @@ def _jacobian_wrt_weight_T_mult_right(
376
393
Jt_tmp [i , :, :] = Jt_tmp_single_batch
377
394
378
395
else :
379
- # TODO: only works for batch size 1?
380
396
reversed_inputs = torch .flip (x , [- 2 , - 1 ]).movedim (0 , 1 )
381
397
382
398
# convolve each column
@@ -399,6 +415,80 @@ def _jacobian_wrt_weight_T_mult_right(
399
415
400
416
return Jt_tmp
401
417
418
+ def _jacobian_wrt_weight_mult_left (
419
+ self , x : Tensor , val : Tensor , tmp : Tensor , use_less_memory : bool = True
420
+ ) -> Tensor :
421
+ b , c1 , h1 , w1 = x .shape
422
+ c2 , h2 , w2 = val .shape [1 :]
423
+ kernel_h , kernel_w = self .kernel_size
424
+ num_of_rows = tmp .shape [- 2 ]
425
+
426
+ # expand rows as cubes [(output channel)x(output height)x(output width)]
427
+ tmp_rows = tmp .movedim (- 1 ,- 2 ).reshape (b , c2 , h2 , w2 , num_of_rows )
428
+ # see rows as columns of the transposed matrix
429
+ tmpt_cols = tmp_rows
430
+ # transpose the images in (output height)x(output width)
431
+ tmpt_cols = torch .flip (tmpt_cols , [- 3 , - 2 ])
432
+ # switch batch size and output channel
433
+ tmpt_cols = tmpt_cols .movedim (0 ,1 )
434
+
435
+ if use_less_memory :
436
+
437
+ tmp_J = torch .zeros (b , c2 * c1 * kernel_h * kernel_w , num_of_rows , device = x .device )
438
+ for i in range (b ):
439
+ # set the weight to the convolution
440
+ input_single_batch = x [i :i + 1 ,:,:,:]
441
+ reversed_input_single_batch = torch .flip (input_single_batch , [- 2 ,- 1 ]).movedim (0 ,1 )
442
+
443
+ tmp_single_batch = tmpt_cols [:,i :i + 1 ,:,:,:]
444
+
445
+ # convolve each column
446
+ tmp_J_single_batch = (
447
+ F .conv2d (
448
+ tmpt_cols .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , 1 , h2 , w2 ),
449
+ weight = reversed_input_single_batch ,
450
+ bias = None ,
451
+ stride = self .stride ,
452
+ padding = self .dw_padding ,
453
+ dilation = self .dilation ,
454
+ groups = self .groups ,
455
+ )
456
+ .reshape (c2 , * tmp_single_batch .shape [4 :], c1 , kernel_h , kernel_w )
457
+ .movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
458
+ )
459
+
460
+ # reshape as a (num of weights)x(num of column) matrix
461
+ tmp_J_single_batch = tmp_J_single_batch .reshape (c2 * c1 * kernel_h * kernel_w , num_of_rows )
462
+ tmp_J [i , :, :] = tmp_J_single_batch
463
+
464
+ # transpose
465
+ tmp_J = tmp_J .movedim (- 1 ,- 2 )
466
+ else :
467
+ # set the weight to the convolution
468
+ reversed_inputs = torch .flip (x , [- 2 ,- 1 ]).movedim (0 ,1 )
469
+
470
+ # convolve each column
471
+ Jt_tmptt_cols = (
472
+ F .conv2d (
473
+ tmp_single_batch .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , b , h2 , w2 ),
474
+ weight = reversed_inputs ,
475
+ bias = None ,
476
+ stride = self .stride ,
477
+ padding = self .dw_padding ,
478
+ dilation = self .dilation ,
479
+ groups = self .groups ,
480
+ )
481
+ .reshape (c2 , * tmp_single_batch .shape [4 :], c1 , kernel_h , kernel_w )
482
+ .movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
483
+ )
484
+
485
+ # reshape as a (num of input)x(num of output) matrix, one for each batch size
486
+ Jt_tmptt_cols = Jt_tmptt_cols .reshape (c2 * c1 * kernel_h * kernel_w ,num_of_rows )
487
+ # transpose
488
+ tmp_J = Jt_tmptt_cols .movedim (0 ,1 )
489
+
490
+ return tmp
491
+
402
492
def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
403
493
if diag :
404
494
return self ._jacobian_wrt_input_diag_sandwich (x , val , tmp )
@@ -467,15 +557,15 @@ def _jacobian_wrt_weight_diag_sandwich(self, x: Tensor, val: Tensor, diag_tmp: T
467
557
468
558
output_tmp_single_batch = (
469
559
F .conv2d (
470
- input_tmp_single_batch .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , c1 , h1 , w1 ),
560
+ input_tmp_single_batch .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , 1 , h2 , w2 ),
471
561
weight = weigth_sq ,
472
562
bias = None ,
473
563
stride = self .stride ,
474
564
padding = self .dw_padding ,
475
565
dilation = self .dilation ,
476
566
groups = self .groups ,
477
567
)
478
- .reshape (b , * input_tmp_single_batch .shape [4 :], c2 , h2 , w2 )
568
+ .reshape (c2 , * input_tmp_single_batch .shape [4 :], c1 , kernel_h , kernel_w )
479
569
.movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
480
570
)
481
571
@@ -559,6 +649,12 @@ def forward(self, x: Tensor) -> Tensor:
559
649
def _jacobian_wrt_input_mult_left_vec (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
560
650
return jac_in .reshape (jac_in .shape [0 ], * self .dims , * jac_in .shape [2 :])
561
651
652
+ def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
653
+ return tmp
654
+
655
+ def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
656
+ return None
657
+
562
658
563
659
class Flatten (AbstractJacobian , nn .Module ):
564
660
def __init__ (self ):
@@ -576,6 +672,12 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
576
672
if jac_in .ndim == 9 : # 3d conv
577
673
return jac_in .reshape (jac_in .shape [0 ], - 1 , * jac_in .shape [5 :])
578
674
675
+ def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
676
+ return tmp
677
+
678
+ def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
679
+ return None
680
+
579
681
580
682
class AbstractActivationJacobian :
581
683
def _jacobian_wrt_input_mult_left_vec (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
@@ -685,20 +787,24 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
685
787
686
788
def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
687
789
# non parametric, so return empty
688
- return []
790
+ return None
689
791
690
792
def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
691
- return self ._jacobian_wrt_input_full_sandwich (x , val , tmp )
793
+ return self ._jacobian_wrt_input_diag_sandwich (x , val , tmp )
692
794
693
- def _jacobian_wrt_input_full_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
795
+ def _jacobian_wrt_input_diag_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
796
+ b , c1 , h1 , w1 = x .shape
797
+ c2 , h2 , w2 = val .shape [1 :]
694
798
695
799
new_tmp = torch .zeros_like (x )
696
- new_tmp [self .idx ] = tmp
697
-
698
- return new_tmp
800
+ new_tmp = new_tmp .reshape (b * c1 , h1 * w1 )
801
+ idx = self .idx .reshape (b * c2 , h2 * w2 )
802
+ arange_repeated = torch .repeat_interleave (torch .arange (b * c1 ), h2 * w2 ).long ()
803
+ arange_repeated = arange_repeated .reshape (b * c2 , h2 * w2 )
804
+
805
+ new_tmp [arange_repeated , idx ] = tmp .reshape (b * c2 , h2 * w2 )
699
806
700
- def _jacobian_wrt_input_diag_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
701
- pass
807
+ return new_tmp .reshape (b , c1 * h1 * w1 )
702
808
703
809
704
810
class MaxPool3d (AbstractJacobian , nn .MaxPool3d ):
@@ -788,7 +894,7 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
788
894
789
895
def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
790
896
# non parametric, so return empty
791
- return []
897
+ return None
792
898
793
899
def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
794
900
if diag :
0 commit comments