@@ -77,8 +77,10 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
77
77
def _jacobian_wrt_input_transpose_mult_left_vec (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
78
78
return F .linear (jac_in .movedim (1 , - 1 ), self .weight .T , bias = None ).movedim (- 1 , 1 )
79
79
80
- def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
81
-
80
+ def _jacobian_wrt_weight_sandwich (
81
+ self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False
82
+ ) -> Tensor :
83
+
82
84
b , c = x .shape
83
85
diag_elements = torch .diagonal (tmp , dim1 = 1 , dim2 = 2 )
84
86
feat_k2 = (x ** 2 ).unsqueeze (1 )
@@ -125,7 +127,9 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
125
127
.movedim (dims2 , dims1 )
126
128
)
127
129
128
- def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
130
+ def _jacobian_wrt_weight_sandwich (
131
+ self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False
132
+ ) -> Tensor :
129
133
# non parametric, so return empty
130
134
return None
131
135
@@ -137,17 +141,16 @@ def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag
137
141
weight = torch .ones (c2 , c1 , int (self .scale_factor ), int (self .scale_factor ), device = x .device )
138
142
139
143
tmp = F .conv2d (
140
- tmp .reshape (- 1 , c2 , h2 , w2 ),
141
- weight = weight ,
142
- bias = None ,
143
- stride = int (self .scale_factor ),
144
- padding = 0 ,
145
- dilation = 1 ,
146
- groups = 1 ,
147
- )
148
-
149
- return tmp .reshape (b , c1 * h1 * w1 )
144
+ tmp .reshape (- 1 , c2 , h2 , w2 ),
145
+ weight = weight ,
146
+ bias = None ,
147
+ stride = int (self .scale_factor ),
148
+ padding = 0 ,
149
+ dilation = 1 ,
150
+ groups = 1 ,
151
+ )
150
152
153
+ return tmp .reshape (b , c1 * h1 * w1 )
151
154
152
155
153
156
class Conv1d (AbstractJacobian , nn .Conv1d ):
@@ -416,36 +419,36 @@ def _jacobian_wrt_weight_T_mult_right(
416
419
return Jt_tmp
417
420
418
421
def _jacobian_wrt_weight_mult_left (
419
- self , x : Tensor , val : Tensor , tmp : Tensor , use_less_memory : bool = True
420
- ) -> Tensor :
422
+ self , x : Tensor , val : Tensor , tmp : Tensor , use_less_memory : bool = True
423
+ ) -> Tensor :
421
424
b , c1 , h1 , w1 = x .shape
422
425
c2 , h2 , w2 = val .shape [1 :]
423
426
kernel_h , kernel_w = self .kernel_size
424
427
num_of_rows = tmp .shape [- 2 ]
425
428
426
429
# expand rows as cubes [(output channel)x(output height)x(output width)]
427
- tmp_rows = tmp .movedim (- 1 ,- 2 ).reshape (b , c2 , h2 , w2 , num_of_rows )
430
+ tmp_rows = tmp .movedim (- 1 , - 2 ).reshape (b , c2 , h2 , w2 , num_of_rows )
428
431
# see rows as columns of the transposed matrix
429
432
tmpt_cols = tmp_rows
430
433
# transpose the images in (output height)x(output width)
431
434
tmpt_cols = torch .flip (tmpt_cols , [- 3 , - 2 ])
432
435
# switch batch size and output channel
433
- tmpt_cols = tmpt_cols .movedim (0 ,1 )
436
+ tmpt_cols = tmpt_cols .movedim (0 , 1 )
434
437
435
438
if use_less_memory :
436
439
437
- tmp_J = torch .zeros (b , c2 * c1 * kernel_h * kernel_w , num_of_rows , device = x .device )
440
+ tmp_J = torch .zeros (b , c2 * c1 * kernel_h * kernel_w , num_of_rows , device = x .device )
438
441
for i in range (b ):
439
442
# set the weight to the convolution
440
- input_single_batch = x [i : i + 1 ,:,:, :]
441
- reversed_input_single_batch = torch .flip (input_single_batch , [- 2 ,- 1 ]).movedim (0 ,1 )
442
-
443
- tmp_single_batch = tmpt_cols [:,i : i + 1 ,:,:, :]
443
+ input_single_batch = x [i : i + 1 , :, :, :]
444
+ reversed_input_single_batch = torch .flip (input_single_batch , [- 2 , - 1 ]).movedim (0 , 1 )
445
+
446
+ tmp_single_batch = tmpt_cols [:, i : i + 1 , :, :, :]
444
447
445
448
# convolve each column
446
449
tmp_J_single_batch = (
447
450
F .conv2d (
448
- tmpt_cols .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , 1 , h2 , w2 ),
451
+ tmp_single_batch .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , 1 , h2 , w2 ),
449
452
weight = reversed_input_single_batch ,
450
453
bias = None ,
451
454
stride = self .stride ,
@@ -458,14 +461,14 @@ def _jacobian_wrt_weight_mult_left(
458
461
)
459
462
460
463
# reshape as a (num of weights)x(num of column) matrix
461
- tmp_J_single_batch = tmp_J_single_batch .reshape (c2 * c1 * kernel_h * kernel_w , num_of_rows )
464
+ tmp_J_single_batch = tmp_J_single_batch .reshape (c2 * c1 * kernel_h * kernel_w , num_of_rows )
462
465
tmp_J [i , :, :] = tmp_J_single_batch
463
466
464
467
# transpose
465
- tmp_J = tmp_J .movedim (- 1 ,- 2 )
466
- else :
468
+ tmp_J = tmp_J .movedim (- 1 , - 2 )
469
+ else :
467
470
# set the weight to the convolution
468
- reversed_inputs = torch .flip (x , [- 2 ,- 1 ]).movedim (0 ,1 )
471
+ reversed_inputs = torch .flip (x , [- 2 , - 1 ]).movedim (0 , 1 )
469
472
470
473
# convolve each column
471
474
Jt_tmptt_cols = (
@@ -483,9 +486,9 @@ def _jacobian_wrt_weight_mult_left(
483
486
)
484
487
485
488
# reshape as a (num of input)x(num of output) matrix, one for each batch size
486
- Jt_tmptt_cols = Jt_tmptt_cols .reshape (c2 * c1 * kernel_h * kernel_w ,num_of_rows )
489
+ Jt_tmptt_cols = Jt_tmptt_cols .reshape (c2 * c1 * kernel_h * kernel_w , num_of_rows )
487
490
# transpose
488
- tmp_J = Jt_tmptt_cols .movedim (0 ,1 )
491
+ tmp_J = Jt_tmptt_cols .movedim (0 , 1 )
489
492
490
493
return tmp
491
494
@@ -495,7 +498,9 @@ def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag
495
498
else :
496
499
return self ._jacobian_wrt_input_full_sandwich (x , val , tmp )
497
500
498
- def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
501
+ def _jacobian_wrt_weight_sandwich (
502
+ self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False
503
+ ) -> Tensor :
499
504
if diag :
500
505
return self ._jacobian_wrt_weight_diag_sandwich (x , val , tmp )
501
506
else :
@@ -652,7 +657,9 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
652
657
def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
653
658
return tmp
654
659
655
- def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
660
+ def _jacobian_wrt_weight_sandwich (
661
+ self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False
662
+ ) -> Tensor :
656
663
return None
657
664
658
665
@@ -675,7 +682,9 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
675
682
def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
676
683
return tmp
677
684
678
- def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
685
+ def _jacobian_wrt_weight_sandwich (
686
+ self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False
687
+ ) -> Tensor :
679
688
return None
680
689
681
690
@@ -785,7 +794,9 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
785
794
jac_in = jac_in [arange_repeated , idx , :, :, :].reshape (* val .shape , * jac_in_orig_shape [4 :])
786
795
return jac_in
787
796
788
- def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
797
+ def _jacobian_wrt_weight_sandwich (
798
+ self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False
799
+ ) -> Tensor :
789
800
# non parametric, so return empty
790
801
return None
791
802
@@ -800,9 +811,9 @@ def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor)
800
811
new_tmp = new_tmp .reshape (b * c1 , h1 * w1 )
801
812
idx = self .idx .reshape (b * c2 , h2 * w2 )
802
813
arange_repeated = torch .repeat_interleave (torch .arange (b * c1 ), h2 * w2 ).long ()
803
- arange_repeated = arange_repeated .reshape (b * c2 , h2 * w2 )
804
-
805
- new_tmp [arange_repeated , idx ] = tmp .reshape (b * c2 , h2 * w2 )
814
+ arange_repeated = arange_repeated .reshape (b * c2 , h2 * w2 )
815
+
816
+ new_tmp [arange_repeated , idx ] = tmp .reshape (b * c2 , h2 * w2 )
806
817
807
818
return new_tmp .reshape (b , c1 * h1 * w1 )
808
819
@@ -892,7 +903,9 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
892
903
jac = 1.0 - val ** 2
893
904
return jac
894
905
895
- def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
906
+ def _jacobian_wrt_weight_sandwich (
907
+ self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False
908
+ ) -> Tensor :
896
909
# non parametric, so return empty
897
910
return None
898
911
0 commit comments