Skip to content

Commit 55ebaac

Browse files
Tests: don't require grad on weights for test_kbit_backprop
1 parent 318a86e commit 55ebaac

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

tests/test_modules.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,24 +285,22 @@ def test_linear_kbit_fp32_bias(device, module):
285285
@pytest.mark.parametrize("device", get_available_devices())
286286
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
287287
def test_kbit_backprop(device, module):
288-
if device == "cpu":
289-
pytest.xfail("Test is not yet supported on CPU")
290-
291288
b = 16
292289
dim1 = 36
293290
dim2 = 84
294291
# dim1 = 37
295292
# dim2 = 83
296293

297294
ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 128)])
298-
# ref[1].weight.requires_grad = False
299295
torch.nn.init.kaiming_normal_(ref[0].weight)
300296
torch.nn.init.kaiming_normal_(ref[1].weight)
297+
ref[1].weight.requires_grad_(False)
301298
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
302299
kbit[0].weight.detach().copy_(ref[0].weight)
303300
kbit[1].weight.detach().copy_(ref[1].weight)
304301
kbit[0].bias.detach().copy_(ref[0].bias)
305302
kbit[1].bias.detach().copy_(ref[1].bias)
303+
kbit[1].weight.requires_grad_(False)
306304
ref = ref.half().to(device)
307305
kbit = kbit.half().to(device)
308306
kbit = kbit.half().to(device)

0 commit comments

Comments
 (0)