The implementation for Residual() is incorrect in repvit.py.
if isinstance(self.m, Conv2d_BN):
m = self.m.fuse()
assert(m.groups == m.in_channels)
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
identity = torch.nn.functional.pad(identity, [1,1,1,1])
m.weight += identity.to(m.weight.device)
return m
this is for converting 1x1 conv to 3x3 conv. For identity connection, the implementation is more like
identity = torch.zeros_like(m.weight)
for i in range(m.weight.shape[0]):
identity[i, i, 1, 1] = 1.0 # center of 3x3 kernel
m.weight += identity