Skip to content

Commit 793fa01

Browse files
committed
fixes to move to fbcode
Summary: removing unused code Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f6a92d1 Pull Request resolved: #25
1 parent b56f51d commit 793fa01

File tree

3 files changed

+14
-41
lines changed

3 files changed

+14
-41
lines changed

test/test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@
4848
Int4WeightOnlyQuantizedLinearWeight
4949
)
5050
from torchao.quantization.utils import (
51-
apply_logging_hook,
51+
_apply_logging_hook,
5252
compute_error,
5353
compute_error as SQNR,
54-
fqn_to_op_to_shape_to_count,
54+
_fqn_to_op_to_shape_to_count,
5555
LoggingTensorMode,
5656
)
5757
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
@@ -1111,12 +1111,12 @@ def test_shape_logger(self):
11111111
),
11121112
)
11131113

1114-
apply_logging_hook(m)
1114+
_apply_logging_hook(m)
11151115
with LoggingTensorMode():
11161116
m(x)
11171117
m(x)
11181118

1119-
for fqn, d1 in fqn_to_op_to_shape_to_count.items(): # noqa: PERF102
1119+
for fqn, d1 in _fqn_to_op_to_shape_to_count.items(): # noqa: PERF102
11201120
for op, d2 in d1.items(): # noqa: PERF102
11211121
for shape, count in d2.items(): # noqa: PERF102
11221122
# print(fqn, op, shape, count)

torchao/quantization/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@
3838
"Int8DynamicallyQuantizedLinearWeight",
3939
"Int8WeightOnlyQuantizedLinearWeight",
4040
"Int4WeightOnlyQuantizedLinearWeight",
41-
"log_with_rank",
42-
"clear_logs",
4341
"compute_error",
44-
"forward_hook",
45-
"apply_logging_hook",
4642
"get_model_size_in_bytes",
4743
"WeightOnlyInt8QuantLinear",
4844
]

torchao/quantization/utils.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,15 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6-
7-
import os
86
from typing import Dict, Optional
97

108
import torch
119
from torch.utils._python_dispatch import TorchDispatchMode
1210

1311
__all__ = [
1412
"find_multiple",
15-
"log_with_rank",
16-
"clear_logs",
1713
"compute_error",
18-
"apply_logging_hook",
14+
"_apply_logging_hook",
1915
"get_model_size_in_bytes",
2016
]
2117

@@ -26,25 +22,6 @@ def find_multiple(n: int, k: int) -> int:
2622
return n + k - (n % k)
2723

2824

29-
def log_with_rank(*args):
30-
# append
31-
#
32-
# {thing_to_log}
33-
#
34-
# to {file}_{rank}.txt, for printing stuff from multiple GPUs
35-
if not os.path.exists(log_dir):
36-
os.makedirs(log_dir)
37-
with open(log_fname, "a") as f:
38-
f.write(" ".join([str(s) for s in args]) + "\n")
39-
if local_rank == 0:
40-
print(*args)
41-
42-
43-
def clear_logs():
44-
if os.path.isfile(log_fname):
45-
os.remove(log_fname)
46-
47-
4825
# basic SQNR
4926
def compute_error(x, y):
5027
Ps = torch.linalg.norm(x)
@@ -65,13 +42,13 @@ def forward_hook(module, input):
6542
return forward_hook
6643

6744

68-
def apply_logging_hook(model):
45+
def _apply_logging_hook(model):
6946
for name, mod in model.named_modules():
7047
mod.register_forward_pre_hook(_get_logging_hook(name))
7148

7249

7350
# collections.defaultdict printing is weird with lambdas, so hand writing for now
74-
fqn_to_op_to_shape_to_count: Dict[
51+
_fqn_to_op_to_shape_to_count: Dict[
7552
Optional[str], Dict[Optional[str], Dict[Optional[str], int]]
7653
] = {}
7754

@@ -90,13 +67,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
9067
if shape_str != "":
9168
shape_str = shape_str[:-2]
9269

93-
if _cur_fqn not in fqn_to_op_to_shape_to_count:
94-
fqn_to_op_to_shape_to_count[_cur_fqn] = {}
95-
if op_name not in fqn_to_op_to_shape_to_count[_cur_fqn]:
96-
fqn_to_op_to_shape_to_count[_cur_fqn][op_name] = {}
97-
if shape_str not in fqn_to_op_to_shape_to_count[_cur_fqn][op_name]:
98-
fqn_to_op_to_shape_to_count[_cur_fqn][op_name][shape_str] = 0
99-
fqn_to_op_to_shape_to_count[_cur_fqn][op_name][shape_str] += 1
70+
if _cur_fqn not in _fqn_to_op_to_shape_to_count:
71+
_fqn_to_op_to_shape_to_count[_cur_fqn] = {}
72+
if op_name not in _fqn_to_op_to_shape_to_count[_cur_fqn]:
73+
_fqn_to_op_to_shape_to_count[_cur_fqn][op_name] = {}
74+
if shape_str not in _fqn_to_op_to_shape_to_count[_cur_fqn][op_name]:
75+
_fqn_to_op_to_shape_to_count[_cur_fqn][op_name][shape_str] = 0
76+
_fqn_to_op_to_shape_to_count[_cur_fqn][op_name][shape_str] += 1
10077

10178
return rs
10279

0 commit comments

Comments
 (0)