Skip to content

Commit 7a52c00

Browse files
authored
Reenable int4 integration tests (#91)
1 parent 6204f07 commit 7a52c00

File tree

3 files changed

+121
-107
lines changed

3 files changed

+121
-107
lines changed

.github/workflows/regression_test.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ jobs:
2222
- name: Install dependencies
2323
run: |
2424
python -m pip install --upgrade pip
25+
pip install torch
2526
pip install -r requirements.txt
2627
pip install -r dev-requirements.txt
27-
pip install torch
2828
2929
3030
- name: Install package
@@ -48,9 +48,9 @@ jobs:
4848
- name: Install dependencies
4949
run: |
5050
python -m pip install --upgrade pip
51+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
5152
pip install -r requirements.txt
5253
pip install -r dev-requirements.txt
53-
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
5454
5555
5656
- name: Install package
@@ -74,9 +74,9 @@ jobs:
7474
- name: Install dependencies
7575
run: |
7676
python -m pip install --upgrade pip
77+
pip install torch --index-url https://download.pytorch.org/whl/cpu
7778
pip install -r requirements.txt
7879
pip install -r dev-requirements.txt
79-
pip install torch --index-url https://download.pytorch.org/whl/cpu
8080
8181
8282
- name: Install package
@@ -100,9 +100,9 @@ jobs:
100100
- name: Install dependencies
101101
run: |
102102
python -m pip install --upgrade pip
103+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
103104
pip install -r requirements.txt
104105
pip install -r dev-requirements.txt
105-
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
106106
107107
108108
- name: Install package

test/dtypes/test_nf4.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,13 @@ def test_smoketest_linear(self):
195195
@unittest.skipIf(torch.__version__.split('+')[0] == '2.2.1', "Broken on stable.")
196196
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
197197
def test_smoketest_linear_compile(self):
198-
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
199-
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
200-
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
201-
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)
198+
for dtype in [torch.bfloat16, torch.float16]:
199+
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16:
200+
self.skipTest("test requires SM capability of at least (8, 0).")
201+
a = torch.randn(32, 32, dtype=dtype, device='cuda')
202+
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
203+
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
204+
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)
202205

203206

204207

0 commit comments

Comments
 (0)