Skip to content

Commit a7670be

Browse files
Use .to() instead of get_original_weight in linear_nf4 backward (#90)
Co-authored-by: cpuhrsch <cpuhrsch@googlemail.com>
1 parent 7a52c00 commit a7670be

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchao/dtypes/nf4tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,9 +569,9 @@ def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
569569
# inconsistently.
570570

571571
def backward(ctx, grad_output):
572-
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.get_original_weight()"""
572+
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)"""
573573
weight: NF4Tensor = ctx.nf4_weight
574-
return grad_output @ weight.get_original_weight(), None
574+
return grad_output @ weight.to(grad_output.dtype), None
575575

576576

577577
def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)