Skip to content

Commit ec8e096

Browse files
Add docstrings for kimm.timm_utils.* (#51)
1 parent 927370b commit ec8e096

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

kimm/_src/utils/timm_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ def _is_non_trainable_weights(name: str):
2222

2323
@kimm_export(parent_path=["kimm.timm_utils"])
2424
def separate_torch_state_dict(state_dict: typing.OrderedDict):
25+
"""Separate the torch state dict into trainable and non-trainable parts.
26+
27+
Args:
28+
state_dict: A `collections.OrderedDict`.
29+
30+
Returns:
31+
A tuple containing the trainable and non-trainable state dicts.
32+
"""
2533
trainable_state_dict = state_dict.copy()
2634
non_trainable_state_dict = state_dict.copy()
2735
trainable_remove_keys = []
@@ -44,6 +52,15 @@ def separate_torch_state_dict(state_dict: typing.OrderedDict):
4452

4553
@kimm_export(parent_path=["kimm.timm_utils"])
4654
def separate_keras_weights(keras_model: keras.Model):
55+
"""Separate the Keras model into trainable and non-trainable parts.
56+
57+
Args:
58+
keras_model: A `keras.Model` instance.
59+
60+
Returns:
61+
A tuple containing the trainable and non-trainable state lists. Each
62+
list contains (`keras.Variable`, name) pairs.
63+
"""
4764
trainable_weights = []
4865
non_trainable_weights = []
4966
for layer in keras_model.layers:
@@ -75,6 +92,20 @@ def separate_keras_weights(keras_model: keras.Model):
7592
def assign_weights(
7693
keras_name: str, keras_weight: keras.Variable, torch_weight: np.ndarray
7794
):
95+
"""Assign the torch weights to the keras weights based on the arguments.
96+
97+
Some basic criterion:
98+
1. 4D must be a convolution weights (also check the name)
99+
2. 2D must be a dense weights
100+
3. 1D must be a vector weights
101+
4. 0D must be a scalar weights
102+
103+
Args:
104+
keras_name: A `str` representing the name of the target weights.
105+
keras_weights: A `keras.Variable` representing the target weights.
106+
torch_weights: A `numpy.ndarray` representing the original source
107+
weights.
108+
"""
78109
if len(keras_weight.shape) == 4:
79110
if (
80111
"conv" in keras_name
@@ -119,6 +150,19 @@ def is_same_weights(
119150
torch_name: str,
120151
torch_weights: np.ndarray,
121152
):
153+
"""Check whether the given keras weights and torch weigths are the same.
154+
155+
Args:
156+
keras_name: A `str` representing the name of the target weights.
157+
keras_weights: A `keras.Variable` representing the target weights.
158+
torch_name: A `str` representing the name of the original source
159+
weights.
160+
torch_weights: A `numpy.ndarray` representing the original source
161+
weights.
162+
163+
Returns:
164+
A boolean indicating whether the two weights are the same.
165+
"""
122166
if np.sum(keras_weights.shape) != np.sum(torch_weights.shape):
123167
if np.sum(keras_weights.shape) == 0: # Deal with scalar
124168
if np.sum(torch_weights.shape) == 1:

0 commit comments

Comments
 (0)