3
3
4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
-
7
- import os
8
6
from typing import Dict , Optional
9
7
10
8
import torch
11
9
from torch .utils ._python_dispatch import TorchDispatchMode
12
10
13
11
__all__ = [
14
12
"find_multiple" ,
15
- "log_with_rank" ,
16
- "clear_logs" ,
17
13
"compute_error" ,
18
- "apply_logging_hook " ,
14
+ "_apply_logging_hook " ,
19
15
"get_model_size_in_bytes" ,
20
16
]
21
17
@@ -26,25 +22,6 @@ def find_multiple(n: int, k: int) -> int:
26
22
return n + k - (n % k )
27
23
28
24
29
- def log_with_rank (* args ):
30
- # append
31
- #
32
- # {thing_to_log}
33
- #
34
- # to {file}_{rank}.txt, for printing stuff from multiple GPUs
35
- if not os .path .exists (log_dir ):
36
- os .makedirs (log_dir )
37
- with open (log_fname , "a" ) as f :
38
- f .write (" " .join ([str (s ) for s in args ]) + "\n " )
39
- if local_rank == 0 :
40
- print (* args )
41
-
42
-
43
- def clear_logs ():
44
- if os .path .isfile (log_fname ):
45
- os .remove (log_fname )
46
-
47
-
48
25
# basic SQNR
49
26
def compute_error (x , y ):
50
27
Ps = torch .linalg .norm (x )
@@ -65,13 +42,13 @@ def forward_hook(module, input):
65
42
return forward_hook
66
43
67
44
68
- def apply_logging_hook (model ):
45
+ def _apply_logging_hook (model ):
69
46
for name , mod in model .named_modules ():
70
47
mod .register_forward_pre_hook (_get_logging_hook (name ))
71
48
72
49
73
50
# collections.defaultdict printing is weird with lambdas, so hand writing for now
74
- fqn_to_op_to_shape_to_count : Dict [
51
+ _fqn_to_op_to_shape_to_count : Dict [
75
52
Optional [str ], Dict [Optional [str ], Dict [Optional [str ], int ]]
76
53
] = {}
77
54
@@ -90,13 +67,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
90
67
if shape_str != "" :
91
68
shape_str = shape_str [:- 2 ]
92
69
93
- if _cur_fqn not in fqn_to_op_to_shape_to_count :
94
- fqn_to_op_to_shape_to_count [_cur_fqn ] = {}
95
- if op_name not in fqn_to_op_to_shape_to_count [_cur_fqn ]:
96
- fqn_to_op_to_shape_to_count [_cur_fqn ][op_name ] = {}
97
- if shape_str not in fqn_to_op_to_shape_to_count [_cur_fqn ][op_name ]:
98
- fqn_to_op_to_shape_to_count [_cur_fqn ][op_name ][shape_str ] = 0
99
- fqn_to_op_to_shape_to_count [_cur_fqn ][op_name ][shape_str ] += 1
70
+ if _cur_fqn not in _fqn_to_op_to_shape_to_count :
71
+ _fqn_to_op_to_shape_to_count [_cur_fqn ] = {}
72
+ if op_name not in _fqn_to_op_to_shape_to_count [_cur_fqn ]:
73
+ _fqn_to_op_to_shape_to_count [_cur_fqn ][op_name ] = {}
74
+ if shape_str not in _fqn_to_op_to_shape_to_count [_cur_fqn ][op_name ]:
75
+ _fqn_to_op_to_shape_to_count [_cur_fqn ][op_name ][shape_str ] = 0
76
+ _fqn_to_op_to_shape_to_count [_cur_fqn ][op_name ][shape_str ] += 1
100
77
101
78
return rs
102
79
0 commit comments