@@ -22,24 +22,23 @@ def forward(self, x: torch.Tensor):
22
22
class TestTorchFunctions (parameterized .TestCase ):
23
23
24
24
def setUp (self ):
25
- self .env = torchax .tensor .Environment ()
26
- self .env .config .use_torch_native_for_cpu_tensor = False
25
+ torchax .enable_globally ()
27
26
torchax .enable_accuracy_mode ()
27
+ self .env = torchax .default_env ()
28
28
29
29
@parameterized .named_parameters (
30
- ('tensor_2d' , lambda : torch .tensor ([[0.1 , 1.2 ], [2.2 , 3.1 ], [4.9 , 5.2 ]])),
31
- ('tensor_1d' , lambda : torch .tensor ([0 , 1 ],)),
32
- ('tensor_scalar' , lambda : torch .tensor (3.14159 ,)),
33
- ('tensor_empty' , lambda : torch .tensor ([],)),
34
- ('tensor_dtype' , lambda : torch .tensor ([[0.11111 , 0.222222 , 0.3333333 ]],
35
- dtype = torch .float64 )),
30
+ ('tensor_2d' , [[0.1 , 1.2 ], [2.2 , 3.1 ], [4.9 , 5.2 ]]),
31
+ ('tensor_1d' , [0 , 1 ]),
32
+ ('tensor_scalar' , 3.14159 ),
33
+ ('tensor_empty' , []),
34
+ ('tensor_dtype' , [[0.11111 , 0.222222 , 0.3333333 ]], {'dtype' : torch .float64 })
36
35
)
37
- def test_tensor_constructor (self , func : Callable [[], torch .Tensor ]):
38
- expected = func ()
36
+ def test_tensor_constructor (self , arg , kwargs = None ):
37
+ kwargs = kwargs or {}
38
+ expected = torch .tensor (arg , ** kwargs )
39
39
40
- with self .env :
41
- actual = func ()
42
- self .assertIsInstance (actual , torchax .tensor .Tensor )
40
+ actual = torch .tensor (arg , device = 'jax' , ** kwargs )
41
+ self .assertIsInstance (actual , torchax .tensor .Tensor )
43
42
44
43
torch .testing .assert_close (torchax .tensor .j2t (actual ._elem ), expected )
45
44
@@ -90,8 +89,9 @@ def test_rms_norm(self):
90
89
self .assertTrue (
91
90
torch .allclose (res , torchax .tensor .j2t (res2 .jax ())))
92
91
93
-
94
-
92
+ def test_randn_requires_grad (self ):
93
+ x = torch .randn ((3 , 3 ), requires_grad = True , device = 'jax' )
94
+ self .assertEqual (x .requires_grad , True )
95
95
96
96
97
97
if __name__ == '__main__' :
0 commit comments