Skip to content

Commit cc35fd6

Browse files
author
Frederik Rahbaek Warburg
committed
all functions implemented, but still not fully tested
1 parent c59cefa commit cc35fd6

File tree

1 file changed

+122
-16
lines changed

1 file changed

+122
-16
lines changed

stochman/nnj.py

Lines changed: 122 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from builtins import breakpoint
12
from math import prod
23
from typing import Optional, Tuple, Union
34

@@ -126,10 +127,26 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
126127

127128
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
128129
# non parametric, so return empty
129-
return []
130+
return None
130131

131132
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
132-
raise NotImplementedError
133+
134+
b, c1, h1, w1 = x.shape
135+
c2, h2, w2 = val.shape[1:]
136+
137+
weight = torch.ones(c2, c1, int(self.scale_factor), int(self.scale_factor), device=x.device)
138+
139+
tmp = F.conv2d(
140+
tmp.reshape(-1, c2, h2, w2),
141+
weight=weight,
142+
bias=None,
143+
stride=int(self.scale_factor),
144+
padding=0,
145+
dilation=1,
146+
groups=1,
147+
)
148+
149+
return tmp.reshape(b, c1*h1*w1)
133150

134151

135152

@@ -176,7 +193,7 @@ def compute_reversed_padding(padding, kernel_size=1):
176193
return kernel_size - 1 - padding
177194

178195

179-
class Conv2d(nn.Conv2d):
196+
class Conv2d(AbstractJacobian, nn.Conv2d):
180197
def __init__(
181198
self,
182199
in_channels,
@@ -348,7 +365,7 @@ def _jacobian_wrt_weight_T_mult_right(
348365

349366
if use_less_memory:
350367
# define moving sum for Jt_tmp
351-
Jt_tmp = torch.zeros(b, c2 * c1 * kernel_h * kernel_w, num_of_cols)
368+
Jt_tmp = torch.zeros(b, c2 * c1 * kernel_h * kernel_w, num_of_cols, device=x.device)
352369
for i in range(b):
353370
# set the weight to the convolution
354371
input_single_batch = x[i : i + 1, :, :, :]
@@ -376,7 +393,6 @@ def _jacobian_wrt_weight_T_mult_right(
376393
Jt_tmp[i, :, :] = Jt_tmp_single_batch
377394

378395
else:
379-
# TODO: only works for batch size 1?
380396
reversed_inputs = torch.flip(x, [-2, -1]).movedim(0, 1)
381397

382398
# convolve each column
@@ -399,6 +415,80 @@ def _jacobian_wrt_weight_T_mult_right(
399415

400416
return Jt_tmp
401417

418+
def _jacobian_wrt_weight_mult_left(
419+
self, x: Tensor, val: Tensor, tmp: Tensor, use_less_memory: bool = True
420+
) -> Tensor:
421+
b, c1, h1, w1 = x.shape
422+
c2, h2, w2 = val.shape[1:]
423+
kernel_h, kernel_w = self.kernel_size
424+
num_of_rows = tmp.shape[-2]
425+
426+
# expand rows as cubes [(output channel)x(output height)x(output width)]
427+
tmp_rows = tmp.movedim(-1,-2).reshape(b, c2, h2, w2, num_of_rows)
428+
# see rows as columns of the transposed matrix
429+
tmpt_cols = tmp_rows
430+
# transpose the images in (output height)x(output width)
431+
tmpt_cols = torch.flip(tmpt_cols, [-3, -2])
432+
# switch batch size and output channel
433+
tmpt_cols = tmpt_cols.movedim(0,1)
434+
435+
if use_less_memory:
436+
437+
tmp_J = torch.zeros(b, c2*c1*kernel_h*kernel_w, num_of_rows, device=x.device)
438+
for i in range(b):
439+
# set the weight to the convolution
440+
input_single_batch = x[i:i+1,:,:,:]
441+
reversed_input_single_batch = torch.flip(input_single_batch, [-2,-1]).movedim(0,1)
442+
443+
tmp_single_batch = tmpt_cols[:,i:i+1,:,:,:]
444+
445+
# convolve each column
446+
tmp_J_single_batch = (
447+
F.conv2d(
448+
tmpt_cols.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, 1, h2, w2),
449+
weight=reversed_input_single_batch,
450+
bias=None,
451+
stride=self.stride,
452+
padding=self.dw_padding,
453+
dilation=self.dilation,
454+
groups=self.groups,
455+
)
456+
.reshape(c2, *tmp_single_batch.shape[4:], c1, kernel_h, kernel_w)
457+
.movedim((-3, -2, -1), (1, 2, 3))
458+
)
459+
460+
# reshape as a (num of weights)x(num of column) matrix
461+
tmp_J_single_batch = tmp_J_single_batch.reshape(c2*c1*kernel_h*kernel_w, num_of_rows)
462+
tmp_J[i, :, :] = tmp_J_single_batch
463+
464+
# transpose
465+
tmp_J = tmp_J.movedim(-1,-2)
466+
else:
467+
# set the weight to the convolution
468+
reversed_inputs = torch.flip(x, [-2,-1]).movedim(0,1)
469+
470+
# convolve each column
471+
Jt_tmptt_cols = (
472+
F.conv2d(
473+
tmp_single_batch.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, b, h2, w2),
474+
weight=reversed_inputs,
475+
bias=None,
476+
stride=self.stride,
477+
padding=self.dw_padding,
478+
dilation=self.dilation,
479+
groups=self.groups,
480+
)
481+
.reshape(c2, *tmp_single_batch.shape[4:], c1, kernel_h, kernel_w)
482+
.movedim((-3, -2, -1), (1, 2, 3))
483+
)
484+
485+
# reshape as a (num of input)x(num of output) matrix, one for each batch size
486+
Jt_tmptt_cols = Jt_tmptt_cols.reshape(c2*c1*kernel_h*kernel_w,num_of_rows)
487+
# transpose
488+
tmp_J = Jt_tmptt_cols.movedim(0,1)
489+
490+
return tmp
491+
402492
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
403493
if diag:
404494
return self._jacobian_wrt_input_diag_sandwich(x, val, tmp)
@@ -467,15 +557,15 @@ def _jacobian_wrt_weight_diag_sandwich(self, x: Tensor, val: Tensor, diag_tmp: T
467557

468558
output_tmp_single_batch = (
469559
F.conv2d(
470-
input_tmp_single_batch.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, c1, h1, w1),
560+
input_tmp_single_batch.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, 1, h2, w2),
471561
weight=weigth_sq,
472562
bias=None,
473563
stride=self.stride,
474564
padding=self.dw_padding,
475565
dilation=self.dilation,
476566
groups=self.groups,
477567
)
478-
.reshape(b, *input_tmp_single_batch.shape[4:], c2, h2, w2)
568+
.reshape(c2, *input_tmp_single_batch.shape[4:], c1, kernel_h, kernel_w)
479569
.movedim((-3, -2, -1), (1, 2, 3))
480570
)
481571

@@ -559,6 +649,12 @@ def forward(self, x: Tensor) -> Tensor:
559649
def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
560650
return jac_in.reshape(jac_in.shape[0], *self.dims, *jac_in.shape[2:])
561651

652+
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
653+
return tmp
654+
655+
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
656+
return None
657+
562658

563659
class Flatten(AbstractJacobian, nn.Module):
564660
def __init__(self):
@@ -576,6 +672,12 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
576672
if jac_in.ndim == 9: # 3d conv
577673
return jac_in.reshape(jac_in.shape[0], -1, *jac_in.shape[5:])
578674

675+
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
676+
return tmp
677+
678+
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
679+
return None
680+
579681

580682
class AbstractActivationJacobian:
581683
def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
@@ -685,20 +787,24 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
685787

686788
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
687789
# non parametric, so return empty
688-
return []
790+
return None
689791

690792
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
691-
return self._jacobian_wrt_input_full_sandwich(x, val, tmp)
793+
return self._jacobian_wrt_input_diag_sandwich(x, val, tmp)
692794

693-
def _jacobian_wrt_input_full_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
795+
def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
796+
b, c1, h1, w1 = x.shape
797+
c2, h2, w2 = val.shape[1:]
694798

695799
new_tmp = torch.zeros_like(x)
696-
new_tmp[self.idx] = tmp
697-
698-
return new_tmp
800+
new_tmp = new_tmp.reshape(b * c1, h1 * w1)
801+
idx = self.idx.reshape(b * c2, h2 * w2)
802+
arange_repeated = torch.repeat_interleave(torch.arange(b * c1), h2 * w2).long()
803+
arange_repeated = arange_repeated.reshape(b*c2, h2*w2)
804+
805+
new_tmp[arange_repeated, idx] = tmp.reshape(b*c2, h2*w2)
699806

700-
def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
701-
pass
807+
return new_tmp.reshape(b, c1 * h1 * w1)
702808

703809

704810
class MaxPool3d(AbstractJacobian, nn.MaxPool3d):
@@ -788,7 +894,7 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
788894

789895
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
790896
# non parametric, so return empty
791-
return []
897+
return None
792898

793899
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
794900
if diag:

0 commit comments

Comments
 (0)