Skip to content

fix: prevent crash on KEYLOCK_ACQUIRED check for NO_KEY_TRANSACTIONAL commands #5185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 29, 2025
77 changes: 77 additions & 0 deletions src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2767,4 +2767,81 @@ TEST_F(SearchFamilyTest, JsonSetIndexesBug) {
resp = Run({"FT.AGGREGATE", "index", "*", "GROUPBY", "1", "@text"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("text", "some text")));
}

TEST_F(SearchFamilyTest, SearchReindexWriteSearchRace) {
const std::string kIndexName = "myRaceIdx";
const int kWriterOps = 2000; // Number of documents to write
const int kSearcherOps = 1500; // Number of search operations
const int kReindexerOps = 25; // Number of times to drop/recreate the index

// 1. Initial FT.CREATE for the index
// Schema from the issue: content TEXT SORTABLE, tags TAG SORTABLE, numeric_field NUMERIC SORTABLE
Run({"ft.create", kIndexName, "ON", "HASH", "PREFIX", "1", "doc:", "SCHEMA", "content", "TEXT",
"SORTABLE", "tags", "TAG", "SORTABLE", "numeric_field", "NUMERIC", "SORTABLE"});

// 2. writer_fiber
auto writer_fiber = pp_->at(0)->LaunchFiber([&] {
for (int i = 1; i <= kWriterOps; ++i) {
std::string doc_key = absl::StrCat("doc:", i);
std::string content = absl::StrCat("text data item ", i, " for race condition test");
std::string tags_val = absl::StrCat("tagA,tagB,", (i % 10));
std::string numeric_field_val = std::to_string(i);
try {
Run({"hset", doc_key, "content", content, "tags", tags_val, "numeric_field",
numeric_field_val});
} catch (const std::exception& e) {
}
if (i % 100 == 0)
ThisFiber::SleepFor(std::chrono::microseconds(100)); // Brief yield
}
});

// 3. searcher_fiber
auto searcher_fiber = pp_->at(1)->LaunchFiber([&] {
for (int i = 1; i <= kSearcherOps; ++i) {
int random_val_content = 1 + (i % kWriterOps);
int random_tag_val = i % 10;
int random_val_numeric = 1 + (i % kWriterOps);

std::string query_content = absl::StrCat("@content:item", random_val_content);
std::string query_tags = absl::StrCat("@tags:{tagA|tagB|tag", random_tag_val, "}");
std::string query_numeric = absl::StrCat("@numeric_field:[", random_val_numeric, " ",
(random_val_numeric + 100), "]");
try {
Run({"ft.search", kIndexName, query_content});
Run({"ft.search", kIndexName, query_tags});
Run({"ft.search", kIndexName, query_numeric});
} catch (const std::exception& e) {
}
if (i % 50 == 0)
ThisFiber::SleepFor(std::chrono::microseconds(200 * (1 + i % 2)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that a constant will suffice here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed all sleeps. The bug is reproducible without them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even better.

}
});

// 4. reindexer_fiber
auto reindexer_fiber = pp_->at(2)->LaunchFiber([&] {
for (int i = 1; i <= kReindexerOps; ++i) {
try {
Run({"ft.create", kIndexName, "ON", "HASH", "PREFIX", "1", "doc:", "SCHEMA", "content",
"TEXT", "SORTABLE", "tags", "TAG", "SORTABLE", "numeric_field", "NUMERIC",
"SORTABLE"});
} catch (const std::exception& e) {
}
ThisFiber::SleepFor(std::chrono::milliseconds(10 + (i % 5 * 5)));
try {
Run({"ft.dropindex", kIndexName});
} catch (const std::exception& e) {
}
ThisFiber::SleepFor(std::chrono::microseconds(500 * (1 + i % 2)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure you need reindexer_fiber at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to remove every part. This test is as minimal as I could create to reproduce.

}
});

// Join fibers
writer_fiber.Join();
searcher_fiber.Join();
reindexer_fiber.Join();

ASSERT_FALSE(service_->IsShardSetLocked());
}

} // namespace dfly
26 changes: 7 additions & 19 deletions src/server/transaction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1247,25 +1247,13 @@ bool Transaction::CancelShardCb(EngineShard* shard) {
auto lock_args = GetLockArgs(shard->shard_id());
DCHECK(sd.local_mask & KEYLOCK_ACQUIRED);

// In multi-transactions, especially with search indexing operations like
// FT.CREATE/FT.DROPINDEX, kv_fp_ can be cleared during MultiSwitchCmd, resulting in empty
// fingerprints.
//
// Example race condition scenario:
// 1. FT.CREATE/FT.DROPINDEX use CO::GLOBAL_TRANS flag, triggering multi-transaction mode
// 2. During concurrent operations (multiple redis-cli scripts running FT.DROPINDEX + HSET +
// FT.SEARCH)
// 3. MultiSwitchCmd is called via ConnectionContext::SwitchTxCmd() or MultiCommandSquasher
// 4. MultiSwitchCmd calls kv_fp_.clear()
// 5. Later, CancelShardCb tries to release locks but GetLockArgs returns empty fps
//
// This is a valid state in multi-transactions and we should skip lock release in such cases.
if (!lock_args.fps.empty()) {
GetDbSlice(shard->shard_id()).Release(LockMode(), lock_args);
} else {
VLOG(1) << "Skipping lock release for transaction " << DebugId()
<< " due to empty fingerprints";
}
// if (!lock_args.fps.empty()) {
DCHECK(!lock_args.fps.empty());
GetDbSlice(shard->shard_id()).Release(LockMode(), lock_args);
//} else {
// VLOG(1) << "Skipping lock release for transaction " << DebugId()
// << " due to empty fingerprints";
//}

sd.local_mask &= ~KEYLOCK_ACQUIRED;
}
Expand Down
242 changes: 0 additions & 242 deletions tests/dragonfly/search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
Search correctness should be ensured with unit tests.
"""
import pytest
import random
import redis
import threading
import time
from redis import asyncio as aioredis
from .utility import *
from . import dfly_args
Expand Down Expand Up @@ -562,241 +558,3 @@ def make_car(producer, description, speed):

for index in client.execute_command("FT._LIST"):
client.ft(index.decode()).dropindex()


@dfly_args({"proactor_threads": 4})
async def test_search_race_condition_threaded_issue_5173(async_client: aioredis.Redis, df_server):
"""
Alternative test using threading for true parallelism to reproduce race condition.
This version uses actual threads instead of asyncio tasks for maximum concurrency.
"""
import threading
import time
import redis # Sync redis client for threading

# Index name for the test
index_name = "myRaceIdx"

# Shared flags
server_crashed = threading.Event()
stop_threads = threading.Event()

def writer_thread():
"""Thread that writes documents"""
client = redis.Redis(port=df_server.port, decode_responses=True)
try:
i = 1
while not stop_threads.is_set() and not server_crashed.is_set():
try:
client.hset(
f"doc:{i}",
mapping={
"content": f"text data item {i} for race condition test",
"tags": f"tagA,tagB,{i % 10}",
"numeric_field": i,
},
)
i += 1
# Reset counter to avoid huge numbers
if i > 100000:
i = 1
# Small delay to prevent CPU overload
time.sleep(0.001)
except (redis.ConnectionError, OSError):
server_crashed.set()
break
except Exception:
# Expected during index operations
time.sleep(0.001)
finally:
client.close()

def reindexer_thread():
"""Thread that drops and recreates index"""
client = redis.Redis(port=df_server.port, decode_responses=True)
try:
while not stop_threads.is_set() and not server_crashed.is_set():
try:
# Drop index
client.execute_command("FT.DROPINDEX", index_name)

# Recreate index
client.execute_command(
"FT.CREATE",
index_name,
"ON",
"HASH",
"PREFIX",
"1",
"doc:",
"SCHEMA",
"content",
"TEXT",
"SORTABLE",
"tags",
"TAG",
"SORTABLE",
"numeric_field",
"NUMERIC",
"SORTABLE",
)
# Small delay between index operations
time.sleep(0.01)
except (redis.ConnectionError, OSError):
server_crashed.set()
break
except Exception:
# Expected during concurrent operations
time.sleep(0.01)
finally:
client.close()

def searcher_thread():
"""Thread that performs search operations"""
client = redis.Redis(port=df_server.port, decode_responses=True)
try:
while not stop_threads.is_set() and not server_crashed.is_set():
try:
random_val = random.randint(1, 20000)
random_tag_val = random.randint(0, 20)

# Various search operations
client.execute_command("FT.SEARCH", index_name, f"@content:item{random_val}")

client.execute_command(
"FT.SEARCH", index_name, f"@tags:{{tagA|tagC|tag{random_tag_val}}}"
)

client.execute_command(
"FT.SEARCH", index_name, f"@numeric_field:[{random_val} {random_val + 100}]"
)

# Small delay between search operations
time.sleep(0.005)

except (redis.ConnectionError, OSError):
server_crashed.set()
break
except Exception:
# Expected during index operations
time.sleep(0.005)
finally:
client.close()

def writer2_thread():
"""Thread that writes additional documents (like writer2.sh)"""
client = redis.Redis(port=df_server.port, decode_responses=True)
try:
i = 50001 # Start from different range to avoid conflicts
while not stop_threads.is_set() and not server_crashed.is_set():
try:
client.hset(
f"doc:{i}",
mapping={
"content": f"additional text data item {i} for race condition test",
"tags": f"tagC,tagD,{i % 15}",
"numeric_field": i + 1000,
},
)
i += 1
# Reset counter to avoid huge numbers
if i > 150000:
i = 50001
# Small delay to prevent CPU overload
time.sleep(0.001)
except (redis.ConnectionError, OSError):
server_crashed.set()
break
except Exception:
# Expected during index operations
time.sleep(0.001)
finally:
client.close()

def deleter_thread():
"""Thread that deletes documents (like deleter.sh)"""
client = redis.Redis(port=df_server.port, decode_responses=True)
try:
while not stop_threads.is_set() and not server_crashed.is_set():
try:
# Delete random documents
doc_id = random.randint(1, 100000)
client.delete(f"doc:{doc_id}")

# Small delay between deletions
time.sleep(0.01)
except (redis.ConnectionError, OSError):
server_crashed.set()
break
except Exception:
# Expected during index operations
time.sleep(0.01)
finally:
client.close()

def health_monitor():
"""Monitor server health"""
client = redis.Redis(port=df_server.port, decode_responses=True)
try:
while not stop_threads.is_set() and not server_crashed.is_set():
try:
client.ping()
time.sleep(1.0)
except (redis.ConnectionError, OSError):
server_crashed.set()
break
except Exception:
time.sleep(0.5)
finally:
client.close()

# Start all threads
threads = [
threading.Thread(target=health_monitor, name="health"),
threading.Thread(target=writer_thread, name="writer"),
threading.Thread(target=writer2_thread, name="writer2"),
threading.Thread(target=deleter_thread, name="deleter"),
threading.Thread(target=reindexer_thread, name="reindexer"),
threading.Thread(target=searcher_thread, name="searcher"),
]

for thread in threads:
thread.start()

try:
# Wait for threads to complete or server to crash
start_time = time.time()
last_progress_time = start_time

while time.time() - start_time < 30: # 30 seconds
if server_crashed.is_set():
break

# Check if any thread died unexpectedly
alive_threads = [t for t in threads if t.is_alive()]
if len(alive_threads) < len(threads):
break

current_time = time.time()
if current_time - last_progress_time >= 30:
elapsed = current_time - start_time
remaining = 600 - elapsed
last_progress_time = current_time

time.sleep(0.1)

# Stop all threads
stop_threads.set()

# Wait for threads to finish
for thread in threads:
thread.join(timeout=5.0)

except Exception as e:
stop_threads.set()

# Check if server crashed
if server_crashed.is_set():
pytest.fail(
"Dragonfly server crashed during threaded race condition test - issue #5173 reproduced!"
)