Skip to content

Commit f8b3e30

Browse files
committed
Fix jagged_test_index_select_2d that hangs in OSS and revert skip tests (#2036)
Summary: Pull Request resolved: #2036 Before c++20, `std::atomic_flag` is initialized to an unspecified state, hence the loop `while (lock.test_and_set(std::memory_order_acquire)` is never broken and causes the test to hang in OSS. This diff properly initializes the `std::atomic_flag`. Reviewed By: q10, sryap Differential Revision: D49528661 fbshipit-source-id: ba2213cb9bf8c0abbd1e169db03f0e32dd2a7ebb
1 parent a3446ae commit f8b3e30

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,13 @@ void jagged_index_add_2d_kernel(
11511151
const auto num_cols = input.size(1);
11521152
// Allocate one lock per row
11531153
std::atomic_flag* locks = new std::atomic_flag[output.size(0)];
1154+
// Initialize all locks since before c++20 std::atomic_flag is initialized to
1155+
// an unspecified state.
1156+
// https://en.cppreference.com/w/cpp/atomic/atomic_flag/atomic_flag
1157+
for (auto i = 0; i < output.size(0); i++) {
1158+
locks[i].clear();
1159+
}
1160+
11541161
at::parallel_for(0, num_dense_input_rows, 0, [&](int64_t start, int64_t end) {
11551162
for (const auto dense_input_offset : c10::irange(start, end)) {
11561163
int index_pos;

fbgemm_gpu/test/jagged_tensor_ops_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
gpu_available,
2828
gpu_unavailable,
2929
on_arm_platform,
30-
running_on_github,
3130
symint_vector_unsupported,
3231
TEST_WITH_ROCM,
3332
)
@@ -38,7 +37,6 @@
3837
gpu_available,
3938
gpu_unavailable,
4039
on_arm_platform,
41-
running_on_github,
4240
symint_vector_unsupported,
4341
TEST_WITH_ROCM,
4442
)
@@ -1805,7 +1803,6 @@ def jagged_index_select_2d_ref(
18051803
new_embeddings = torch.index_select(values, 0, all_indices)
18061804
return new_embeddings
18071805

1808-
@unittest.skipIf(*running_on_github)
18091806
@given(
18101807
max_seq_length=st.integers(5, 10),
18111808
batch_size=st.integers(1, 128),
@@ -1826,7 +1823,7 @@ def jagged_index_select_2d_ref(
18261823
if (gpu_available and TEST_WITH_ROCM)
18271824
else st.just(True),
18281825
)
1829-
@settings(max_examples=20, deadline=None)
1826+
@settings(max_examples=20, deadline=None, verbosity=Verbosity.verbose)
18301827
def test_jagged_index_select_2d(
18311828
self,
18321829
max_seq_length: int,
@@ -1899,7 +1896,6 @@ def test_jagged_index_select_2d(
18991896
atol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None,
19001897
)
19011898

1902-
@unittest.skipIf(*running_on_github)
19031899
@given(
19041900
max_seq_length=st.integers(5, 10),
19051901
batch_size=st.integers(1, 128),

0 commit comments

Comments
 (0)