12
12
13
13
AUTOQUANT_CACHE = {}
14
14
15
- def check_cache (shape , cls ):
16
- if shape in AUTOQUANT_CACHE :
17
- return AUTOQUANT_CACHE [shape ].get (cls , None )
18
- else :
19
- return None
15
+ def check_cache (cls , shape , dtype ):
16
+ return AUTOQUANT_CACHE .get ((cls , shape , dtype ), None )
20
17
21
- def update_cache (shape , cls , res ):
22
- if not shape in AUTOQUANT_CACHE :
23
- AUTOQUANT_CACHE [shape ] = {}
24
- AUTOQUANT_CACHE [shape ][cls ] = res
18
+ def update_cache (cls , shape , dtype , res ):
19
+ AUTOQUANT_CACHE [(cls , shape , dtype )] = res
25
20
26
21
class AutoQuantizableLinearWeight (torch .Tensor ):
27
22
"""
@@ -43,7 +38,8 @@ def __new__(cls, weight, qtensor_class_list, *args, **kwargs):
43
38
def __init__ (self , weight , qtensor_class_list , * args , ** kwargs ):
44
39
self .weight = weight
45
40
self .qtensor_class_list = qtensor_class_list
46
- self .cache_shape = None
41
+ self .logged_shape = None
42
+ self .logged_dtype = None
47
43
48
44
def __repr__ (self ):
49
45
return (
@@ -52,36 +48,46 @@ def __repr__(self):
52
48
)
53
49
54
50
@staticmethod
55
- def tune_autoquant (act_mat , w_autoquant , bias ):
51
+ def log_shape (act_mat , w_autoquant , bias ):
56
52
orig_shape = act_mat .shape
57
53
act_mat = act_mat .reshape (- 1 , act_mat .shape [- 1 ])
58
- cache_shape = (act_mat .shape , w_autoquant .shape , None if bias is None else bias .shape )
59
- w_autoquant .cache_shape = cache_shape
60
- for cur_cls in w_autoquant .qtensor_class_list :
61
- if check_cache (cache_shape , cur_cls ) is None :
62
- with torch .no_grad ():
63
- print (cur_cls , cache_shape )
64
- print (torch .cuda .max_memory_allocated ()/ 1e6 , torch .cuda .memory_usage ())
65
- res = cur_cls ._autoquant_test (act_mat .clone (), w_autoquant .weight .clone (), None if bias is None else bias .clone ())
66
- update_cache (cache_shape , cur_cls , res )
67
- print (torch .cuda .max_memory_allocated ()/ 1e6 , torch .cuda .memory_usage ())
54
+ logged_shape = (act_mat .shape , w_autoquant .shape , None if bias is None else bias .shape )
55
+ logged_dtype = act_mat .dtype
56
+ w_autoquant .logged_shape = logged_shape
57
+ w_autoquant .logged_dtype = logged_dtype
58
+ for q_cls in w_autoquant .qtensor_class_list :
59
+ if check_cache (q_cls , logged_shape , logged_dtype ) is None :
60
+ update_cache (q_cls , logged_shape , logged_dtype , None )
68
61
y = torch .mm (act_mat , w_autoquant .weight .t ())
69
62
y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ])
70
63
if bias is not None :
71
64
y += bias
72
65
return y
73
66
67
+ def tune_autoquant (self , q_cls ):
68
+ act_shape , w_shape , bias_shape = self .logged_shape
69
+ if check_cache (q_cls , self .logged_shape , self .logged_dtype ) is None :
70
+ with torch .no_grad ():
71
+ act_mat = torch .randn (act_shape , dtype = self .logged_dtype , device = self .device )
72
+ bias = None if bias_shape is None else torch .randn (bias_shape , dtype = self .logged_dtype , device = self .device )
73
+ print (q_cls , self .logged_shape , self .logged_dtype )
74
+ print ("mem" , torch .cuda .max_memory_allocated ()/ 1e6 , torch .cuda .memory_usage ())
75
+ res = q_cls ._autoquant_test (act_mat , self .weight , bias )
76
+ update_cache (q_cls , self .logged_shape , self .logged_dtype , res )
77
+
74
78
def to_quantized (self ):
75
- if self .cache_shape is None or self .cache_shape not in AUTOQUANT_CACHE :
76
- raise RuntimeError ("must run module normally to find best quantization option " )
79
+ if self .logged_shape is None or self .logged_dtype is None :
80
+ raise RuntimeError ("must run module normally to get shape, dtype info for autoquant " )
77
81
best_time = torch .inf
78
82
best_cls = None
79
- for cur_cls in self .qtensor_class_list :
80
- cls_res = AUTOQUANT_CACHE [self .cache_shape ].get (cur_cls , torch .inf )
83
+ for q_cls in self .qtensor_class_list :
84
+ if check_cache (q_cls , self .logged_shape , self .logged_dtype ) is None :
85
+ self .tune_autoquant (q_cls )
86
+ cls_res = AUTOQUANT_CACHE .get ((q_cls , self .logged_shape , self .logged_dtype ), torch .inf )
81
87
if best_time >= cls_res :
82
88
best_time = cls_res
83
- best_cls = cur_cls
84
- # need to handle random cls args/kwargs?
89
+ best_cls = q_cls
90
+ # TODO handle random cls args/kwargs? or should they be curried
85
91
self = best_cls .from_float (self .weight )
86
92
return self
87
93
@@ -113,7 +119,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
113
119
args [1 ],
114
120
args [2 ] if len (args )> 2 else None
115
121
)
116
- return cls .tune_autoquant (mat1 , w_autoquant , bias )
122
+ return cls .log_shape (mat1 , w_autoquant , bias )
117
123
118
124
try :
119
125
with torch ._C .DisableTorchFunctionSubclass ():
@@ -155,9 +161,10 @@ def from_float(cls, weight):
155
161
return weight
156
162
157
163
DEFAULT_CLASS_LIST = [
164
+ Int8DynamicallyQuantizedLinearWeight ,
158
165
DefaultLinear ,
159
166
Int8WeightOnlyQuantizedLinearWeight ,
160
- Int8DynamicallyQuantizedLinearWeight ,
167
+
161
168
]
162
169
163
170
if False :
0 commit comments