11# EfficientNet-V2 config
22import tensorflow as tf
3- import DECIMER .efficientnetv2 as efficientnetv2
3+ import DECIMER .efficientnetv2
4+ from DECIMER .efficientnetv2 import effnetv2_model
5+ from DECIMER .efficientnetv2 import effnetv2_configs
46
57BATCH_SIZE_DEBUG = 2
6- MODEL = "efficientnetv2-b3 " # @param
8+ MODEL = "efficientnetv2-m " # @param
79
810
911# Define encoder
1012def get_efficientnetv2_backbone (
11- model_name , include_top = False , input_shape = (299 , 299 , 3 ), pooling = None , weights = None
13+ model_name , include_top = False , input_shape = (512 , 512 , 3 ), pooling = None , weights = None
1214):
13- """Initiate and get the desired Efficient-Net V2 backbone as encoder
14-
15- Args:
16- model_name (str): Name of the Efficient-Net V2 model
17- include_top (bool, optional): Defaults to False.
18- input_shape (tuple, optional): Image shape. Defaults to (299, 299, 3).
19- pooling (int, optional): Max pooling values. Defaults to None.
20- weights ( optional): Pretrained weights. Defaults to None.
21-
22- Raises:
23- NotImplementedError: At this time we only want to use the raw
24-
25- Returns:
26- Efficient Net V2 backbone
27- """
2815 # Catch unsupported arguments
2916 if pooling or weights or include_top :
3017 raise NotImplementedError (
3118 "\n ...At this time we only want to use the raw "
3219 "(no pretraining), headless, features with no pooling ...\n "
3320 )
34- backbone = efficientnetv2 . effnetv2_model .EffNetV2Model (model_name = model_name )
21+ backbone = effnetv2_model .EffNetV2Model (model_name = model_name )
3522 backbone (
3623 tf .ones ((BATCH_SIZE_DEBUG , * input_shape )), training = False , features_only = True
3724 )
3825 return backbone
3926
4027
4128class Encoder (tf .keras .Model ):
42- """Encoder class
43-
44- Args:
45- tf (_type_): tensorflow model module
46- """
47-
4829 def __init__ (
4930 self ,
5031 image_embedding_dim ,
@@ -56,6 +37,7 @@ def __init__(
5637 pretrained_weights = None ,
5738 scale_factor = 0 ,
5839 ):
40+
5941 super (Encoder , self ).__init__ ()
6042
6143 self .image_embedding_dim = image_embedding_dim
0 commit comments