Skip to content

Commit 6463153

Browse files
authored
fix: Fix shuffle writing rows containing null struct fields (#1845)
1 parent e02b6cd commit 6463153

File tree

2 files changed

+81
-16
lines changed

2 files changed

+81
-16
lines changed

native/core/src/execution/shuffle/row.rs

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -444,25 +444,18 @@ pub(crate) fn append_field(
444444
// Appending value into struct field builder of Arrow struct builder.
445445
let field_builder = struct_builder.field_builder::<StructBuilder>(idx).unwrap();
446446

447-
if row.is_null_row() {
448-
// The row is null.
447+
let nested_row = if row.is_null_row() || row.is_null_at(idx) {
448+
// The row is null, or the field in the row is null, i.e., a null nested row.
449+
// Append a null value to the row builder.
449450
field_builder.append_null();
451+
SparkUnsafeRow::default()
450452
} else {
451-
let is_null = row.is_null_at(idx);
453+
field_builder.append(true);
454+
row.get_struct(idx, fields.len())
455+
};
452456

453-
let nested_row = if is_null {
454-
// The field in the row is null, i.e., a null nested row.
455-
// Append a null value to the row builder.
456-
field_builder.append_null();
457-
SparkUnsafeRow::default()
458-
} else {
459-
field_builder.append(true);
460-
row.get_struct(idx, fields.len())
461-
};
462-
463-
for (field_idx, field) in fields.into_iter().enumerate() {
464-
append_field(field.data_type(), field_builder, &nested_row, field_idx)?;
465-
}
457+
for (field_idx, field) in fields.into_iter().enumerate() {
458+
append_field(field.data_type(), field_builder, &nested_row, field_idx)?;
466459
}
467460
}
468461
DataType::Map(field, _) => {
@@ -3302,3 +3295,45 @@ fn make_batch(arrays: Vec<ArrayRef>, row_count: usize) -> Result<RecordBatch, Ar
33023295
let options = RecordBatchOptions::new().with_row_count(Option::from(row_count));
33033296
RecordBatch::try_new_with_options(schema, arrays, &options)
33043297
}
3298+
3299+
#[cfg(test)]
3300+
mod test {
3301+
use arrow::datatypes::Fields;
3302+
3303+
use super::*;
3304+
3305+
#[test]
3306+
fn test_append_null_row_to_struct_builder() {
3307+
let data_type = DataType::Struct(Fields::from(vec![
3308+
Field::new("a", DataType::Boolean, true),
3309+
Field::new("b", DataType::Boolean, true),
3310+
]));
3311+
let fields = Fields::from(vec![Field::new("st", data_type.clone(), true)]);
3312+
let mut struct_builder = StructBuilder::from_fields(fields, 1);
3313+
let row = SparkUnsafeRow::default();
3314+
append_field(&data_type, &mut struct_builder, &row, 0).expect("append field");
3315+
struct_builder.append_null();
3316+
let struct_array = struct_builder.finish();
3317+
assert_eq!(struct_array.len(), 1);
3318+
assert!(struct_array.is_null(0));
3319+
}
3320+
3321+
#[test]
3322+
#[cfg_attr(miri, ignore)] // Unaligned memory access in SparkUnsafeRow
3323+
fn test_append_null_struct_field_to_struct_builder() {
3324+
let data_type = DataType::Struct(Fields::from(vec![
3325+
Field::new("a", DataType::Boolean, true),
3326+
Field::new("b", DataType::Boolean, true),
3327+
]));
3328+
let fields = Fields::from(vec![Field::new("st", data_type.clone(), true)]);
3329+
let mut struct_builder = StructBuilder::from_fields(fields, 1);
3330+
let mut row = SparkUnsafeRow::new_with_num_fields(1);
3331+
let data = [0; 8];
3332+
row.point_to_slice(&data);
3333+
append_field(&data_type, &mut struct_builder, &row, 0).expect("append field");
3334+
struct_builder.append_null();
3335+
let struct_array = struct_builder.finish();
3336+
assert_eq!(struct_array.len(), 1);
3337+
assert!(struct_array.is_null(0));
3338+
}
3339+
}

spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
package org.apache.comet.exec
2121

22+
import java.nio.file.Files
23+
import java.nio.file.Paths
24+
2225
import scala.reflect.runtime.universe._
2326
import scala.util.Random
2427

@@ -820,6 +823,33 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
820823
}
821824
}
822825

826+
test("columnar shuffle on null struct fields") {
827+
withTempDir { dir =>
828+
val testData = "{}\n"
829+
val path = Paths.get(dir.toString, "test.json")
830+
Files.write(path, testData.getBytes)
831+
832+
// Define the nested struct schema
833+
val readSchema = StructType(
834+
Array(
835+
StructField(
836+
"metaData",
837+
StructType(
838+
Array(StructField(
839+
"format",
840+
StructType(Array(StructField("provider", StringType, nullable = true))),
841+
nullable = true))),
842+
nullable = true)))
843+
844+
// Read JSON with custom schema and repartition, this will repartition rows that contain
845+
// null struct fields.
846+
val df = spark.read.format("json").schema(readSchema).load(path.toString).repartition(2)
847+
assert(df.count() == 1)
848+
val row = df.collect()(0)
849+
assert(row.getAs[org.apache.spark.sql.Row]("metaData") == null)
850+
}
851+
}
852+
823853
/**
824854
* Checks that `df` produces the same answer as Spark does, and has the `expectedNum` Comet
825855
* exchange operators.

0 commit comments

Comments
 (0)