Skip to content

Commit 1ee511a

Browse files
committed
Statistics computation for weight and bias matrices.
1 parent ea899de commit 1ee511a

File tree

7 files changed

+46
-5
lines changed

7 files changed

+46
-5
lines changed

src/fflib/interfaces/iff_recurrent_layer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.nn import Module
44
from fflib.enums import SparsityType
55
from abc import ABC, abstractmethod
6-
from typing import Callable, Tuple, List, Dict
6+
from typing import Callable, Tuple, List, Dict, Any
77

88

99
class IFFRecurrentLayer(ABC, Module):
@@ -65,3 +65,7 @@ def strip_down(self) -> None:
6565
@abstractmethod
6666
def sparsity(self, type: SparsityType) -> Dict[str, float]:
6767
pass
68+
69+
@abstractmethod
70+
def stats(self) -> Dict[str, Any]:
71+
pass

src/fflib/nn/ff_linear.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from torch.optim import Adam, Optimizer
44

55
from fflib.enums import SparsityType
6-
from fflib.utils.maths import ComputeSparsity
6+
from fflib.utils.maths import ComputeSparsity, ComputeStats
77

8-
from typing import Callable, Tuple, Any, cast
8+
from typing import Callable, Tuple, Dict, Any, cast
99

1010

1111
class FFLinear(Linear):
@@ -140,3 +140,6 @@ def strip_down(self) -> None:
140140
def sparsity(self, type: SparsityType) -> torch.Tensor:
141141
"""Computes the sparsity of the weight's matrix."""
142142
return ComputeSparsity(torch.flatten(self.weight), type)
143+
144+
def stats(self) -> Dict[str, Any]:
145+
return {"weight": ComputeStats(self.weight), "bias": ComputeStats(self.bias)}

src/fflib/nn/ff_net.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,6 @@ def sparsity(self, type: SparsityType) -> Dict[str, float]:
118118
return {
119119
f"layer_{i}": float(layer.sparsity(type).item()) for i, layer in enumerate(self.layers)
120120
}
121+
122+
def stats(self) -> Dict[str, Any]:
123+
return {f"layer_{i}": layer.stats() for i, layer in enumerate(self.layers)}

src/fflib/nn/ff_recurrent_layer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.optim import Adam, Optimizer
66
from fflib.interfaces.iff_recurrent_layer import IFFRecurrentLayer
77
from fflib.enums import SparsityType
8-
from fflib.utils.maths import ComputeSparsity
8+
from fflib.utils.maths import ComputeSparsity, ComputeStats
99
from math import sqrt
1010
from typing import Callable, List, Tuple, Dict, cast, Any
1111

@@ -164,6 +164,13 @@ def sparsity(self, type: SparsityType) -> Dict[str, float]:
164164
),
165165
}
166166

167+
def stats(self) -> Dict[str, Any]:
168+
return {
169+
"fw": ComputeStats(self.fw),
170+
"bw": ComputeStats(self.bw),
171+
"fb": ComputeStats(self.fb),
172+
}
173+
167174

168175
class FFRecurrentLayerDummy(IFFRecurrentLayer):
169176
def __init__(self, dimensions: int):
@@ -204,3 +211,6 @@ def strip_down(self) -> None:
204211

205212
def sparsity(self, type: SparsityType) -> Dict[str, float]:
206213
return {}
214+
215+
def stats(self) -> Dict[str, Any]:
216+
return {}

src/fflib/nn/ff_rnn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,7 @@ def strip_down(self) -> None:
236236
def sparsity(self, type: SparsityType) -> Dict[str, Dict[str, float]]:
237237
"""Returns a dictionary of dictionaries describing the sparsity levels at each layer."""
238238
return {f"layer_{i}": layer.sparsity(type) for i, layer in enumerate(self.layers)}
239+
240+
def stats(self) -> Dict[str, Dict[str, float]]:
241+
"""Returns a dictionary of dictionaries containing basic statistics about the layers' parameters."""
242+
return {f"layer_{i}": layer.stats() for i, layer in enumerate(self.layers[1:-1])}

src/fflib/nn/ffc.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fflib.interfaces.iff import IFF
66
from fflib.nn.ff_linear import FFLinear
77
from fflib.enums import SparsityType
8-
from fflib.utils.maths import ComputeSparsity
8+
from fflib.utils.maths import ComputeSparsity, ComputeStats
99

1010
from typing import List, Dict, Callable, Any
1111

@@ -106,3 +106,11 @@ def sparsity(self, type: SparsityType) -> Dict[str, float]:
106106
ComputeSparsity(torch.flatten(self.classifier.weight), type).item()
107107
)
108108
return result
109+
110+
def stats(self) -> Dict[str, Any]:
111+
result = {f"layer_{i}": layer.stats() for i, layer in enumerate(self.layers)}
112+
result["classifier"] = {
113+
"weight": ComputeStats(self.classifier.weight),
114+
"bias": ComputeStats(self.classifier.bias),
115+
}
116+
return result

src/fflib/utils/maths.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,12 @@ def ComputeAllSparsityTypes(x: torch.Tensor) -> Dict[str, torch.Tensor]:
3737
for type in SparsityType:
3838
result[str(type).split(".")[1]] = ComputeSparsity(x, type)
3939
return result
40+
41+
42+
def ComputeStats(x: torch.Tensor) -> Dict[str, float]:
43+
return {
44+
"min": x.min().item(),
45+
"max": x.max().item(),
46+
"mean": x.mean().item(),
47+
"std": x.std().item(),
48+
}

0 commit comments

Comments
 (0)