Skip to content

Commit 2063c34

Browse files
authored
Simple Python file that exemplifies how to add an op to NF4Tensor (#51)
* Simple Python file that exemplifies how to add an op to NF4Tensor * Wording * Move test location
1 parent 94853d4 commit 2063c34

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed
File renamed without changes.

torchao/dtypes/nf4tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def noop_detach(func, *args, **kwargs):
4040
def _to_copy(func, *args, **kwargs):
4141
return args[0][0].get_original_weight().to(args[1]['dtype'])
4242

43+
@implements([torch.ops.aten.to.dtype])
44+
def to_dtype(func, *args, **kwargs):
45+
return args[0][0].get_original_weight().to(args[0][1])
46+
4347

4448
@implements(
4549
[

tutorials/add_an_op.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
import torchao
3+
from torchao.dtypes import to_nf4
4+
5+
# To create coverage for a new nf4 op we first attempt to run it
6+
7+
# Construct a small nf4 Tensor of desired shaped
8+
a = torch.randn(64)
9+
a[0] = 0
10+
11+
# Don't forget to pick block and scalar shapes that work for your shape
12+
a_nf4 = to_nf4(a, 32, 2)
13+
14+
# Trust is good, print better
15+
print(f"a: {a}")
16+
print(f"a_nf4: {a_nf4}")
17+
18+
19+
# If GELU is not supported you'll get the following error
20+
# NotImplementedError: NF4Tensor dispatch: attempting to run aten.gelu.default, this is not supported
21+
# torch.nn.functional.gelu(a_nf4)
22+
23+
# Next you can add this function using the implements decorator
24+
@torchao.dtypes.nf4tensor.implements([torch.ops.aten.gelu.default])
25+
def gelu(func, *args, **kwargs):
26+
# The torch dispatch convention is to pass all args and kwargs via the
27+
# args input.
28+
# args[0] here corresponds to the original *args
29+
# args[1] here corresponds to the original *kwargs
30+
# We're getting the first argument of the original args
31+
inp = args[0][0]
32+
# There's a way very inefficient way to implement it
33+
return to_nf4(torch.nn.functional.gelu(inp.to(torch.float32)), inp.block_size, inp.scaler_block_size)
34+
35+
print(f"gelu(a): {torch.nn.functional.gelu(a)}")
36+
print(f"gelu(a_nf4): {torch.nn.functional.gelu(a_nf4)}")
37+
38+
# We collect these implementations in torchao.dtypes.nf4tensor, but you can also
39+
# just roll your own.

0 commit comments

Comments
 (0)