1
1
import abc
2
+ import pathlib
2
3
import typing
4
+ import urllib .parse
3
5
4
6
from keras import KerasTensor
5
7
from keras import backend
6
8
from keras import layers
7
9
from keras import models
10
+ from keras import utils
8
11
from keras .src .applications import imagenet_utils
9
12
10
13
@@ -14,53 +17,79 @@ def __init__(
14
17
inputs ,
15
18
outputs ,
16
19
features : typing .Optional [typing .Dict [str , KerasTensor ]] = None ,
17
- feature_keys : typing .Optional [typing .List [str ]] = None ,
18
20
** kwargs ,
19
21
):
20
- self .feature_extractor = kwargs .pop ("feature_extractor" , False )
21
- self .feature_keys = feature_keys
22
- if self .feature_extractor :
23
- if features is None :
24
- raise ValueError (
25
- "`features` must be set when "
26
- f"`feature_extractor=True`. Received features={ features } "
27
- )
28
- if self .feature_keys is None :
29
- self .feature_keys = list (features .keys ())
30
- filtered_features = {}
31
- for k in self .feature_keys :
32
- if k not in features :
33
- raise KeyError (
34
- f"'{ k } ' is not a key of `features`. Available keys "
35
- f"are: { list (features .keys ())} "
36
- )
37
- filtered_features [k ] = features [k ]
38
- # add outputs
39
- if backend .is_keras_tensor (outputs ):
40
- filtered_features ["TOP" ] = outputs
41
- super ().__init__ (inputs = inputs , outputs = filtered_features , ** kwargs )
42
- else :
22
+ if not hasattr (self , "_feature_extractor" ):
43
23
del features
44
24
super ().__init__ (inputs = inputs , outputs = outputs , ** kwargs )
25
+ else :
26
+ if not hasattr (self , "_feature_keys" ):
27
+ raise AttributeError (
28
+ "`self._feature_keys` must be set when initializing "
29
+ "BaseModel"
30
+ )
31
+ if self ._feature_extractor :
32
+ if features is None :
33
+ raise ValueError (
34
+ "`features` must be set when `feature_extractor=True`. "
35
+ f"Received features={ features } "
36
+ )
37
+ if self ._feature_keys is None :
38
+ self ._feature_keys = list (features .keys ())
39
+ filtered_features = {}
40
+ for k in self ._feature_keys :
41
+ if k not in features :
42
+ raise KeyError (
43
+ f"'{ k } ' is not a key of `features`. Available keys "
44
+ f"are: { list (features .keys ())} "
45
+ )
46
+ filtered_features [k ] = features [k ]
47
+ # Add outputs
48
+ if backend .is_keras_tensor (outputs ):
49
+ filtered_features ["TOP" ] = outputs
50
+ super ().__init__ (
51
+ inputs = inputs , outputs = filtered_features , ** kwargs
52
+ )
53
+ else :
54
+ del features
55
+ super ().__init__ (inputs = inputs , outputs = outputs , ** kwargs )
56
+
57
+ if hasattr (self , "_weights_url" ):
58
+ self .load_pretrained_weights (self ._weights_url )
45
59
46
- def parse_kwargs (
60
+ def set_properties (
47
61
self , kwargs : typing .Dict [str , typing .Any ], default_size : int = 224
48
62
):
49
- result = {
50
- "input_tensor" : kwargs .pop ("input_tensor" , None ),
51
- "input_shape" : kwargs .pop ("input_shape" , None ),
52
- "include_preprocessing" : kwargs .pop ("include_preprocessing" , True ),
53
- "include_top" : kwargs .pop ("include_top" , True ),
54
- "pooling" : kwargs .pop ("pooling" , None ),
55
- "dropout_rate" : kwargs .pop ("dropout_rate" , 0.0 ),
56
- "classes" : kwargs .pop ("classes" , 1000 ),
57
- "classifier_activation" : kwargs .pop (
58
- "classifier_activation" , "softmax"
59
- ),
60
- "weights" : kwargs .pop ("weights" , "imagenet" ),
61
- "default_size" : kwargs .pop ("default_size" , default_size ),
62
- }
63
- return result
63
+ """Must be called in the initilization of the class.
64
+
65
+ This method will add following common properties to the model object:
66
+ - input_shape
67
+ - include_preprocessing
68
+ - include_top
69
+ - pooling
70
+ - dropout_rate
71
+ - classes
72
+ - classifier_activation
73
+ - _weights
74
+ - weights_url
75
+ - default_size
76
+ """
77
+ self ._input_shape = kwargs .pop ("input_shape" , None )
78
+ self ._include_preprocessing = kwargs .pop ("include_preprocessing" , True )
79
+ self ._include_top = kwargs .pop ("include_top" , True )
80
+ self ._pooling = kwargs .pop ("pooling" , None )
81
+ self ._dropout_rate = kwargs .pop ("dropout_rate" , 0.0 )
82
+ self ._classes = kwargs .pop ("classes" , 1000 )
83
+ self ._classifier_activation = kwargs .pop (
84
+ "classifier_activation" , "softmax"
85
+ )
86
+ self ._weights = kwargs .pop ("weights" , None )
87
+ self ._weights_url = kwargs .pop ("weights_url" , None )
88
+ self ._default_size = kwargs .pop ("default_size" , default_size )
89
+ # feature extractor
90
+ self ._feature_extractor = kwargs .pop ("feature_extractor" , False )
91
+ self ._feature_keys = kwargs .pop ("feature_keys" , None )
92
+ print ("self._feature_keys" , self ._feature_keys )
64
93
65
94
def determine_input_tensor (
66
95
self ,
@@ -87,10 +116,12 @@ def determine_input_tensor(
87
116
if not backend .is_keras_tensor (input_tensor ):
88
117
x = layers .Input (tensor = input_tensor , shape = input_shape )
89
118
else :
90
- x = input_tensor
119
+ x = utils . get_source_inputs ( input_tensor )
91
120
return x
92
121
93
122
def build_preprocessing (self , inputs , mode = "imagenet" ):
123
+ if self ._include_preprocessing is False :
124
+ return inputs
94
125
if mode == "imagenet" :
95
126
# [0, 255] to [0, 1] and apply ImageNet mean and variance
96
127
x = layers .Rescaling (scale = 1.0 / 255.0 )(inputs )
@@ -118,15 +149,30 @@ def build_top(self, inputs, classes, classifier_activation, dropout_rate):
118
149
)(x )
119
150
return x
120
151
121
- def add_references (self , parsed_kwargs : typing .Dict [str , typing .Any ]):
122
- self .include_preprocessing = parsed_kwargs ["include_preprocessing" ]
123
- self .include_top = parsed_kwargs ["include_top" ]
124
- self .pooling = parsed_kwargs ["pooling" ]
125
- self .dropout_rate = parsed_kwargs ["dropout_rate" ]
126
- self .classes = parsed_kwargs ["classes" ]
127
- self .classifier_activation = parsed_kwargs ["classifier_activation" ]
128
- # `self.weights` is been used internally
129
- self ._weights = parsed_kwargs ["weights" ]
152
+ def build_head (self , inputs ):
153
+ x = inputs
154
+ if self ._include_top :
155
+ x = self .build_top (
156
+ x ,
157
+ self ._classes ,
158
+ self ._classifier_activation ,
159
+ self ._dropout_rate ,
160
+ )
161
+ else :
162
+ if self ._pooling == "avg" :
163
+ x = layers .GlobalAveragePooling2D (name = "avg_pool" )(x )
164
+ elif self ._pooling == "max" :
165
+ x = layers .GlobalMaxPooling2D (name = "max_pool" )(x )
166
+ return x
167
+
168
+ def load_pretrained_weights (self , weights_url : typing .Optional [str ] = None ):
169
+ if weights_url is not None :
170
+ result = urllib .parse .urlparse (weights_url )
171
+ file_name = pathlib .Path (result .path ).name
172
+ weights_path = utils .get_file (
173
+ file_name , weights_url , cache_subdir = "kimm_models"
174
+ )
175
+ self .load_weights (weights_path )
130
176
131
177
@staticmethod
132
178
@abc .abstractmethod
@@ -141,20 +187,25 @@ def get_config(self):
141
187
# models.Model
142
188
"name" : self .name ,
143
189
"trainable" : self .trainable ,
144
- # feature extractor
145
- "feature_extractor" : self .feature_extractor ,
146
- "feature_keys" : self .feature_keys ,
147
- # common
148
190
"input_shape" : self .input_shape [1 :],
149
- "include_preprocessing" : self .include_preprocessing ,
150
- "include_top" : self .include_top ,
151
- "pooling" : self .pooling ,
152
- "dropout_rate" : self .dropout_rate ,
153
- "classes" : self .classes ,
154
- "classifier_activation" : self .classifier_activation ,
191
+ # common
192
+ "include_preprocessing" : self ._include_preprocessing ,
193
+ "include_top" : self ._include_top ,
194
+ "pooling" : self ._pooling ,
195
+ "dropout_rate" : self ._dropout_rate ,
196
+ "classes" : self ._classes ,
197
+ "classifier_activation" : self ._classifier_activation ,
155
198
"weights" : self ._weights ,
199
+ "weights_url" : self ._weights_url ,
200
+ # feature extractor
201
+ "feature_extractor" : self ._feature_extractor ,
202
+ "feature_keys" : self ._feature_keys ,
156
203
}
157
204
return config
158
205
159
206
def fix_config (self , config : typing .Dict ):
160
207
return config
208
+
209
+ @property
210
+ def default_origin (self ):
211
+ return "https://github.com/james77777778/keras-aug/releases/download/v0.5.0"
0 commit comments