@@ -1402,12 +1402,13 @@ def test_multionehot(self, shape1, shape2):
1402
1402
assert spec2 .zero ().shape == spec2 .shape
1403
1403
1404
1404
def test_non_tensor (self ):
1405
- spec = NonTensor ((3 , 4 ), device = "cpu" )
1405
+ spec = NonTensor ((3 , 4 ), device = "cpu" , example_data = "example_data" )
1406
1406
assert (
1407
1407
spec .expand (2 , 3 , 4 )
1408
1408
== spec .expand ((2 , 3 , 4 ))
1409
- == NonTensor ((2 , 3 , 4 ), device = "cpu" )
1409
+ == NonTensor ((2 , 3 , 4 ), device = "cpu" , example_data = "example_data" )
1410
1410
)
1411
+ assert spec .expand (2 , 3 , 4 ).example_data == "example_data"
1411
1412
1412
1413
@pytest .mark .parametrize ("shape1" , [None , (), (5 ,)])
1413
1414
@pytest .mark .parametrize ("shape2" , [(), (10 ,)])
@@ -1607,9 +1608,10 @@ def test_multionehot(
1607
1608
assert spec is not spec .clone ()
1608
1609
1609
1610
def test_non_tensor (self ):
1610
- spec = NonTensor (shape = (3 , 4 ), device = "cpu" )
1611
+ spec = NonTensor (shape = (3 , 4 ), device = "cpu" , example_data = "example_data" )
1611
1612
assert spec .clone () == spec
1612
1613
assert spec .clone () is not spec
1614
+ assert spec .clone ().example_data == "example_data"
1613
1615
1614
1616
@pytest .mark .parametrize ("shape1" , [None , (), (5 ,)])
1615
1617
def test_onehot (
@@ -1840,9 +1842,10 @@ def test_multionehot(
1840
1842
spec .unbind (- 1 )
1841
1843
1842
1844
def test_non_tensor (self ):
1843
- spec = NonTensor (shape = (3 , 4 ), device = "cpu" )
1845
+ spec = NonTensor (shape = (3 , 4 ), device = "cpu" , example_data = "example_data" )
1844
1846
assert spec .unbind (1 )[0 ] == spec [:, 0 ]
1845
1847
assert spec .unbind (1 )[0 ] is not spec [:, 0 ]
1848
+ assert spec .unbind (1 )[0 ].example_data == "example_data"
1846
1849
1847
1850
@pytest .mark .parametrize ("shape1" , [(5 ,), (5 , 6 )])
1848
1851
def test_onehot (
@@ -2001,8 +2004,9 @@ def test_multionehot(self, shape1, device):
2001
2004
assert spec .to (device ).device == device
2002
2005
2003
2006
def test_non_tensor (self , device ):
2004
- spec = NonTensor (shape = (3 , 4 ), device = "cpu" )
2007
+ spec = NonTensor (shape = (3 , 4 ), device = "cpu" , example_data = "example_data" )
2005
2008
assert spec .to (device ).device == device
2009
+ assert spec .to (device ).example_data == "example_data"
2006
2010
2007
2011
@pytest .mark .parametrize ("shape1" , [(5 ,), (5 , 6 )])
2008
2012
def test_onehot (self , shape1 , device ):
@@ -2262,13 +2266,14 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
2262
2266
assert r .shape == c .shape
2263
2267
2264
2268
def test_stack_non_tensor (self , shape , stack_dim ):
2265
- spec0 = NonTensor (shape = shape , device = "cpu" )
2266
- spec1 = NonTensor (shape = shape , device = "cpu" )
2269
+ spec0 = NonTensor (shape = shape , device = "cpu" , example_data = "example_data" )
2270
+ spec1 = NonTensor (shape = shape , device = "cpu" , example_data = "example_data" )
2267
2271
new_spec = torch .stack ([spec0 , spec1 ], stack_dim )
2268
2272
shape_insert = list (shape )
2269
2273
shape_insert .insert (stack_dim , 2 )
2270
2274
assert new_spec .shape == torch .Size (shape_insert )
2271
2275
assert new_spec .device == torch .device ("cpu" )
2276
+ assert new_spec .example_data == "example_data"
2272
2277
2273
2278
def test_stack_onehot (self , shape , stack_dim ):
2274
2279
n = 5
@@ -3642,10 +3647,18 @@ def test_expand(self):
3642
3647
3643
3648
class TestNonTensorSpec :
3644
3649
def test_sample (self ):
3645
- nts = NonTensor (shape = (3 , 4 ))
3650
+ nts = NonTensor (shape = (3 , 4 ), example_data = "example_data" )
3646
3651
assert nts .one ((2 ,)).shape == (2 , 3 , 4 )
3647
3652
assert nts .rand ((2 ,)).shape == (2 , 3 , 4 )
3648
3653
assert nts .zero ((2 ,)).shape == (2 , 3 , 4 )
3654
+ assert nts .one ((2 ,)).data == "example_data"
3655
+ assert nts .rand ((2 ,)).data == "example_data"
3656
+ assert nts .zero ((2 ,)).data == "example_data"
3657
+
3658
+ def test_example_data_ineq (self ):
3659
+ nts0 = NonTensor (shape = (3 , 4 ), example_data = "example_data" )
3660
+ nts1 = NonTensor (shape = (3 , 4 ), example_data = "example_data 2" )
3661
+ assert nts0 != nts1
3649
3662
3650
3663
3651
3664
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "not cuda device" )
0 commit comments