@@ -320,6 +320,9 @@ def test_jagged_2d_to_dense_dynamic_shape(
320
320
dtype : torch .dtype ,
321
321
device_type : str ,
322
322
) -> None :
323
+ # Start a fresh compile for each parameter of the test case
324
+ torch ._dynamo .reset ()
325
+
323
326
D = D * 4
324
327
lengths_ = np .random .randint (low = 0 , high = max_sequence_length , size = B )
325
328
total_lengths = lengths_ .sum ()
@@ -523,6 +526,9 @@ def test_jagged_1d_to_dense_truncation(self) -> None:
523
526
def test_jagged_1d_to_dense_dynamic_shape (
524
527
self , B : int , max_sequence_length : int , padding_value : int , device_type : str
525
528
) -> None :
529
+ # Start a fresh compile for each parameter of the test case
530
+ torch ._dynamo .reset ()
531
+
526
532
def lengths_to_segment_ids (lengths : torch .Tensor ) -> torch .Tensor :
527
533
return torch .repeat_interleave (
528
534
torch ._dim_arange (lengths , 0 ).long (),
@@ -912,6 +918,9 @@ def test_dense_to_jagged_dynamic_shape(
912
918
dtype : torch .dtype ,
913
919
device_type : str ,
914
920
) -> None :
921
+ # Start a fresh compile for each parameter of the test case
922
+ torch ._dynamo .reset ()
923
+
915
924
values_2d , offsets , max_lengths = self ._generate_jagged_tensor (
916
925
num_jagged_dim ,
917
926
outer_dense_size ,
@@ -1248,6 +1257,9 @@ def test_jagged_elementwise_binary_dynamic_shape(
1248
1257
dtype : torch .dtype ,
1249
1258
device_type : str ,
1250
1259
) -> None :
1260
+ # Start a fresh compile for each parameter of the test case
1261
+ torch ._dynamo .reset ()
1262
+
1251
1263
device = torch .device (device_type )
1252
1264
1253
1265
x_values , x_offsets , max_lengths = self ._generate_jagged_tensor (
@@ -1514,6 +1526,9 @@ def test_jagged_dense_dense_elementwise_add_jagged_output_dynamic_shape(
1514
1526
dtype : torch .dtype ,
1515
1527
device_type : str ,
1516
1528
) -> None :
1529
+ # Start a fresh compile for each parameter of the test case
1530
+ torch ._dynamo .reset ()
1531
+
1517
1532
x_values , x_offsets , max_lengths = self ._generate_jagged_tensor (
1518
1533
num_jagged_dim ,
1519
1534
outer_dense_size ,
@@ -1720,6 +1735,9 @@ def test_batched_dense_vec_jagged_2d_mul_dynamic_shape(
1720
1735
dtype : torch .dtype ,
1721
1736
device_type : str ,
1722
1737
) -> None :
1738
+ # Start a fresh compile for each parameter of the test case
1739
+ torch ._dynamo .reset ()
1740
+
1723
1741
assume (H == 1 or B != 0 )
1724
1742
1725
1743
device = torch .device (device_type )
@@ -2405,6 +2423,9 @@ def test_jagged_dense_bmm_dynamic_shape(
2405
2423
dtype : torch .dtype ,
2406
2424
device_type : str ,
2407
2425
) -> None :
2426
+ # Start a fresh compile for each parameter of the test case
2427
+ torch ._dynamo .reset ()
2428
+
2408
2429
assume (B != 0 )
2409
2430
device = torch .device (device_type )
2410
2431
torch .backends .cuda .matmul .allow_tf32 = False
0 commit comments