Skip to content

Commit 9982ad9

Browse files
committed
Improve: RoundRobinVec code nesting
1 parent a707d55 commit 9982ad9

File tree

1 file changed

+95
-103
lines changed

1 file changed

+95
-103
lines changed

rust/lib.rs

Lines changed: 95 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,21 +2290,22 @@ impl<T> RoundRobinVec<T> {
22902290
let pool_ptr = SafePtr(pool as *const ThreadPool as *mut ThreadPool);
22912291

22922292
pool.for_threads(move |thread_index, colocation_index| {
2293-
if colocation_index < colocations_count {
2294-
// Get the specific pinned vector for this NUMA node
2295-
let node_vec = safe_ptr.get_mut_at(colocation_index);
2296-
let pool = pool_ptr.get_mut();
2297-
2298-
let threads_in_colocation = pool.count_threads_in(colocation_index);
2299-
let thread_local_index = pool.locate_thread_in(thread_index, colocation_index);
2300-
let split = IndexedSplit::new(node_vec.len(), threads_in_colocation);
2301-
let range = split.get(thread_local_index);
2293+
if colocation_index >= colocations_count {
2294+
return;
2295+
}
23022296

2303-
// Fill the assigned range of this thread
2304-
for idx in range {
2305-
if let Some(element) = node_vec.get_mut(idx) {
2306-
*element = value.clone();
2307-
}
2297+
let node_vec = safe_ptr.get_mut_at(colocation_index);
2298+
let pool = pool_ptr.get_mut();
2299+
2300+
let threads_in_colocation = pool.count_threads_in(colocation_index);
2301+
let thread_local_index = pool.locate_thread_in(thread_index, colocation_index);
2302+
let split = IndexedSplit::new(node_vec.len(), threads_in_colocation);
2303+
let range = split.get(thread_local_index);
2304+
2305+
// Fill the assigned range of this thread
2306+
for idx in range {
2307+
if let Some(element) = node_vec.get_mut(idx) {
2308+
*element = value.clone();
23082309
}
23092310
}
23102311
});
@@ -2346,22 +2347,23 @@ impl<T> RoundRobinVec<T> {
23462347
let pool_ptr = SafePtr(pool as *const ThreadPool as *mut ThreadPool);
23472348

23482349
pool.for_threads(move |thread_index, colocation_index| {
2349-
if colocation_index < colocations_count {
2350-
// Get the specific pinned vector for this NUMA node
2351-
let node_vec = safe_ptr.get_mut_at(colocation_index);
2352-
let f_ref = f_ptr.get_mut();
2353-
let pool = pool_ptr.get_mut();
2354-
2355-
let threads_in_colocation = pool.count_threads_in(colocation_index);
2356-
let thread_local_index = pool.locate_thread_in(thread_index, colocation_index);
2357-
let split = IndexedSplit::new(node_vec.len(), threads_in_colocation);
2358-
let range = split.get(thread_local_index);
2350+
if colocation_index >= colocations_count {
2351+
return;
2352+
}
23592353

2360-
// Fill the assigned range of this thread
2361-
for idx in range {
2362-
if let Some(element) = node_vec.get_mut(idx) {
2363-
*element = f_ref();
2364-
}
2354+
let node_vec = safe_ptr.get_mut_at(colocation_index);
2355+
let f_ref = f_ptr.get_mut();
2356+
let pool = pool_ptr.get_mut();
2357+
2358+
let threads_in_colocation = pool.count_threads_in(colocation_index);
2359+
let thread_local_index = pool.locate_thread_in(thread_index, colocation_index);
2360+
let split = IndexedSplit::new(node_vec.len(), threads_in_colocation);
2361+
let range = split.get(thread_local_index);
2362+
2363+
// Fill the assigned range of this thread
2364+
for idx in range {
2365+
if let Some(element) = node_vec.get_mut(idx) {
2366+
*element = f_ref();
23652367
}
23662368
}
23672369
});
@@ -2378,22 +2380,23 @@ impl<T> RoundRobinVec<T> {
23782380
let pool_ptr = SafePtr(pool as *const ThreadPool as *mut ThreadPool);
23792381

23802382
pool.for_threads(move |thread_index, colocation_index| {
2381-
if colocation_index < colocations_count {
2382-
// Get the specific pinned vector for this NUMA node
2383-
let node_vec = safe_ptr.get_mut_at(colocation_index);
2384-
let pool = pool_ptr.get_mut();
2385-
2386-
let threads_in_colocation = pool.count_threads_in(colocation_index);
2387-
let thread_local_index = pool.locate_thread_in(thread_index, colocation_index);
2388-
let split = IndexedSplit::new(node_vec.len(), threads_in_colocation);
2389-
let range = split.get(thread_local_index);
2383+
if colocation_index >= colocations_count {
2384+
return;
2385+
}
23902386

2391-
// Drop elements in the assigned range
2392-
unsafe {
2393-
let ptr = node_vec.as_mut_ptr();
2394-
for idx in range {
2395-
core::ptr::drop_in_place(ptr.add(idx));
2396-
}
2387+
let node_vec = safe_ptr.get_mut_at(colocation_index);
2388+
let pool = pool_ptr.get_mut();
2389+
2390+
let threads_in_colocation = pool.count_threads_in(colocation_index);
2391+
let thread_local_index = pool.locate_thread_in(thread_index, colocation_index);
2392+
let split = IndexedSplit::new(node_vec.len(), threads_in_colocation);
2393+
let range = split.get(thread_local_index);
2394+
2395+
// Drop elements in the assigned range
2396+
unsafe {
2397+
let ptr = node_vec.as_mut_ptr();
2398+
for idx in range {
2399+
core::ptr::drop_in_place(ptr.add(idx));
23972400
}
23982401
}
23992402
});
@@ -2435,18 +2438,21 @@ impl<T> RoundRobinVec<T> {
24352438
let elements_per_node = new_len / colocations_count;
24362439
let extra_elements = new_len % colocations_count;
24372440

2438-
// Step 1: Centrally handle reallocation for each NUMA node
2439-
for i in 0..colocations_count {
2440-
let node_len = if i < extra_elements {
2441+
// Helper to calculate target length for a colocation
2442+
let node_len = |col_idx: usize| -> usize {
2443+
if col_idx < extra_elements {
24412444
elements_per_node + 1
24422445
} else {
24432446
elements_per_node
2444-
};
2447+
}
2448+
};
24452449

2450+
// Step 1: Centrally handle reallocation for each NUMA node
2451+
for i in 0..colocations_count {
2452+
let target_len = node_len(i);
24462453
let current_len = self.colocations[i].len();
2447-
if node_len > current_len {
2448-
// Need to reserve more capacity
2449-
self.colocations[i].reserve(node_len - current_len)?;
2454+
if target_len > current_len {
2455+
self.colocations[i].reserve(target_len - current_len)?;
24502456
}
24512457
}
24522458

@@ -2455,69 +2461,55 @@ impl<T> RoundRobinVec<T> {
24552461
let pool_ptr = SafePtr(pool as *const ThreadPool as *mut ThreadPool);
24562462

24572463
pool.for_threads(move |thread_index, colocation_index| {
2458-
if colocation_index < colocations_count {
2459-
// Get the specific pinned vector for this NUMA node
2460-
let node_vec = safe_ptr.get_mut_at(colocation_index);
2461-
let pool = pool_ptr.get_mut();
2462-
2463-
let node_len = if colocation_index < extra_elements {
2464-
elements_per_node + 1
2465-
} else {
2466-
elements_per_node
2467-
};
2468-
2469-
let current_len = node_vec.len();
2470-
let threads_in_colocation = pool.count_threads_in(colocation_index);
2471-
let thread_local_index = pool.locate_thread_in(thread_index, colocation_index);
2472-
2473-
match node_len.cmp(&current_len) {
2474-
std::cmp::Ordering::Greater => {
2475-
// Growing: construct new elements in parallel
2476-
let new_elements = node_len - current_len;
2477-
let split = IndexedSplit::new(new_elements, threads_in_colocation);
2478-
let range = split.get(thread_local_index);
2479-
2480-
unsafe {
2481-
let ptr = node_vec.as_mut_ptr();
2482-
for i in range {
2483-
let idx = current_len + i;
2484-
core::ptr::write(ptr.add(idx), value.clone());
2485-
}
2486-
}
2487-
}
2488-
std::cmp::Ordering::Less => {
2489-
// Shrinking: drop elements in parallel
2490-
let elements_to_drop = current_len - node_len;
2491-
let split = IndexedSplit::new(elements_to_drop, threads_in_colocation);
2492-
let range = split.get(thread_local_index);
2493-
2494-
unsafe {
2495-
let ptr = node_vec.as_mut_ptr();
2496-
for i in range {
2497-
let idx = node_len + i;
2498-
core::ptr::drop_in_place(ptr.add(idx));
2499-
}
2500-
}
2464+
if colocation_index >= colocations_count {
2465+
return;
2466+
}
2467+
2468+
let node_vec = safe_ptr.get_mut_at(colocation_index);
2469+
let pool = pool_ptr.get_mut();
2470+
let target_len = node_len(colocation_index);
2471+
let current_len = node_vec.len();
2472+
if target_len == current_len {
2473+
return;
2474+
}
2475+
2476+
let threads_in_colocation = pool.count_threads_in(colocation_index);
2477+
let thread_local_index = pool.locate_thread_in(thread_index, colocation_index);
2478+
2479+
if target_len > current_len {
2480+
// Growing: construct new elements in parallel
2481+
let new_elements = target_len - current_len;
2482+
let split = IndexedSplit::new(new_elements, threads_in_colocation);
2483+
let range = split.get(thread_local_index);
2484+
2485+
unsafe {
2486+
let ptr = node_vec.as_mut_ptr();
2487+
for i in range {
2488+
core::ptr::write(ptr.add(current_len + i), value.clone());
25012489
}
2502-
std::cmp::Ordering::Equal => {
2503-
// No change needed
2490+
}
2491+
} else {
2492+
// Shrinking: drop elements in parallel
2493+
let elements_to_drop = current_len - target_len;
2494+
let split = IndexedSplit::new(elements_to_drop, threads_in_colocation);
2495+
let range = split.get(thread_local_index);
2496+
2497+
unsafe {
2498+
let ptr = node_vec.as_mut_ptr();
2499+
for i in range {
2500+
core::ptr::drop_in_place(ptr.add(target_len + i));
25042501
}
25052502
}
25062503
}
25072504
});
25082505

25092506
// Step 3: Update lengths after parallel operations
25102507
for i in 0..colocations_count {
2511-
let node_len = if i < extra_elements {
2512-
elements_per_node + 1
2513-
} else {
2514-
elements_per_node
2515-
};
2516-
self.colocations[i].len = node_len;
2508+
self.colocations[i].len = node_len(i);
25172509
}
25182510

25192511
self.total_length = new_len;
2520-
self.total_capacity = self.capacity(); // Recalculate total capacity
2512+
self.total_capacity = self.capacity();
25212513
Ok(())
25222514
}
25232515
}

0 commit comments

Comments
 (0)