@@ -10,8 +10,8 @@ def initialize(self):
10
10
if self .classification_head is not None :
11
11
init .initialize_head (self .classification_head )
12
12
13
- def forward (self , x1 , x2 ):
14
- """Sequentially pass `x ` trough model`s encoder, decoder and heads"""
13
+ def base_forward (self , x1 , x2 ):
14
+ """Sequentially pass `x1` `x2 ` trough model`s encoder, decoder and heads"""
15
15
if self .siam_encoder :
16
16
features = self .encoder (x1 ), self .encoder (x2 )
17
17
else :
@@ -27,11 +27,15 @@ def forward(self, x1, x2):
27
27
28
28
return masks
29
29
30
- def predict (self , x ):
31
- """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
30
+ def forward (self , x1 , x2 ):
31
+ """Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads"""
32
+ return self .base_forward (x1 , x2 )
33
+
34
+ def predict (self , x1 , x2 ):
35
+ """Inference method. Switch model to `eval` mode, call `.forward(x1, x2)` with `torch.no_grad()`
32
36
33
37
Args:
34
- x : 4D torch tensor with shape (batch_size, channels, height, width)
38
+ x1, x2 : 4D torch tensor with shape (batch_size, channels, height, width)
35
39
36
40
Return:
37
41
prediction: 4D torch tensor with shape (batch_size, classes, height, width)
@@ -41,6 +45,6 @@ def predict(self, x):
41
45
self .eval ()
42
46
43
47
with torch .no_grad ():
44
- x = self .forward (x )
48
+ x = self .forward (x1 , x2 )
45
49
46
50
return x
0 commit comments