Skip to content

Commit a59db0a

Browse files
committed
Updated to_nf4 test
1 parent 5b3ee6e commit a59db0a

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

test/modules/test_nf4_linear.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,5 +164,18 @@ def test_load_from_nf4_diff_meta(self):
164164
assert other_mod.param.block_size == 64
165165
assert other_mod.param.scaler_block_size == 1
166166

167+
def test_to_copy(self):
168+
inpt_tensor = torch.rand(128, device='cpu')
169+
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
170+
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
171+
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
172+
173+
if torch.cuda.is_available():
174+
inpt_tensor = torch.rand(128, device='cuda')
175+
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
176+
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
177+
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
178+
179+
167180
if __name__ == "__main__":
168181
unittest.main()

torchao/dtypes/nf4tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def noop_detach(func, *args, **kwargs):
3838

3939
@implements([torch.ops.aten._to_copy.default])
4040
def _to_copy(func, *args, **kwargs):
41+
print("func: ", func)
4142
return args[0][0].get_original_weight().to(args[1]['dtype'])
4243

4344

0 commit comments

Comments
 (0)