Skip to content

Commit db49edd

Browse files
authored
Fix the source param sharding for GradAcc API (#8999)
1 parent 366f248 commit db49edd

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch_xla/experimental/gradient_accumulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor,
404404
if param.grad is None:
405405
param.grad = torch.zeros(param.size()).to(
406406
param.device).requires_grad_(False)
407-
param_sharding = torch_xla._XLAC._get_xla_op_sharding(param.grad)
407+
param_sharding = torch_xla._XLAC._get_xla_op_sharding(param)
408408
if param_sharding:
409409
# Match the gradient sharding to the parameter's.
410410
torch_xla._XLAC._xla_mark_sharding(param.grad, param_sharding)

0 commit comments

Comments
 (0)