Skip to content

Commit 8f6fdda

Browse files
committed
fixing tests that aren't skipping
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 89ec74b commit 8f6fdda

File tree

1 file changed

+55
-22
lines changed

1 file changed

+55
-22
lines changed

test/quantization/test_moe_quant.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,18 @@ def _test_impl_moe_quant(
9292
self.assertGreaterEqual(compute_error(out_q, out), 10)
9393
self.assertGreaterEqual(compute_error(out_qc, out), 10)
9494

95-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
96-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
9795
@parameterized.expand(
9896
[
9997
("single_token", 1, False),
10098
("multiple_tokens", 8, False),
10199
]
102100
)
103101
def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
102+
if not torch.cuda.is_available():
103+
self.skipTest("Need CUDA available")
104+
if not TORCH_VERSION_AT_LEAST_2_5:
105+
self.skipTest("Test only enabled for 2.5+")
106+
104107
config = MoEQuantConfig(Int4WeightOnlyConfig())
105108
tensor_impl_class = TensorCoreTiledAQTTensorImpl
106109

@@ -111,16 +114,20 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
111114
fullgraph=fullgraph,
112115
)
113116

114-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
115-
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
116-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
117117
@parameterized.expand(
118118
[
119119
("single_token", 1, True),
120120
("multiple_tokens", 8, False),
121121
]
122122
)
123123
def test_int4wo_base(self, name, num_tokens, fullgraph):
124+
if not torch.cuda.is_available():
125+
self.skipTest("Need CUDA available")
126+
if not is_sm_at_least_90():
127+
self.skipTest("Requires CUDA capability >= 9.0")
128+
if not TORCH_VERSION_AT_LEAST_2_5:
129+
self.skipTest("Test only enabled for 2.5+")
130+
124131
config = Int4WeightOnlyConfig()
125132
tensor_impl_class = TensorCoreTiledAQTTensorImpl
126133

@@ -131,15 +138,18 @@ def test_int4wo_base(self, name, num_tokens, fullgraph):
131138
fullgraph=fullgraph,
132139
)
133140

134-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
135-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
136141
@parameterized.expand(
137142
[
138143
("single_token", 1, False),
139144
("multiple_tokens", 8, False),
140145
]
141146
)
142147
def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
148+
if not torch.cuda.is_available():
149+
self.skipTest("Need CUDA available")
150+
if not TORCH_VERSION_AT_LEAST_2_5:
151+
self.skipTest("Test only enabled for 2.5+")
152+
143153
config = MoEQuantConfig(Int8WeightOnlyConfig())
144154
tensor_impl_class = PlainAQTTensorImpl
145155

@@ -150,15 +160,18 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
150160
fullgraph=fullgraph,
151161
)
152162

153-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
154-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
155163
@parameterized.expand(
156164
[
157165
("single_token", 1, True),
158166
("multiple_tokens", 8, False),
159167
]
160168
)
161169
def test_int8wo_base(self, name, num_tokens, fullgraph):
170+
if not torch.cuda.is_available():
171+
self.skipTest("Need CUDA available")
172+
if not TORCH_VERSION_AT_LEAST_2_5:
173+
self.skipTest("Test only enabled for 2.5+")
174+
162175
config = Int8WeightOnlyConfig()
163176
tensor_impl_class = PlainAQTTensorImpl
164177

@@ -169,14 +182,16 @@ def test_int8wo_base(self, name, num_tokens, fullgraph):
169182
fullgraph=fullgraph,
170183
)
171184

172-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
173185
@parameterized.expand(
174186
[
175187
("single_token", 1, True),
176188
("multiple_tokens", 8, False),
177189
]
178190
)
179191
def test_int8wo_base_cpu(self, name, num_tokens, fullgraph):
192+
if not TORCH_VERSION_AT_LEAST_2_5:
193+
self.skipTest("Test only enabled for 2.5+")
194+
180195
config = Int8WeightOnlyConfig()
181196
tensor_impl_class = PlainAQTTensorImpl
182197

@@ -188,14 +203,17 @@ def test_int8wo_base_cpu(self, name, num_tokens, fullgraph):
188203
device="cpu",
189204
)
190205

191-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
192-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
193206
@parameterized.expand(
194207
[
195208
("multiple_tokens", 32, False),
196209
]
197210
)
198211
def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
212+
if not torch.cuda.is_available():
213+
self.skipTest("Need CUDA available")
214+
if not TORCH_VERSION_AT_LEAST_2_5:
215+
self.skipTest("Test only enabled for 2.5+")
216+
199217
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
200218
base_class = LinearActivationQuantizedTensor
201219

@@ -207,14 +225,17 @@ def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
207225
fullgraph=fullgraph,
208226
)
209227

210-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
211-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
212228
@parameterized.expand(
213229
[
214230
("multiple_tokens", 32, False),
215231
]
216232
)
217233
def test_int8dq_base(self, name, num_tokens, fullgraph):
234+
if not torch.cuda.is_available():
235+
self.skipTest("Need CUDA available")
236+
if not TORCH_VERSION_AT_LEAST_2_5:
237+
self.skipTest("Test only enabled for 2.5+")
238+
218239
config = Int8DynamicActivationInt8WeightConfig()
219240
base_class = LinearActivationQuantizedTensor
220241

@@ -226,15 +247,18 @@ def test_int8dq_base(self, name, num_tokens, fullgraph):
226247
fullgraph=fullgraph,
227248
)
228249

229-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
230-
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
231250
@parameterized.expand(
232251
[
233252
("single_token", 1, False),
234253
("multiple_tokens", 8, False),
235254
]
236255
)
237256
def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph):
257+
if not torch.cuda.is_available():
258+
self.skipTest("Need CUDA available")
259+
if not is_sm_at_least_90():
260+
self.skipTest("Requires CUDA capability >= 9.0")
261+
238262
config = MoEQuantConfig(Float8WeightOnlyConfig())
239263
tensor_impl_class = Float8AQTTensorImpl
240264

@@ -245,15 +269,18 @@ def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph):
245269
fullgraph=fullgraph,
246270
)
247271

248-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
249-
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
250272
@parameterized.expand(
251273
[
252274
("single_token", 1, True),
253275
("multiple_tokens", 8, False),
254276
]
255277
)
256278
def test_fp8wo_base(self, name, num_tokens, fullgraph):
279+
if not torch.cuda.is_available():
280+
self.skipTest("Need CUDA available")
281+
if not is_sm_at_least_90():
282+
self.skipTest("Requires CUDA capability >= 9.0")
283+
257284
config = Float8WeightOnlyConfig()
258285
tensor_impl_class = Float8AQTTensorImpl
259286

@@ -264,15 +291,18 @@ def test_fp8wo_base(self, name, num_tokens, fullgraph):
264291
fullgraph=fullgraph,
265292
)
266293

267-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
268-
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
269294
@parameterized.expand(
270295
[
271296
("single_token", 1, False),
272297
("multiple_tokens", 8, False),
273298
]
274299
)
275300
def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph):
301+
if not torch.cuda.is_available():
302+
self.skipTest("Need CUDA available")
303+
if not is_sm_at_least_90():
304+
self.skipTest("Requires CUDA capability >= 9.0")
305+
276306
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig())
277307
base_class = LinearActivationQuantizedTensor
278308

@@ -283,15 +313,18 @@ def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph):
283313
fullgraph=fullgraph,
284314
)
285315

286-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
287-
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
288316
@parameterized.expand(
289317
[
290318
("single_token", 1, True),
291319
("multiple_tokens", 8, False),
292320
]
293321
)
294322
def test_fp8dq_base(self, name, num_tokens, fullgraph):
323+
if not torch.cuda.is_available():
324+
self.skipTest("Need CUDA available")
325+
if not is_sm_at_least_90():
326+
self.skipTest("Requires CUDA capability >= 9.0")
327+
295328
config = Float8DynamicActivationFloat8WeightConfig()
296329
base_class = LinearActivationQuantizedTensor
297330

0 commit comments

Comments
 (0)