Skip to content

Commit b3a73f0

Browse files
committed
optimize base class
1 parent fea645d commit b3a73f0

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

change_detection_pytorch/base/model.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def initialize(self):
1010
if self.classification_head is not None:
1111
init.initialize_head(self.classification_head)
1212

13-
def forward(self, x1, x2):
14-
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
13+
def base_forward(self, x1, x2):
14+
"""Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads"""
1515
if self.siam_encoder:
1616
features = self.encoder(x1), self.encoder(x2)
1717
else:
@@ -27,11 +27,15 @@ def forward(self, x1, x2):
2727

2828
return masks
2929

30-
def predict(self, x):
31-
"""Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
30+
def forward(self, x1, x2):
31+
"""Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads"""
32+
return self.base_forward(x1, x2)
33+
34+
def predict(self, x1, x2):
35+
"""Inference method. Switch model to `eval` mode, call `.forward(x1, x2)` with `torch.no_grad()`
3236
3337
Args:
34-
x: 4D torch tensor with shape (batch_size, channels, height, width)
38+
x1, x2: 4D torch tensor with shape (batch_size, channels, height, width)
3539
3640
Return:
3741
prediction: 4D torch tensor with shape (batch_size, classes, height, width)
@@ -41,6 +45,6 @@ def predict(self, x):
4145
self.eval()
4246

4347
with torch.no_grad():
44-
x = self.forward(x)
48+
x = self.forward(x1, x2)
4549

4650
return x

0 commit comments

Comments
 (0)