Skip to content

Commit 7927220

Browse files
anijain2305spcyppt
authored andcommitted
Clear torch.compile cache between tests (#1992)
Summary: Pull Request resolved: #1992 Preparing for reducing cache size work. Reviewed By: jspark1105 Differential Revision: D48940544 fbshipit-source-id: e0047c784918d886a935f2ec059a05fbc5064e6f
1 parent 52a96be commit 7927220

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

fbgemm_gpu/test/jagged_tensor_ops_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,9 @@ def test_jagged_2d_to_dense_dynamic_shape(
320320
dtype: torch.dtype,
321321
device_type: str,
322322
) -> None:
323+
# Start a fresh compile for each parameter of the test case
324+
torch._dynamo.reset()
325+
323326
D = D * 4
324327
lengths_ = np.random.randint(low=0, high=max_sequence_length, size=B)
325328
total_lengths = lengths_.sum()
@@ -523,6 +526,9 @@ def test_jagged_1d_to_dense_truncation(self) -> None:
523526
def test_jagged_1d_to_dense_dynamic_shape(
524527
self, B: int, max_sequence_length: int, padding_value: int, device_type: str
525528
) -> None:
529+
# Start a fresh compile for each parameter of the test case
530+
torch._dynamo.reset()
531+
526532
def lengths_to_segment_ids(lengths: torch.Tensor) -> torch.Tensor:
527533
return torch.repeat_interleave(
528534
torch._dim_arange(lengths, 0).long(),
@@ -912,6 +918,9 @@ def test_dense_to_jagged_dynamic_shape(
912918
dtype: torch.dtype,
913919
device_type: str,
914920
) -> None:
921+
# Start a fresh compile for each parameter of the test case
922+
torch._dynamo.reset()
923+
915924
values_2d, offsets, max_lengths = self._generate_jagged_tensor(
916925
num_jagged_dim,
917926
outer_dense_size,
@@ -1248,6 +1257,9 @@ def test_jagged_elementwise_binary_dynamic_shape(
12481257
dtype: torch.dtype,
12491258
device_type: str,
12501259
) -> None:
1260+
# Start a fresh compile for each parameter of the test case
1261+
torch._dynamo.reset()
1262+
12511263
device = torch.device(device_type)
12521264

12531265
x_values, x_offsets, max_lengths = self._generate_jagged_tensor(
@@ -1514,6 +1526,9 @@ def test_jagged_dense_dense_elementwise_add_jagged_output_dynamic_shape(
15141526
dtype: torch.dtype,
15151527
device_type: str,
15161528
) -> None:
1529+
# Start a fresh compile for each parameter of the test case
1530+
torch._dynamo.reset()
1531+
15171532
x_values, x_offsets, max_lengths = self._generate_jagged_tensor(
15181533
num_jagged_dim,
15191534
outer_dense_size,
@@ -1720,6 +1735,9 @@ def test_batched_dense_vec_jagged_2d_mul_dynamic_shape(
17201735
dtype: torch.dtype,
17211736
device_type: str,
17221737
) -> None:
1738+
# Start a fresh compile for each parameter of the test case
1739+
torch._dynamo.reset()
1740+
17231741
assume(H == 1 or B != 0)
17241742

17251743
device = torch.device(device_type)
@@ -2405,6 +2423,9 @@ def test_jagged_dense_bmm_dynamic_shape(
24052423
dtype: torch.dtype,
24062424
device_type: str,
24072425
) -> None:
2426+
# Start a fresh compile for each parameter of the test case
2427+
torch._dynamo.reset()
2428+
24082429
assume(B != 0)
24092430
device = torch.device(device_type)
24102431
torch.backends.cuda.matmul.allow_tf32 = False

0 commit comments

Comments
 (0)