@@ -22,6 +22,14 @@ def _is_non_trainable_weights(name: str):
22
22
23
23
@kimm_export (parent_path = ["kimm.timm_utils" ])
24
24
def separate_torch_state_dict (state_dict : typing .OrderedDict ):
25
+ """Separate the torch state dict into trainable and non-trainable parts.
26
+
27
+ Args:
28
+ state_dict: A `collections.OrderedDict`.
29
+
30
+ Returns:
31
+ A tuple containing the trainable and non-trainable state dicts.
32
+ """
25
33
trainable_state_dict = state_dict .copy ()
26
34
non_trainable_state_dict = state_dict .copy ()
27
35
trainable_remove_keys = []
@@ -44,6 +52,15 @@ def separate_torch_state_dict(state_dict: typing.OrderedDict):
44
52
45
53
@kimm_export (parent_path = ["kimm.timm_utils" ])
46
54
def separate_keras_weights (keras_model : keras .Model ):
55
+ """Separate the Keras model into trainable and non-trainable parts.
56
+
57
+ Args:
58
+ keras_model: A `keras.Model` instance.
59
+
60
+ Returns:
61
+ A tuple containing the trainable and non-trainable state lists. Each
62
+ list contains (`keras.Variable`, name) pairs.
63
+ """
47
64
trainable_weights = []
48
65
non_trainable_weights = []
49
66
for layer in keras_model .layers :
@@ -75,6 +92,20 @@ def separate_keras_weights(keras_model: keras.Model):
75
92
def assign_weights (
76
93
keras_name : str , keras_weight : keras .Variable , torch_weight : np .ndarray
77
94
):
95
+ """Assign the torch weights to the keras weights based on the arguments.
96
+
97
+ Some basic criterion:
98
+ 1. 4D must be a convolution weights (also check the name)
99
+ 2. 2D must be a dense weights
100
+ 3. 1D must be a vector weights
101
+ 4. 0D must be a scalar weights
102
+
103
+ Args:
104
+ keras_name: A `str` representing the name of the target weights.
105
+ keras_weights: A `keras.Variable` representing the target weights.
106
+ torch_weights: A `numpy.ndarray` representing the original source
107
+ weights.
108
+ """
78
109
if len (keras_weight .shape ) == 4 :
79
110
if (
80
111
"conv" in keras_name
@@ -119,6 +150,19 @@ def is_same_weights(
119
150
torch_name : str ,
120
151
torch_weights : np .ndarray ,
121
152
):
153
+ """Check whether the given keras weights and torch weigths are the same.
154
+
155
+ Args:
156
+ keras_name: A `str` representing the name of the target weights.
157
+ keras_weights: A `keras.Variable` representing the target weights.
158
+ torch_name: A `str` representing the name of the original source
159
+ weights.
160
+ torch_weights: A `numpy.ndarray` representing the original source
161
+ weights.
162
+
163
+ Returns:
164
+ A boolean indicating whether the two weights are the same.
165
+ """
122
166
if np .sum (keras_weights .shape ) != np .sum (torch_weights .shape ):
123
167
if np .sum (keras_weights .shape ) == 0 : # Deal with scalar
124
168
if np .sum (torch_weights .shape ) == 1 :
0 commit comments