@@ -94,7 +94,11 @@ class Test8BitBlockwiseQuantizeFunctional:
94
94
@pytest .mark .parametrize ("blocksize" , [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
95
95
@pytest .mark .parametrize ("signed" , TRUE_FALSE , ids = id_formatter ("signed" ))
96
96
def test_dynamic_blockwise_quantization (self , device , dtype , nested , blocksize , signed ):
97
+ iters = 100
98
+
97
99
if device == "cpu" :
100
+ iters = 10
101
+
98
102
# This test is slow on CPU, so avoid atypical use cases.
99
103
if nested :
100
104
pytest .skip ("Not a typical use case." )
@@ -106,7 +110,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
106
110
107
111
diffs = []
108
112
reldiffs = []
109
- for i in range (100 ):
113
+ for i in range (iters ):
110
114
A1 = torch .randn (1024 , 1024 , device = device , dtype = dtype )
111
115
C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested )
112
116
A2 = F .dequantize_blockwise (C , S )
@@ -116,15 +120,13 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
116
120
reldiffs .append (reldiff .mean ().item ())
117
121
abserr = sum (diffs ) / len (diffs )
118
122
relerr = sum (reldiffs ) / len (reldiffs )
119
- # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
120
- # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
121
123
assert abserr < 0.011
122
124
assert relerr < 0.018
123
125
assert A2 .dtype == dtype
124
126
125
127
diffs = []
126
128
code = F .create_dynamic_map (signed = signed )
127
- for i in range (100 ):
129
+ for i in range (iters ):
128
130
A1 = torch .rand (1024 , 1024 , device = device , dtype = dtype )
129
131
C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested , code = code )
130
132
A2 = F .dequantize_blockwise (C , S )
@@ -142,29 +144,29 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
142
144
assert abserr < 0.00175
143
145
assert relerr < 0.012
144
146
assert A2 .dtype == dtype
145
- # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
146
- # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
147
147
148
- def test_blockwise_cpu_large (self ):
148
+ @pytest .mark .skipif ("cpu" not in get_available_devices (), reason = "CPU is required" )
149
+ @pytest .mark .parametrize ("hidden" , [128 ])
150
+ @pytest .mark .parametrize ("blocksize" , [4096 , 16384 ])
151
+ def test_blockwise_cpu_large (self , hidden , blocksize ):
149
152
diffs = []
150
153
reldiffs = []
151
154
batch = 128
152
155
seq = 128
153
- for hidden in [128 ]: # , 14336]:
154
- for blocksize in [4096 , 16384 ]:
155
- for i in range (2 ):
156
- A1 = torch .randn (batch , seq , hidden , device = "cpu" )
157
- t0 = time .time ()
158
- C , S = F .quantize_blockwise (A1 , blocksize = blocksize )
159
- A2 = F .dequantize_blockwise (C , S , blocksize = blocksize )
160
- print (time .time () - t0 )
161
- diff = torch .abs (A1 - A2 )
162
- reldiff = diff / torch .abs (A1 + 1e-8 )
163
- diffs .append (diff .mean ().item ())
164
- reldiffs .append (reldiff .mean ().item ())
165
- assert diffs [- 1 ] < 0.011
166
- # print(sum(diffs)/len(diffs))
167
- # print(sum(reldiffs)/len(reldiffs))
156
+
157
+ for i in range (2 ):
158
+ A1 = torch .randn (batch , seq , hidden , device = "cpu" )
159
+ t0 = time .time ()
160
+ C , S = F .quantize_blockwise (A1 , blocksize = blocksize )
161
+ A2 = F .dequantize_blockwise (C , S , blocksize = blocksize )
162
+ print (time .time () - t0 )
163
+ diff = torch .abs (A1 - A2 )
164
+ reldiff = diff / torch .abs (A1 + 1e-8 )
165
+ diffs .append (diff .mean ().item ())
166
+ reldiffs .append (reldiff .mean ().item ())
167
+ assert diffs [- 1 ] < 0.011
168
+ # print(sum(diffs)/len(diffs))
169
+ # print(sum(reldiffs)/len(reldiffs))
168
170
169
171
@pytest .mark .parametrize ("device" , get_available_devices ())
170
172
@pytest .mark .parametrize ("bits" , range (2 , 9 ), ids = id_formatter ("bits" ))
0 commit comments