Skip to content

Commit b217f5f

Browse files
authored
Merge pull request Kohulan#67 from Kohulan/development
fix: use checkpoints properly for predictions Kohulan#65 and Kohulan#66
2 parents 7be1e9f + 7057fab commit b217f5f

File tree

10 files changed

+386
-549
lines changed

10 files changed

+386
-549
lines changed

DECIMER/Efficient_Net_encoder.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,31 @@
11
# EfficientNet-V2 config
22
import tensorflow as tf
3-
import DECIMER.efficientnetv2 as efficientnetv2
3+
import DECIMER.efficientnetv2
4+
from DECIMER.efficientnetv2 import effnetv2_model
5+
from DECIMER.efficientnetv2 import effnetv2_configs
46

57
BATCH_SIZE_DEBUG = 2
6-
MODEL = "efficientnetv2-b3" # @param
8+
MODEL = "efficientnetv2-m" # @param
79

810

911
# Define encoder
1012
def get_efficientnetv2_backbone(
11-
model_name, include_top=False, input_shape=(299, 299, 3), pooling=None, weights=None
13+
model_name, include_top=False, input_shape=(512, 512, 3), pooling=None, weights=None
1214
):
13-
"""Initiate and get the desired Efficient-Net V2 backbone as encoder
14-
15-
Args:
16-
model_name (str): Name of the Efficient-Net V2 model
17-
include_top (bool, optional): Defaults to False.
18-
input_shape (tuple, optional): Image shape. Defaults to (299, 299, 3).
19-
pooling (int, optional): Max pooling values. Defaults to None.
20-
weights ( optional): Pretrained weights. Defaults to None.
21-
22-
Raises:
23-
NotImplementedError: At this time we only want to use the raw
24-
25-
Returns:
26-
Efficient Net V2 backbone
27-
"""
2815
# Catch unsupported arguments
2916
if pooling or weights or include_top:
3017
raise NotImplementedError(
3118
"\n...At this time we only want to use the raw "
3219
"(no pretraining), headless, features with no pooling ...\n"
3320
)
34-
backbone = efficientnetv2.effnetv2_model.EffNetV2Model(model_name=model_name)
21+
backbone = effnetv2_model.EffNetV2Model(model_name=model_name)
3522
backbone(
3623
tf.ones((BATCH_SIZE_DEBUG, *input_shape)), training=False, features_only=True
3724
)
3825
return backbone
3926

4027

4128
class Encoder(tf.keras.Model):
42-
"""Encoder class
43-
44-
Args:
45-
tf (_type_): tensorflow model module
46-
"""
47-
4829
def __init__(
4930
self,
5031
image_embedding_dim,
@@ -56,6 +37,7 @@ def __init__(
5637
pretrained_weights=None,
5738
scale_factor=0,
5839
):
40+
5941
super(Encoder, self).__init__()
6042

6143
self.image_embedding_dim = image_embedding_dim

DECIMER/Predictor_EfficientNet2.py

Lines changed: 0 additions & 160 deletions
This file was deleted.

0 commit comments

Comments
 (0)