Skip to content

Commit f177e76

Browse files
authored
handle requires_grad in torchax (#8992)
1 parent 8e6a5e5 commit f177e76

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

torchax/test/test_functions.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,23 @@ def forward(self, x: torch.Tensor):
2222
class TestTorchFunctions(parameterized.TestCase):
2323

2424
def setUp(self):
25-
self.env = torchax.tensor.Environment()
26-
self.env.config.use_torch_native_for_cpu_tensor = False
25+
torchax.enable_globally()
2726
torchax.enable_accuracy_mode()
27+
self.env = torchax.default_env()
2828

2929
@parameterized.named_parameters(
30-
('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])),
31-
('tensor_1d', lambda: torch.tensor([0, 1],)),
32-
('tensor_scalar', lambda: torch.tensor(3.14159,)),
33-
('tensor_empty', lambda: torch.tensor([],)),
34-
('tensor_dtype', lambda: torch.tensor([[0.11111, 0.222222, 0.3333333]],
35-
dtype=torch.float64)),
30+
('tensor_2d', [[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]),
31+
('tensor_1d', [0, 1]),
32+
('tensor_scalar', 3.14159),
33+
('tensor_empty', []),
34+
('tensor_dtype', [[0.11111, 0.222222, 0.3333333]], {'dtype': torch.float64})
3635
)
37-
def test_tensor_constructor(self, func: Callable[[], torch.Tensor]):
38-
expected = func()
36+
def test_tensor_constructor(self, arg, kwargs=None):
37+
kwargs = kwargs or {}
38+
expected = torch.tensor(arg, **kwargs)
3939

40-
with self.env:
41-
actual = func()
42-
self.assertIsInstance(actual, torchax.tensor.Tensor)
40+
actual = torch.tensor(arg, device='jax', **kwargs)
41+
self.assertIsInstance(actual, torchax.tensor.Tensor)
4342

4443
torch.testing.assert_close(torchax.tensor.j2t(actual._elem), expected)
4544

@@ -90,8 +89,9 @@ def test_rms_norm(self):
9089
self.assertTrue(
9190
torch.allclose(res, torchax.tensor.j2t(res2.jax())))
9291

93-
94-
92+
def test_randn_requires_grad(self):
93+
x = torch.randn((3, 3), requires_grad=True, device='jax')
94+
self.assertEqual(x.requires_grad, True)
9595

9696

9797
if __name__ == '__main__':

torchax/torchax/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,15 @@ def _handle_tensor_constructor(self, func, args, kwargs):
411411
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
412412
return func(*args, **kwargs)
413413
with jax.default_device(jax_device):
414+
requires_grad = kwargs.get('requires_grad', False)
414415
op = self._ops.get(func)
415416
if op is None and isinstance(func, torch._ops.OpOverload):
416417
op = self._ops.get(func.overloadpacket)
417418
res = op.func(*args, **kwargs)
418419
if isinstance(res, jax.Array):
419420
res = Tensor(res, self)
421+
if requires_grad:
422+
res.requires_grad = True
420423
return res
421424

422425
def _torch_Tensor_to(self, args, kwargs):

0 commit comments

Comments
 (0)