|
| 1 | +import torch |
| 2 | +import torchao |
| 3 | +from torchao.dtypes import to_nf4 |
| 4 | + |
| 5 | +# To create coverage for a new nf4 op we first attempt to run it |
| 6 | + |
| 7 | +# Construct a small nf4 Tensor of desired shaped |
| 8 | +a = torch.randn(64) |
| 9 | +a[0] = 0 |
| 10 | + |
| 11 | +# Don't forget to pick block and scalar shapes that work for your shape |
| 12 | +a_nf4 = to_nf4(a, 32, 2) |
| 13 | + |
| 14 | +# Trust is good, print better |
| 15 | +print(f"a: {a}") |
| 16 | +print(f"a_nf4: {a_nf4}") |
| 17 | + |
| 18 | + |
| 19 | +# If GELU is not supported you'll get the following error |
| 20 | +# NotImplementedError: NF4Tensor dispatch: attempting to run aten.gelu.default, this is not supported |
| 21 | +# torch.nn.functional.gelu(a_nf4) |
| 22 | + |
| 23 | +# Next you can add this function using the implements decorator |
| 24 | +@torchao.dtypes.nf4tensor.implements([torch.ops.aten.gelu.default]) |
| 25 | +def gelu(func, *args, **kwargs): |
| 26 | + # The torch dispatch convention is to pass all args and kwargs via the |
| 27 | + # args input. |
| 28 | + # args[0] here corresponds to the original *args |
| 29 | + # args[1] here corresponds to the original *kwargs |
| 30 | + # We're getting the first argument of the original args |
| 31 | + inp = args[0][0] |
| 32 | + # There's a way very inefficient way to implement it |
| 33 | + return to_nf4(torch.nn.functional.gelu(inp.to(torch.float32)), inp.block_size, inp.scaler_block_size) |
| 34 | + |
| 35 | +print(f"gelu(a): {torch.nn.functional.gelu(a)}") |
| 36 | +print(f"gelu(a_nf4): {torch.nn.functional.gelu(a_nf4)}") |
| 37 | + |
| 38 | +# We collect these implementations in torchao.dtypes.nf4tensor, but you can also |
| 39 | +# just roll your own. |
0 commit comments