Skip to content

Make dict ID only an IPC concern #7929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion arrow-flight/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ impl FlightDataDecoder {

self.state = Some(FlightStreamState {
schema: Arc::clone(&schema),
schema_message: data.clone(),
dictionaries_by_field,
});
Ok(Some(DecodedFlightData::new_schema(data, schema)))
Expand All @@ -296,10 +297,15 @@ impl FlightDataDecoder {
)
})?;

let ipc_schema = arrow_ipc::root_as_message(&state.schema_message.data_header)
.unwrap()
.header_as_schema()
.unwrap();

arrow_ipc::reader::read_dictionary(
&buffer,
dictionary_batch,
&state.schema,
ipc_schema,
&mut state.dictionaries_by_field,
&message.version(),
)
Expand All @@ -319,8 +325,14 @@ impl FlightDataDecoder {
));
};

let ipc_schema = arrow_ipc::root_as_message(&state.schema_message.data_header)
.unwrap()
.header_as_schema()
.unwrap();

let batch = flight_data_to_arrow_batch(
&data,
ipc_schema,
Arc::clone(&state.schema),
&state.dictionaries_by_field,
)
Expand Down Expand Up @@ -376,6 +388,7 @@ impl futures::Stream for FlightDataDecoder {
#[derive(Debug)]
struct FlightStreamState {
schema: SchemaRef,
schema_message: FlightData,
dictionaries_by_field: HashMap<i64, ArrayRef>,
}

Expand Down
29 changes: 8 additions & 21 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,15 +535,13 @@ fn prepare_field_for_flight(
)
.with_metadata(field.metadata().clone())
} else {
#[allow(deprecated)]
let dict_id = dictionary_tracker.set_dict_id(field.as_ref());

dictionary_tracker.next_dict_id();
#[allow(deprecated)]
Field::new_dict(
field.name(),
field.data_type().clone(),
field.is_nullable(),
dict_id,
0,
field.dict_is_ordered().unwrap_or_default(),
)
.with_metadata(field.metadata().clone())
Expand Down Expand Up @@ -585,14 +583,13 @@ fn prepare_schema_for_flight(
)
.with_metadata(field.metadata().clone())
} else {
#[allow(deprecated)]
let dict_id = dictionary_tracker.set_dict_id(field.as_ref());
dictionary_tracker.next_dict_id();
#[allow(deprecated)]
Field::new_dict(
field.name(),
field.data_type().clone(),
field.is_nullable(),
dict_id,
0,
field.dict_is_ordered().unwrap_or_default(),
)
.with_metadata(field.metadata().clone())
Expand Down Expand Up @@ -654,16 +651,10 @@ struct FlightIpcEncoder {

impl FlightIpcEncoder {
fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self {
#[allow(deprecated)]
let preserve_dict_id = options.preserve_dict_id();
Self {
options,
data_gen: IpcDataGenerator::default(),
#[allow(deprecated)]
dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id(
error_on_replacement,
preserve_dict_id,
),
dictionary_tracker: DictionaryTracker::new(error_on_replacement),
}
}

Expand Down Expand Up @@ -1547,9 +1538,8 @@ mod tests {
async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
let expected_schema = batches.first().unwrap().schema();

#[allow(deprecated)]
let encoder = FlightDataEncoderBuilder::default()
.with_options(IpcWriteOptions::default().with_preserve_dict_id(false))
.with_options(IpcWriteOptions::default())
.with_dictionary_handling(DictionaryHandling::Resend)
.build(futures::stream::iter(batches.clone().into_iter().map(Ok)));

Expand All @@ -1575,8 +1565,7 @@ mod tests {
HashMap::from([("some_key".to_owned(), "some_value".to_owned())]),
);

#[allow(deprecated)]
let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
let mut dictionary_tracker = DictionaryTracker::new(false);

let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false);
assert!(got.metadata().contains_key("some_key"));
Expand Down Expand Up @@ -1606,9 +1595,7 @@ mod tests {
options: &IpcWriteOptions,
) -> (Vec<FlightData>, FlightData) {
let data_gen = IpcDataGenerator::default();
#[allow(deprecated)]
let mut dictionary_tracker =
DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
let mut dictionary_tracker = DictionaryTracker::new(false);

let (encoded_dictionaries, encoded_batch) = data_gen
.encoded_batch(batch, &mut dictionary_tracker, options)
Expand Down
4 changes: 1 addition & 3 deletions arrow-flight/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ pub struct IpcMessage(pub Bytes);

fn flight_schema_as_encoded_data(arrow_schema: &Schema, options: &IpcWriteOptions) -> EncodedData {
let data_gen = writer::IpcDataGenerator::default();
#[allow(deprecated)]
let mut dict_tracker =
writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
let mut dict_tracker = writer::DictionaryTracker::new(false);
data_gen.schema_to_bytes_with_dictionary_tracker(arrow_schema, &mut dict_tracker, options)
}

Expand Down
2 changes: 2 additions & 0 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ pub enum ArrowFlightData {
pub fn arrow_data_from_flight_data(
flight_data: FlightData,
arrow_schema_ref: &SchemaRef,
ipc_schema: arrow_ipc::Schema,
) -> Result<ArrowFlightData, ArrowError> {
let ipc_message = root_as_message(&flight_data.data_header[..])
.map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?;
Expand All @@ -723,6 +724,7 @@ pub fn arrow_data_from_flight_data(
let record_batch = read_record_batch(
&Buffer::from(flight_data.data_body),
ipc_record_batch,
ipc_schema,
arrow_schema_ref.clone(),
&dictionaries_by_field,
None,
Expand Down
9 changes: 5 additions & 4 deletions arrow-flight/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ pub fn flight_data_to_batches(flight_data: &[FlightData]) -> Result<Vec<RecordBa
let mut batches = vec![];
let dictionaries_by_id = HashMap::new();
for datum in flight_data[1..].iter() {
let batch = flight_data_to_arrow_batch(datum, schema.clone(), &dictionaries_by_id)?;
let batch =
flight_data_to_arrow_batch(datum, ipc_schema, schema.clone(), &dictionaries_by_id)?;
batches.push(batch);
}
Ok(batches)
Expand All @@ -53,6 +54,7 @@ pub fn flight_data_to_batches(flight_data: &[FlightData]) -> Result<Vec<RecordBa
/// Convert `FlightData` (with supplied schema and dictionaries) to an arrow `RecordBatch`.
pub fn flight_data_to_arrow_batch(
data: &FlightData,
ipc_schema: arrow_ipc::Schema,
schema: SchemaRef,
dictionaries_by_id: &HashMap<i64, ArrayRef>,
) -> Result<RecordBatch, ArrowError> {
Expand All @@ -71,6 +73,7 @@ pub fn flight_data_to_arrow_batch(
reader::read_record_batch(
&Buffer::from(data.data_body.as_ref()),
batch,
ipc_schema,
schema,
dictionaries_by_id,
None,
Expand All @@ -90,9 +93,7 @@ pub fn batches_to_flight_data(
let mut flight_data = vec![];

let data_gen = writer::IpcDataGenerator::default();
#[allow(deprecated)]
let mut dictionary_tracker =
writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
let mut dictionary_tracker = writer::DictionaryTracker::new(false);

for batch in batches.iter() {
let (encoded_dictionaries, encoded_batch) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use arrow_flight::{
use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
use tonic::{Request, Streaming};

use arrow::datatypes::Schema;
use std::sync::Arc;

type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
Expand Down Expand Up @@ -72,9 +71,7 @@ async fn upload_data(
let (mut upload_tx, upload_rx) = mpsc::channel(10);

let options = arrow::ipc::writer::IpcWriteOptions::default();
#[allow(deprecated)]
let mut dict_tracker =
writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
let mut dict_tracker = writer::DictionaryTracker::new(false);
let data_gen = writer::IpcDataGenerator::default();
let data = IpcMessage(
data_gen
Expand Down Expand Up @@ -217,33 +214,40 @@ async fn consume_flight_location(
let resp = client.do_get(ticket).await?;
let mut resp = resp.into_inner();

let flight_schema = receive_schema_flight_data(&mut resp)
let data = resp
.next()
.await
.unwrap_or_else(|| panic!("Failed to receive flight schema"));
let actual_schema = Arc::new(flight_schema);
.ok_or_else(|| Error::from("No data received from Flight server"))??;
let message =
arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");

// message header is a Schema, so read it
let ipc_schema: ipc::Schema = message
.header_as_schema()
.expect("Unable to read IPC message as schema");
let schema = Arc::new(ipc::convert::fb_to_schema(ipc_schema));

let mut dictionaries_by_id = HashMap::new();

for (counter, expected_batch) in expected_data.iter().enumerate() {
let data =
receive_batch_flight_data(&mut resp, actual_schema.clone(), &mut dictionaries_by_id)
.await
.unwrap_or_else(|| {
panic!(
"Got fewer batches than expected, received so far: {} expected: {}",
counter,
expected_data.len(),
)
});
let data = receive_batch_flight_data(&mut resp, ipc_schema, &mut dictionaries_by_id)
.await
.unwrap_or_else(|| {
panic!(
"Got fewer batches than expected, received so far: {} expected: {}",
counter,
expected_data.len(),
)
});

let metadata = counter.to_string().into_bytes();
assert_eq!(metadata, data.app_metadata);

let actual_batch =
flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id)
flight_data_to_arrow_batch(&data, ipc_schema, schema.clone(), &dictionaries_by_id)
.expect("Unable to convert flight data to Arrow batch");

assert_eq!(actual_schema, actual_batch.schema());
assert_eq!(schema, actual_batch.schema());
assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
let schema = expected_batch.schema();
Expand All @@ -267,23 +271,9 @@ async fn consume_flight_location(
Ok(())
}

async fn receive_schema_flight_data(resp: &mut Streaming<FlightData>) -> Option<Schema> {
let data = resp.next().await?.ok()?;
let message =
arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");

// message header is a Schema, so read it
let ipc_schema: ipc::Schema = message
.header_as_schema()
.expect("Unable to read IPC message as schema");
let schema = ipc::convert::fb_to_schema(ipc_schema);

Some(schema)
}

async fn receive_batch_flight_data(
resp: &mut Streaming<FlightData>,
schema: SchemaRef,
ipc_schema: arrow::ipc::Schema<'_>,
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
) -> Option<FlightData> {
let mut data = resp.next().await?.ok()?;
Expand All @@ -296,7 +286,7 @@ async fn receive_batch_flight_data(
message
.header_as_dictionary_batch()
.expect("Error parsing dictionary"),
&schema,
ipc_schema,
dictionaries_by_id,
&message.version(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ impl FlightService for FlightServiceImpl {
.ok_or_else(|| Status::not_found(format!("Could not find flight. {key}")))?;

let options = arrow::ipc::writer::IpcWriteOptions::default();
#[allow(deprecated)]
let mut dictionary_tracker =
writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
let mut dictionary_tracker = writer::DictionaryTracker::new(false);
let data_gen = writer::IpcDataGenerator::default();
let data = IpcMessage(
data_gen
Expand Down Expand Up @@ -268,6 +266,7 @@ impl FlightService for FlightServiceImpl {
if let Err(e) = save_uploaded_chunks(
uploaded_chunks,
schema_ref,
flight_data,
input_stream,
response_tx,
schema,
Expand Down Expand Up @@ -319,6 +318,7 @@ async fn record_batch_from_message(
message: ipc::Message<'_>,
data_body: &Buffer,
schema_ref: SchemaRef,
ipc_schema: ipc::Schema<'_>,
dictionaries_by_id: &HashMap<i64, ArrayRef>,
) -> Result<RecordBatch, Status> {
let ipc_batch = message
Expand All @@ -328,6 +328,7 @@ async fn record_batch_from_message(
let arrow_batch_result = reader::read_record_batch(
data_body,
ipc_batch,
ipc_schema,
schema_ref,
dictionaries_by_id,
None,
Expand All @@ -341,7 +342,7 @@ async fn record_batch_from_message(
async fn dictionary_from_message(
message: ipc::Message<'_>,
data_body: &Buffer,
schema_ref: SchemaRef,
ipc_schema: ipc::Schema<'_>,
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
) -> Result<(), Status> {
let ipc_batch = message
Expand All @@ -351,7 +352,7 @@ async fn dictionary_from_message(
let dictionary_batch_result = reader::read_dictionary(
data_body,
ipc_batch,
&schema_ref,
ipc_schema,
dictionaries_by_id,
&message.version(),
);
Expand All @@ -362,6 +363,7 @@ async fn dictionary_from_message(
async fn save_uploaded_chunks(
uploaded_chunks: Arc<Mutex<HashMap<String, IntegrationDataset>>>,
schema_ref: Arc<Schema>,
schema_flight_data: FlightData,
mut input_stream: Streaming<FlightData>,
mut response_tx: mpsc::Sender<Result<PutResult, Status>>,
schema: Schema,
Expand All @@ -372,6 +374,11 @@ async fn save_uploaded_chunks(

let mut dictionaries_by_id = HashMap::new();

let ipc_schema = arrow::ipc::root_as_message(&schema_flight_data.data_header[..])
.map_err(|e| Status::invalid_argument(format!("Could not parse message: {e:?}")))?
.header_as_schema()
.ok_or_else(|| Status::invalid_argument("Could not parse message header as schema"))?;

while let Some(Ok(data)) = input_stream.next().await {
let message = arrow::ipc::root_as_message(&data.data_header[..])
.map_err(|e| Status::internal(format!("Could not parse message: {e:?}")))?;
Expand All @@ -389,6 +396,7 @@ async fn save_uploaded_chunks(
message,
&Buffer::from(data.data_body.as_ref()),
schema_ref.clone(),
ipc_schema,
&dictionaries_by_id,
)
.await?;
Expand All @@ -399,7 +407,7 @@ async fn save_uploaded_chunks(
dictionary_from_message(
message,
&Buffer::from(data.data_body.as_ref()),
schema_ref.clone(),
ipc_schema,
&mut dictionaries_by_id,
)
.await?;
Expand Down
Loading
Loading