Skip to content

Commit 6a9ca4b

Browse files
committed
optimize code
1 parent 246feb3 commit 6a9ca4b

File tree

4 files changed

+57
-32
lines changed

4 files changed

+57
-32
lines changed

src/query/service/src/pipelines/processors/transforms/hash_join/desc.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
use std::collections::HashMap;
16+
use std::collections::HashSet;
1617

1718
use common_exception::Result;
1819
use common_functions::scalars::FunctionFactory;
@@ -28,15 +29,15 @@ use crate::sql::plans::JoinType;
2829

2930
pub struct RightJoinDesc {
3031
/// Record rows in build side that are matched with rows in probe side.
31-
pub(crate) build_indexes: RwLock<Vec<RowPtr>>,
32+
pub(crate) build_indexes: RwLock<HashSet<RowPtr>>,
3233
/// Record row in build side that is matched how many rows in probe side.
3334
pub(crate) row_state: RwLock<HashMap<RowPtr, usize>>,
3435
}
3536

3637
impl RightJoinDesc {
3738
pub fn create() -> Self {
3839
RightJoinDesc {
39-
build_indexes: RwLock::new(vec![]),
40+
build_indexes: RwLock::new(HashSet::new()),
4041
row_state: RwLock::new(HashMap::new()),
4142
}
4243
}

src/query/service/src/pipelines/processors/transforms/hash_join/join_hash_table.rs

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,37 @@ impl JoinHashTable {
432432
}
433433
}
434434
}
435+
436+
fn find_unmatched_build_indexes(&self) -> Result<Vec<RowPtr>> {
437+
// For right join, build side will appear at lease once in the joined table
438+
// Find the unmatched rows in build side
439+
let mut unmatched_build_indexes = vec![];
440+
{
441+
let chunks = self.row_space.chunks.read().unwrap();
442+
for (chunk_index, chunk) in chunks.iter().enumerate() {
443+
for row_index in 0..chunk.num_rows() {
444+
let row_ptr = RowPtr {
445+
chunk_index: chunk_index as u32,
446+
row_index: row_index as u32,
447+
marker: None,
448+
};
449+
if !self
450+
.hash_join_desc
451+
.right_join_desc
452+
.build_indexes
453+
.read()
454+
.contains(&row_ptr)
455+
{
456+
let mut row_state = self.hash_join_desc.right_join_desc.row_state.write();
457+
row_state.entry(row_ptr).or_insert(0_usize);
458+
unmatched_build_indexes.push(row_ptr);
459+
}
460+
}
461+
}
462+
drop(chunks);
463+
}
464+
Ok(unmatched_build_indexes)
465+
}
435466
}
436467

437468
#[async_trait::async_trait]
@@ -689,32 +720,7 @@ impl HashJoinState for JoinHashTable {
689720
}
690721

691722
fn right_join_blocks(&self, blocks: &[DataBlock]) -> Result<Vec<DataBlock>> {
692-
// For right join, build side will appear at lease once in the joined table
693-
// Find the unmatched rows in build side
694-
let mut unmatched_build_indexes = vec![];
695-
{
696-
for (chunk_index, chunk) in self.row_space.chunks.read().unwrap().iter().enumerate() {
697-
for row_index in 0..chunk.num_rows() {
698-
let row_ptr = RowPtr {
699-
chunk_index: chunk_index as u32,
700-
row_index: row_index as u32,
701-
marker: None,
702-
};
703-
if !self
704-
.hash_join_desc
705-
.right_join_desc
706-
.build_indexes
707-
.read()
708-
.contains(&row_ptr)
709-
{
710-
let mut row_state = self.hash_join_desc.right_join_desc.row_state.write();
711-
row_state.entry(row_ptr).or_insert(0_usize);
712-
unmatched_build_indexes.push(row_ptr);
713-
}
714-
}
715-
}
716-
}
717-
723+
let unmatched_build_indexes = self.find_unmatched_build_indexes()?;
718724
if unmatched_build_indexes.is_empty() && self.hash_join_desc.other_predicate.is_none() {
719725
return Ok(blocks.to_vec());
720726
}
@@ -770,7 +776,7 @@ impl HashJoinState for JoinHashTable {
770776
// If build_indexes size will greater build table size, we need filter the redundant rows for build side.
771777
let mut build_indexes = self.hash_join_desc.right_join_desc.build_indexes.write();
772778
let mut row_state = self.hash_join_desc.right_join_desc.row_state.write();
773-
build_indexes.extend_from_slice(&unmatched_build_indexes);
779+
build_indexes.extend(&unmatched_build_indexes);
774780
if build_indexes.len() > self.row_space.rows_number() {
775781
let mut bm = validity.into_mut().right().unwrap();
776782
Self::filter_rows_for_right_join(&mut bm, &build_indexes, &mut row_state);

src/query/service/src/pipelines/processors/transforms/hash_join/result_blocks.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::collections::HashSet;
1516
use std::iter::TrustedLen;
1617

1718
use common_arrow::arrow::bitmap::Bitmap;
@@ -546,7 +547,7 @@ impl JoinHashTable {
546547
{
547548
let mut build_indexes =
548549
self.hash_join_desc.right_join_desc.build_indexes.write();
549-
build_indexes.extend_from_slice(probe_result_ptrs);
550+
build_indexes.extend(probe_result_ptrs);
550551
local_build_indexes.extend_from_slice(probe_result_ptrs);
551552
}
552553
for row_ptr in probe_result_ptrs.iter() {
@@ -655,7 +656,7 @@ impl JoinHashTable {
655656

656657
pub(crate) fn filter_rows_for_right_join(
657658
bm: &mut MutableBitmap,
658-
build_indexes: &[RowPtr],
659+
build_indexes: &HashSet<RowPtr>,
659660
row_state: &mut std::collections::HashMap<RowPtr, usize>,
660661
) {
661662
for (index, row) in build_indexes.iter().enumerate() {

src/query/service/src/pipelines/processors/transforms/hash_join/row.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::hash::Hash;
16+
use std::hash::Hasher;
1517
use std::sync::RwLock;
1618

1719
use common_datablocks::DataBlock;
@@ -36,7 +38,7 @@ impl Chunk {
3638
}
3739
}
3840

39-
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
41+
#[derive(Clone, Copy, Debug)]
4042
pub struct RowPtr {
4143
pub chunk_index: u32,
4244
pub row_index: u32,
@@ -110,3 +112,18 @@ impl RowSpace {
110112
}
111113
}
112114
}
115+
116+
impl PartialEq for RowPtr {
117+
fn eq(&self, other: &Self) -> bool {
118+
self.chunk_index == other.chunk_index && self.row_index == other.row_index
119+
}
120+
}
121+
122+
impl Eq for RowPtr {}
123+
124+
impl Hash for RowPtr {
125+
fn hash<H: Hasher>(&self, state: &mut H) {
126+
self.chunk_index.hash(state);
127+
self.row_index.hash(state);
128+
}
129+
}

0 commit comments

Comments
 (0)