diff --git a/src/query/service/src/pipelines/executor/executor_graph.rs b/src/query/service/src/pipelines/executor/executor_graph.rs index d1eee550047cf..0914f197d4f36 100644 --- a/src/query/service/src/pipelines/executor/executor_graph.rs +++ b/src/query/service/src/pipelines/executor/executor_graph.rs @@ -32,7 +32,7 @@ use crate::pipelines::executor::executor_condvar::WorkersCondvar; use crate::pipelines::executor::executor_tasks::ExecutorTasksQueue; use crate::pipelines::executor::executor_worker_context::ExecutorTask; use crate::pipelines::executor::executor_worker_context::ExecutorWorkerContext; -use crate::pipelines::executor::processor_async_task::ProcessorAsyncTask; +use crate::pipelines::executor::processor_async_task::ProcessorAsyncFuture; use crate::pipelines::executor::PipelineExecutor; use crate::pipelines::pipeline::Pipeline; use crate::pipelines::processors::connect; @@ -337,7 +337,7 @@ impl ScheduleQueue { let process_future = proc.async_process(); executor .async_runtime - .spawn(TrackedFuture::create(ProcessorAsyncTask::create( + .spawn(TrackedFuture::create(ProcessorAsyncFuture::create( query_id, wakeup_worker_id, proc.clone(), diff --git a/src/query/service/src/pipelines/executor/executor_tasks.rs b/src/query/service/src/pipelines/executor/executor_tasks.rs index deb16bf9b59f1..c2184f27b325f 100644 --- a/src/query/service/src/pipelines/executor/executor_tasks.rs +++ b/src/query/service/src/pipelines/executor/executor_tasks.rs @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::collections::VecDeque; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::Arc; +use std::sync::Weak; use common_base::base::tokio::sync::Notify; use common_exception::Result; @@ -26,6 +28,7 @@ use crate::pipelines::executor::executor_condvar::WorkersCondvar; use crate::pipelines::executor::executor_condvar::WorkersWaitingStatus; use crate::pipelines::executor::executor_worker_context::ExecutorTask; use crate::pipelines::executor::executor_worker_context::ExecutorWorkerContext; +use crate::pipelines::executor::processor_async_task::ProcessorAsyncTask; use crate::pipelines::processors::processor::ProcessorPtr; pub struct ExecutorTasksQueue { @@ -48,6 +51,7 @@ impl ExecutorTasksQueue { self.finished_notify.notify_waiters(); let mut workers_tasks = self.workers_tasks.lock(); + let mut wakeup_workers = Vec::with_capacity(workers_tasks.workers_waiting_status.waiting_size()); @@ -56,6 +60,14 @@ impl ExecutorTasksQueue { wakeup_workers.push(worker_id); } + for pending_sync_tasks in &mut workers_tasks.workers_pending_async_tasks { + for (_proc, async_task) in pending_sync_tasks.drain() { + if let Some(async_task) = Weak::upgrade(&async_task) { + async_task.wakeup(); + } + } + } + drop(workers_tasks); for wakeup_worker in wakeup_workers { workers_condvar.wakeup(wakeup_worker); @@ -163,6 +175,10 @@ impl ExecutorTasksQueue { let mut workers_tasks = self.workers_tasks.lock(); let mut worker_id = task.worker_id; + if task.has_pending { + workers_tasks.workers_pending_async_tasks[worker_id].remove(&task.id.index()); + } + workers_tasks.tasks_size += 1; workers_tasks.workers_completed_async_tasks[worker_id].push_back(task); @@ -182,6 +198,12 @@ impl ExecutorTasksQueue { } } + pub fn pending_async_task(&self, task: &Arc) { + let mut workers_tasks = self.workers_tasks.lock(); + workers_tasks.workers_pending_async_tasks[task.worker_id] + .insert(task.processor_id.index(), Arc::downgrade(task)); + } + pub fn get_finished_notify(&self) -> Arc { self.finished_notify.clone() } @@ -197,11 +219,17 @@ pub struct CompletedAsyncTask { pub id: NodeIndex, pub worker_id: usize, pub res: Result<()>, + pub has_pending: bool, } impl CompletedAsyncTask { - pub fn create(id: NodeIndex, worker_id: usize, res: Result<()>) -> Self { - CompletedAsyncTask { id, worker_id, res } + pub fn create(id: NodeIndex, worker_id: usize, pending: bool, res: Result<()>) -> Self { + CompletedAsyncTask { + id, + worker_id, + res, + has_pending: pending, + } } } @@ -210,6 +238,7 @@ struct ExecutorTasks { workers_waiting_status: WorkersWaitingStatus, workers_sync_tasks: Vec>, workers_completed_async_tasks: Vec>, + workers_pending_async_tasks: Vec>>, } unsafe impl Send for ExecutorTasks {} @@ -217,16 +246,19 @@ unsafe impl Send for ExecutorTasks {} impl ExecutorTasks { pub fn create(workers_size: usize) -> ExecutorTasks { let mut workers_sync_tasks = Vec::with_capacity(workers_size); + let mut workers_pending_async_tasks = Vec::with_capacity(workers_size); let mut workers_completed_async_tasks = Vec::with_capacity(workers_size); for _index in 0..workers_size { workers_sync_tasks.push(VecDeque::new()); + workers_pending_async_tasks.push(HashMap::new()); workers_completed_async_tasks.push(VecDeque::new()); } ExecutorTasks { tasks_size: 0, workers_sync_tasks, + workers_pending_async_tasks, workers_completed_async_tasks, workers_waiting_status: WorkersWaitingStatus::create(workers_size), } diff --git a/src/query/service/src/pipelines/executor/processor_async_task.rs b/src/query/service/src/pipelines/executor/processor_async_task.rs index 4108878f0e930..335078f6adc34 100644 --- a/src/query/service/src/pipelines/executor/processor_async_task.rs +++ b/src/query/service/src/pipelines/executor/processor_async_task.rs @@ -13,21 +13,19 @@ // limitations under the License. use std::future::Future; +use std::mem::ManuallyDrop; +use std::ops::DerefMut; use std::pin::Pin; +use std::sync::atomic::AtomicPtr; +use std::sync::atomic::Ordering; use std::sync::Arc; use std::task::Context; use std::task::Poll; -use std::time::Duration; -use std::time::Instant; +use std::task::Waker; -use common_base::base::tokio::time::sleep; use common_base::runtime::catch_unwind; -use common_exception::ErrorCode; use common_exception::Result; use common_pipeline_core::processors::processor::ProcessorPtr; -use futures_util::future::BoxFuture; -use futures_util::future::Either; -use futures_util::FutureExt; use petgraph::prelude::NodeIndex; use crate::pipelines::executor::executor_condvar::WorkersCondvar; @@ -35,101 +33,134 @@ use crate::pipelines::executor::executor_tasks::CompletedAsyncTask; use crate::pipelines::executor::executor_tasks::ExecutorTasksQueue; pub struct ProcessorAsyncTask { - worker_id: usize, - processor_id: NodeIndex, + pub worker_id: usize, + pub processor_id: NodeIndex, queue: Arc, workers_condvar: Arc, - inner: BoxFuture<'static, Result<()>>, + pending_waker: AtomicPtr, } impl ProcessorAsyncTask { - pub fn create> + Send + 'static>( - query_id: Arc, + pub fn create( worker_id: usize, - processor: ProcessorPtr, + processor_id: NodeIndex, queue: Arc, workers_condvar: Arc, - inner: Inner, - ) -> ProcessorAsyncTask { - let finished_notify = queue.get_finished_notify(); - - let inner = async move { - let left = Box::pin(inner); - let right = Box::pin(finished_notify.notified()); - match futures::future::select(left, right).await { - Either::Left((res, _)) => res, - Either::Right((_, _)) => Err(ErrorCode::AbortedQuery( - "Aborted query, because the server is shutting down or the query was killed.", - )), + ) -> Arc { + Arc::new(ProcessorAsyncTask { + worker_id, + processor_id, + queue, + workers_condvar, + pending_waker: AtomicPtr::new(std::ptr::null_mut()), + }) + } + + #[inline(always)] + pub fn is_finished(&self) -> bool { + self.queue.is_finished() + } + + #[inline(always)] + pub fn finish(self: Arc, res: Result<()>) -> Poll<()> { + self.queue.completed_async_task( + self.workers_condvar.clone(), + CompletedAsyncTask::create( + self.processor_id, + self.worker_id, + self.pending_waker.load(Ordering::SeqCst) != std::ptr::null_mut(), + res, + ), + ); + + Poll::Ready(()) + } + + pub fn wakeup(&self) { + let waker = self + .pending_waker + .swap(std::ptr::null_mut(), Ordering::SeqCst); + if !waker.is_null() { + unsafe { + let waker = std::ptr::read(waker as *const Waker); + waker.wake(); } - }; - - let processor_id = unsafe { processor.id() }; - let processor_name = unsafe { processor.name() }; - let queue_clone = queue.clone(); - let inner = async move { - let start = Instant::now(); - let mut inner = inner.boxed(); - - loop { - let interval = Box::pin(sleep(Duration::from_secs(5))); - match futures::future::select(interval, inner).await { - Either::Left((_, right)) => { - inner = right; - let active_workers = queue_clone.active_workers(); - tracing::warn!( - "Very slow processor async task, query_id:{:?}, processor id: {:?}, name: {:?}, elapsed: {:?}, active sync workers: {:?}", - query_id, - processor_id, - processor_name, - start.elapsed(), - active_workers, - ); + } + } + + #[inline(always)] + pub fn set_pending_watcher(self: Arc, waker: Waker) -> Poll<()> { + let mut expected = std::ptr::null_mut(); + let desired = Box::into_raw(Box::new(waker)); + + loop { + match self.pending_waker.compare_exchange_weak( + expected, + desired, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Err(new_expected) => unsafe { + if !new_expected.is_null() && (&*new_expected).will_wake(&*desired) { + return Poll::Pending; } - Either::Right((res, _)) => { - return res; + + expected = new_expected; + }, + Ok(old_value) => { + if !old_value.is_null() { + unsafe { drop(Box::from_raw(old_value)) }; } + + self.queue.pending_async_task(&self); + return Poll::Pending; } } - }; + } + } +} - ProcessorAsyncTask { - worker_id, - processor_id, - queue, - workers_condvar, - inner: inner.boxed(), +pub struct ProcessorAsyncFuture> + Send + 'static> { + inner: Inner, + task: Arc, +} + +impl> + Send + 'static> ProcessorAsyncFuture { + pub fn create( + _query_id: Arc, + worker_id: usize, + processor: ProcessorPtr, + queue: Arc, + workers_condvar: Arc, + inner: Inner, + ) -> ProcessorAsyncFuture { + ProcessorAsyncFuture { + inner, + task: ProcessorAsyncTask::create( + worker_id, + unsafe { processor.id() }, + queue, + workers_condvar, + ), } } } -impl Future for ProcessorAsyncTask { +impl> + Send + 'static> Future for ProcessorAsyncFuture { type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.queue.is_finished() { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.task.is_finished() { return Poll::Ready(()); } - let inner = self.inner.as_mut(); - - match catch_unwind(move || inner.poll(cx)) { - Ok(Poll::Pending) => Poll::Pending, - Ok(Poll::Ready(res)) => { - self.queue.completed_async_task( - self.workers_condvar.clone(), - CompletedAsyncTask::create(self.processor_id, self.worker_id, res), - ); - Poll::Ready(()) - } - Err(cause) => { - self.queue.completed_async_task( - self.workers_condvar.clone(), - CompletedAsyncTask::create(self.processor_id, self.worker_id, Err(cause)), - ); + let task = self.task.clone(); + let inner = unsafe { self.map_unchecked_mut(|x| &mut x.inner) }; - Poll::Ready(()) - } + match catch_unwind(move || (inner.poll(cx), cx)) { + Ok((Poll::Ready(res), _)) => task.finish(res), + Err(cause) => task.finish(Err(cause)), + Ok((Poll::Pending, cx)) => task.set_pending_watcher(cx.waker().clone()), } } } diff --git a/src/query/service/src/table_functions/mod.rs b/src/query/service/src/table_functions/mod.rs index d81ef4ccf0cdd..b443cd2b8521b 100644 --- a/src/query/service/src/table_functions/mod.rs +++ b/src/query/service/src/table_functions/mod.rs @@ -17,6 +17,7 @@ mod infer_schema; mod list_stage; mod numbers; mod openai; +mod slow_async; mod srf; mod sync_crash_me; mod table_function; diff --git a/src/query/service/src/table_functions/slow_async.rs b/src/query/service/src/table_functions/slow_async.rs new file mode 100755 index 0000000000000..19a352a456fc9 --- /dev/null +++ b/src/query/service/src/table_functions/slow_async.rs @@ -0,0 +1,227 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::sync::Arc; +use std::time::Duration; + +use chrono::NaiveDateTime; +use chrono::TimeZone; +use chrono::Utc; +use common_base::base::tokio; +use common_catalog::plan::DataSourcePlan; +use common_catalog::plan::PartStatistics; +use common_catalog::plan::Partitions; +use common_catalog::plan::PushDownInfo; +use common_catalog::table::Table; +use common_catalog::table_args::TableArgs; +use common_catalog::table_context::TableContext; +use common_catalog::table_function::TableFunction; +use common_exception::ErrorCode; +use common_exception::Result; +use common_expression::types::NumberDataType; +use common_expression::types::NumberScalar; +use common_expression::types::UInt64Type; +use common_expression::DataBlock; +use common_expression::FromData; +use common_expression::Scalar; +use common_expression::TableDataType; +use common_expression::TableField; +use common_expression::TableSchemaRefExt; +use common_meta_app::schema::TableIdent; +use common_meta_app::schema::TableInfo; +use common_meta_app::schema::TableMeta; +use common_pipeline_core::processors::port::OutputPort; +use common_pipeline_core::processors::processor::ProcessorPtr; +use common_pipeline_core::Pipeline; +use common_pipeline_sources::AsyncSource; +use common_pipeline_sources::AsyncSourcer; + +pub struct SlowAsyncTable { + table_info: TableInfo, + delay_seconds: usize, + read_blocks_num: usize, +} + +impl SlowAsyncTable { + pub fn create( + database_name: &str, + table_func_name: &str, + table_id: u64, + table_args: TableArgs, + ) -> Result> { + // let mut panic_message = None; + let args = table_args.expect_all_positioned(table_func_name, None)?; + + if args.len() != 1 && args.len() != 2 { + return Err(ErrorCode::BadArguments( + "unexpected arguments for slow_async(delay_seconds, [block_num])", + )); + } + + let mut read_blocks_num = 1; + println!("{:?}", args[0].to_string()); + + let delay_seconds = match args[0].to_string().parse::() { + Ok(x) => Ok(x), + Err(_) => Err(ErrorCode::BadArguments("Expected uint64 argument")), + }?; + + if args.len() == 2 { + read_blocks_num = match args[1].to_string().parse::() { + Ok(x) => Ok(x), + Err(_) => Err(ErrorCode::BadArguments("Expected uint64 argument")), + }?; + } + + let table_info = TableInfo { + ident: TableIdent::new(table_id, 0), + desc: format!("'{}'.'{}'", database_name, table_func_name), + name: String::from("slow_async"), + meta: TableMeta { + schema: TableSchemaRefExt::create(vec![TableField::new( + "number", + TableDataType::Number(NumberDataType::UInt64), + )]), + engine: String::from(table_func_name), + // Assuming that created_on is unnecessary for function table, + // we could make created_on fixed to pass test_shuffle_action_try_into. + created_on: Utc + .from_utc_datetime(&NaiveDateTime::from_timestamp_opt(0, 0).unwrap()), + updated_on: Utc + .from_utc_datetime(&NaiveDateTime::from_timestamp_opt(0, 0).unwrap()), + ..Default::default() + }, + ..Default::default() + }; + + Ok(Arc::new(SlowAsyncTable { + table_info, + delay_seconds, + read_blocks_num, + })) + } +} + +#[async_trait::async_trait] +impl Table for SlowAsyncTable { + fn is_local(&self) -> bool { + true + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn get_table_info(&self) -> &TableInfo { + &self.table_info + } + + #[async_backtrace::framed] + async fn read_partitions( + &self, + _ctx: Arc, + _push_downs: Option, + _dry_run: bool, + ) -> Result<(PartStatistics, Partitions)> { + // dummy statistics + Ok(( + PartStatistics::new_exact( + self.read_blocks_num, + self.read_blocks_num * 8, + self.read_blocks_num, + self.read_blocks_num, + ), + Partitions::default(), + )) + } + + fn table_args(&self) -> Option { + Some(TableArgs::new_positioned(vec![ + Scalar::Number(NumberScalar::UInt64(self.delay_seconds as u64)), + Scalar::Number(NumberScalar::UInt64(self.read_blocks_num as u64)), + ])) + } + + fn read_data( + &self, + ctx: Arc, + _: &DataSourcePlan, + pipeline: &mut Pipeline, + ) -> Result<()> { + pipeline.add_source( + |output| { + SlowAsyncDataSource::create( + ctx.clone(), + output, + self.delay_seconds, + self.read_blocks_num, + ) + }, + 1, + )?; + + Ok(()) + } +} + +struct SlowAsyncDataSource { + delay_seconds: usize, + read_blocks_num: usize, +} + +impl SlowAsyncDataSource { + pub fn create( + ctx: Arc, + output: Arc, + delay_seconds: usize, + read_blocks_num: usize, + ) -> Result { + AsyncSourcer::create(ctx, output, SlowAsyncDataSource { + delay_seconds, + read_blocks_num, + }) + } +} + +#[async_trait::async_trait] +impl AsyncSource for SlowAsyncDataSource { + const NAME: &'static str = "slow_async"; + + #[async_trait::unboxed_simple] + #[async_backtrace::framed] + async fn generate(&mut self) -> Result> { + match self.read_blocks_num { + 0 => Ok(None), + _ => { + self.read_blocks_num -= 1; + tokio::time::sleep(Duration::from_secs(self.delay_seconds as u64)).await; + Ok(Some(DataBlock::new_from_columns(vec![ + UInt64Type::from_data(vec![1]), + ]))) + } + } + } +} + +impl TableFunction for SlowAsyncTable { + fn function_name(&self) -> &str { + self.name() + } + + fn as_table<'a>(self: Arc) -> Arc + where Self: 'a { + self + } +} diff --git a/src/query/service/src/table_functions/table_function_factory.rs b/src/query/service/src/table_functions/table_function_factory.rs index 4d38db74fdefc..0beb7260130ea 100644 --- a/src/query/service/src/table_functions/table_function_factory.rs +++ b/src/query/service/src/table_functions/table_function_factory.rs @@ -33,6 +33,7 @@ use crate::table_functions::async_crash_me::AsyncCrashMeTable; use crate::table_functions::infer_schema::InferSchemaTable; use crate::table_functions::list_stage::ListStageTable; use crate::table_functions::numbers::NumbersTable; +use crate::table_functions::slow_async::SlowAsyncTable; use crate::table_functions::srf::RangeTable; use crate::table_functions::sync_crash_me::SyncCrashMeTable; use crate::table_functions::GPT2SQLTable; @@ -135,6 +136,11 @@ impl TableFunctionFactory { (next_id(), Arc::new(AsyncCrashMeTable::create)), ); + creators.insert( + "slow_async".to_string(), + (next_id(), Arc::new(SlowAsyncTable::create)), + ); + creators.insert( "infer_schema".to_string(), (next_id(), Arc::new(InferSchemaTable::create)), diff --git a/tests/suites/1_stateful/02_query/02_0000_kill_query.py b/tests/suites/1_stateful/02_query/02_0000_kill_query.py index ee211a4d749b8..2ef788cce86ad 100755 --- a/tests/suites/1_stateful/02_query/02_0000_kill_query.py +++ b/tests/suites/1_stateful/02_query/02_0000_kill_query.py @@ -37,12 +37,28 @@ res = mycursor.fetchone() kill_query = "kill query " + str(res[0]) + ";" mycursor.execute(kill_query) - time.sleep(0.1) + time.sleep(0.5) + + ## TODO NEW EXPRESSION + ## assert res is None + client1.expect(prompt) + +with NativeClient(name="client2>") as client2: + client2.expect(prompt) + client2.expect("") + + client2.send("SELECT * FROM slow_async(5, 100);") + time.sleep(0.5) + + mycursor = mydb.cursor() mycursor.execute( - "SELECT * FROM system.processes WHERE extra_info LIKE '%SELECT max(number)%' AND extra_info NOT LIKE '%system.processes%';" + "SELECT mysql_connection_id FROM system.processes WHERE extra_info LIKE '%slow_async%' AND extra_info NOT LIKE '%system.processes%';" ) res = mycursor.fetchone() + kill_query = "kill query " + str(res[0]) + ";" + mycursor.execute(kill_query) + time.sleep(0.5) ## TODO NEW EXPRESSION ## assert res is None - client1.expect(prompt) + client2.expect(prompt)