Skip to content

Commit f3e1a5b

Browse files
authored
Merge pull request #9030 from andylokandy/fnctx
refactor(expr): rename FunctionContext
2 parents be6e4f2 + bc893a1 commit f3e1a5b

File tree

13 files changed

+160
-183
lines changed

13 files changed

+160
-183
lines changed

src/query/codegen/src/writes/register.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub fn codegen_register() {
3535
use std::sync::Arc;
3636
3737
use crate::Function;
38-
use crate::FunctionContext;
38+
use crate::EvalContext;
3939
use crate::FunctionDomain;
4040
use crate::FunctionRegistry;
4141
use crate::FunctionSignature;
@@ -82,7 +82,7 @@ pub fn codegen_register() {
8282
func: G,
8383
) where
8484
F: Fn({arg_f_closure_sig}) -> FunctionDomain<O> + 'static + Clone + Copy + Send + Sync,
85-
G: Fn({arg_g_closure_sig} FunctionContext) -> O::Scalar + 'static + Clone + Copy + Send + Sync,
85+
G: Fn({arg_g_closure_sig} EvalContext) -> O::Scalar + 'static + Clone + Copy + Send + Sync,
8686
{{
8787
self.register_passthrough_nullable_{n_args}_arg::<{arg_generics} O, _, _>(
8888
name,
@@ -154,7 +154,7 @@ pub fn codegen_register() {
154154
func: G,
155155
) where
156156
F: Fn({arg_f_closure_sig}) -> FunctionDomain<O> + 'static + Clone + Copy + Send + Sync,
157-
G: for<'a> Fn({arg_g_closure_sig} FunctionContext) -> Result<Value<O>, String> + 'static + Clone + Copy + Send + Sync,
157+
G: for<'a> Fn({arg_g_closure_sig} EvalContext) -> Result<Value<O>, String> + 'static + Clone + Copy + Send + Sync,
158158
{{
159159
let has_nullable = &[{arg_sig_type} O::data_type()]
160160
.iter()
@@ -257,7 +257,7 @@ pub fn codegen_register() {
257257
func: G,
258258
) where
259259
F: Fn({arg_f_closure_sig}) -> FunctionDomain<NullableType<O>> + 'static + Clone + Copy + Send + Sync,
260-
G: for<'a> Fn({arg_g_closure_sig} FunctionContext) -> Result<Value<NullableType<O>>, String> + 'static + Clone + Copy + Send + Sync,
260+
G: for<'a> Fn({arg_g_closure_sig} EvalContext) -> Result<Value<NullableType<O>>, String> + 'static + Clone + Copy + Send + Sync,
261261
{{
262262
let has_nullable = &[{arg_sig_type} O::data_type()]
263263
.iter()
@@ -340,7 +340,7 @@ pub fn codegen_register() {
340340
func: G,
341341
) where
342342
F: Fn({arg_f_closure_sig}) -> FunctionDomain<O> + 'static + Clone + Copy + Send + Sync,
343-
G: for <'a> Fn({arg_g_closure_sig} FunctionContext) -> Result<Value<O>, String> + 'static + Clone + Copy + Send + Sync,
343+
G: for <'a> Fn({arg_g_closure_sig} EvalContext) -> Result<Value<O>, String> + 'static + Clone + Copy + Send + Sync,
344344
{{
345345
self.funcs
346346
.entry(name.to_string())
@@ -441,8 +441,8 @@ pub fn codegen_register() {
441441
source,
442442
"
443443
pub fn vectorize_{n_args}_arg<{arg_generics_bound} O: ArgType>(
444-
func: impl Fn({arg_input_closure_sig} FunctionContext) -> O::Scalar + Copy + Send + Sync,
445-
) -> impl Fn({arg_output_closure_sig} FunctionContext) -> Result<Value<O>, String> + Copy + Send + Sync {{
444+
func: impl Fn({arg_input_closure_sig} EvalContext) -> O::Scalar + Copy + Send + Sync,
445+
) -> impl Fn({arg_output_closure_sig} EvalContext) -> Result<Value<O>, String> + Copy + Send + Sync {{
446446
move |{func_args} ctx| match ({args_tuple}) {{
447447
({arg_scalar}) => Ok(Value::Scalar(func({func_args} ctx))),
448448
{match_arms}
@@ -535,8 +535,8 @@ pub fn codegen_register() {
535535
source,
536536
"
537537
pub fn vectorize_with_builder_{n_args}_arg<{arg_generics_bound} O: ArgType>(
538-
func: impl Fn({arg_input_closure_sig} &mut O::ColumnBuilder, FunctionContext) -> Result<(), String> + Copy + Send + Sync,
539-
) -> impl Fn({arg_output_closure_sig} FunctionContext) -> Result<Value<O>, String> + Copy + Send + Sync {{
538+
func: impl Fn({arg_input_closure_sig} &mut O::ColumnBuilder, EvalContext) -> Result<(), String> + Copy + Send + Sync,
539+
) -> impl Fn({arg_output_closure_sig} EvalContext) -> Result<Value<O>, String> + Copy + Send + Sync {{
540540
move |{func_args} ctx| match ({args_tuple}) {{
541541
({arg_scalar}) => {{
542542
let mut builder = O::create_builder(1, ctx.generics);
@@ -640,8 +640,8 @@ pub fn codegen_register() {
640640
source,
641641
"
642642
pub fn passthrough_nullable_{n_args}_arg<{arg_generics_bound} O: ArgType>(
643-
func: impl for <'a> Fn({arg_input_closure_sig} FunctionContext) -> Result<Value<O>, String> + Copy + Send + Sync,
644-
) -> impl for <'a> Fn({arg_output_closure_sig} FunctionContext) -> Result<Value<NullableType<O>>, String> + Copy + Send + Sync {{
643+
func: impl for <'a> Fn({arg_input_closure_sig} EvalContext) -> Result<Value<O>, String> + Copy + Send + Sync,
644+
) -> impl for <'a> Fn({arg_output_closure_sig} EvalContext) -> Result<Value<NullableType<O>>, String> + Copy + Send + Sync {{
645645
move |{closure_args} ctx| match ({args_tuple}) {{
646646
{scalar_nones_pats} => Ok(Value::Scalar(None)),
647647
({arg_scalar}) => Ok(Value::Scalar(Some(
@@ -747,8 +747,8 @@ pub fn codegen_register() {
747747
source,
748748
"
749749
pub fn combine_nullable_{n_args}_arg<{arg_generics_bound} O: ArgType>(
750-
func: impl for <'a> Fn({arg_input_closure_sig} FunctionContext) -> Result<Value<NullableType<O>>, String> + Copy + Send + Sync,
751-
) -> impl for <'a> Fn({arg_output_closure_sig} FunctionContext) -> Result<Value<NullableType<O>>, String> + Copy + Send + Sync {{
750+
func: impl for <'a> Fn({arg_input_closure_sig} EvalContext) -> Result<Value<NullableType<O>>, String> + Copy + Send + Sync,
751+
) -> impl for <'a> Fn({arg_output_closure_sig} EvalContext) -> Result<Value<NullableType<O>>, String> + Copy + Send + Sync {{
752752
move |{closure_args} ctx| match ({args_tuple}) {{
753753
{scalar_nones_pats} => Ok(Value::Scalar(None)),
754754
({arg_scalar}) => Ok(Value::Scalar(
@@ -825,8 +825,8 @@ pub fn codegen_register() {
825825
source,
826826
"
827827
fn erase_function_generic_{n_args}_arg<{arg_generics_bound} O: ArgType>(
828-
func: impl for <'a> Fn({arg_g_closure_sig} FunctionContext) -> Result<Value<O>, String>,
829-
) -> impl Fn(&[ValueRef<AnyType>], FunctionContext) -> Result<Value<AnyType>, String> {{
828+
func: impl for <'a> Fn({arg_g_closure_sig} EvalContext) -> Result<Value<O>, String>,
829+
) -> impl Fn(&[ValueRef<AnyType>], EvalContext) -> Result<Value<AnyType>, String> {{
830830
move |args, ctx| {{
831831
{let_args}
832832
func({func_args} ctx).map(Value::upcast)

src/query/expression/src/evaluator.rs

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@ use std::collections::HashMap;
1616
#[cfg(debug_assertions)]
1717
use std::sync::Mutex;
1818

19-
use chrono_tz::Tz;
2019
use common_arrow::arrow::bitmap;
2120
use itertools::Itertools;
2221

2322
use crate::chunk::Chunk;
2423
use crate::expression::Expr;
2524
use crate::expression::Span;
26-
use crate::function::FunctionContext;
25+
use crate::function::EvalContext;
2726
use crate::property::Domain;
2827
use crate::type_check::check_simple_cast;
2928
use crate::types::any::AnyType;
@@ -39,21 +38,26 @@ use crate::values::ColumnBuilder;
3938
use crate::values::Scalar;
4039
use crate::values::Value;
4140
use crate::ColumnIndex;
41+
use crate::FunctionContext;
4242
use crate::FunctionDomain;
4343
use crate::FunctionRegistry;
4444
use crate::Result;
4545

4646
pub struct Evaluator<'a> {
4747
input_columns: &'a Chunk,
48-
tz: Tz,
48+
fn_ctx: FunctionContext,
4949
fn_registry: &'a FunctionRegistry,
5050
}
5151

5252
impl<'a> Evaluator<'a> {
53-
pub fn new(input_columns: &'a Chunk, tz: Tz, fn_registry: &'a FunctionRegistry) -> Self {
53+
pub fn new(
54+
input_columns: &'a Chunk,
55+
fn_ctx: FunctionContext,
56+
fn_registry: &'a FunctionRegistry,
57+
) -> Self {
5458
Evaluator {
5559
input_columns,
56-
tz,
60+
fn_ctx,
5761
fn_registry,
5862
}
5963
}
@@ -82,10 +86,10 @@ impl<'a> Evaluator<'a> {
8286
.all_equal()
8387
);
8488
let cols_ref = cols.iter().map(Value::as_ref).collect::<Vec<_>>();
85-
let ctx = FunctionContext {
89+
let ctx = EvalContext {
8690
generics,
8791
num_rows: self.input_columns.num_rows(),
88-
tz: self.tz,
92+
tz: self.fn_ctx.tz,
8993
};
9094
(function.eval)(cols_ref.as_slice(), ctx).map_err(|msg| (span.clone(), msg))
9195
}
@@ -110,9 +114,13 @@ impl<'a> Evaluator<'a> {
110114
if !*RECURSING.lock().unwrap() {
111115
*RECURSING.lock().unwrap() = true;
112116
assert_eq!(
113-
ConstantFolder::new(self.input_columns.domains(), self.tz, self.fn_registry)
114-
.fold(expr)
115-
.1,
117+
ConstantFolder::new(
118+
self.input_columns.domains(),
119+
self.fn_ctx,
120+
self.fn_registry
121+
)
122+
.fold(expr)
123+
.1,
116124
None,
117125
"domain calculation should not return any domain for expressions that are possible to fail"
118126
);
@@ -401,7 +409,7 @@ impl<'a> Evaluator<'a> {
401409
span,
402410
cast_fn,
403411
[(value, src_type.clone())],
404-
self.tz,
412+
self.fn_ctx,
405413
num_rows,
406414
self.fn_registry,
407415
)?;
@@ -412,19 +420,19 @@ impl<'a> Evaluator<'a> {
412420

413421
pub struct ConstantFolder<'a, Index: ColumnIndex> {
414422
input_domains: HashMap<Index, Domain>,
415-
tz: Tz,
423+
fn_ctx: FunctionContext,
416424
fn_registry: &'a FunctionRegistry,
417425
}
418426

419427
impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
420428
pub fn new(
421429
input_domains: HashMap<Index, Domain>,
422-
tz: Tz,
430+
fn_ctx: FunctionContext,
423431
fn_registry: &'a FunctionRegistry,
424432
) -> Self {
425433
ConstantFolder {
426434
input_domains,
427-
tz,
435+
fn_ctx,
428436
fn_registry,
429437
}
430438
}
@@ -485,7 +493,7 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
485493

486494
if inner_expr.as_constant().is_some() {
487495
let chunk = Chunk::empty();
488-
let evaluator = Evaluator::new(&chunk, self.tz, self.fn_registry);
496+
let evaluator = Evaluator::new(&chunk, self.fn_ctx, self.fn_registry);
489497
// Since we know the expression is constant, it'll be safe to change its column index type.
490498
let cast_expr = cast_expr.project_column_ref(|_| unreachable!());
491499
if let Ok(Value::Scalar(scalar)) = evaluator.run(&cast_expr) {
@@ -561,7 +569,7 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
561569

562570
if all_args_is_scalar {
563571
let chunk = Chunk::empty();
564-
let evaluator = Evaluator::new(&chunk, self.tz, self.fn_registry);
572+
let evaluator = Evaluator::new(&chunk, self.fn_ctx, self.fn_registry);
565573
// Since we know the expression is constant, it'll be safe to change its column index type.
566574
let func_expr = func_expr.project_column_ref(|_| unreachable!());
567575
if let Ok(Value::Scalar(scalar)) = evaluator.run(&func_expr) {
@@ -753,7 +761,7 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
753761
span,
754762
cast_fn,
755763
[(domain.clone(), src_type.clone())],
756-
self.tz,
764+
self.fn_ctx,
757765
self.fn_registry,
758766
)?;
759767
assert_eq!(&ty, dest_type);

src/query/expression/src/function.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@ pub struct FunctionSignature {
4242
}
4343

4444
#[derive(Clone, Copy)]
45-
pub struct FunctionContext<'a> {
45+
pub struct FunctionContext {
46+
pub tz: Tz,
47+
}
48+
49+
#[derive(Clone, Copy)]
50+
pub struct EvalContext<'a> {
4651
pub generics: &'a GenericMap,
4752
pub num_rows: usize,
4853
pub tz: Tz,
@@ -70,9 +75,7 @@ pub struct Function {
7075
pub calc_domain: Box<dyn Fn(&[Domain]) -> FunctionDomain<AnyType> + Send + Sync>,
7176
#[allow(clippy::type_complexity)]
7277
pub eval: Box<
73-
dyn Fn(&[ValueRef<AnyType>], FunctionContext) -> Result<Value<AnyType>, String>
74-
+ Send
75-
+ Sync,
78+
dyn Fn(&[ValueRef<AnyType>], EvalContext) -> Result<Value<AnyType>, String> + Send + Sync,
7679
>,
7780
}
7881

@@ -204,8 +207,8 @@ impl FunctionRegistry {
204207

205208
pub fn wrap_nullable<F>(
206209
f: F,
207-
) -> impl Fn(&[ValueRef<AnyType>], FunctionContext) -> Result<Value<AnyType>, String> + Copy
208-
where F: Fn(&[ValueRef<AnyType>], FunctionContext) -> Result<Value<AnyType>, String> + Copy {
210+
) -> impl Fn(&[ValueRef<AnyType>], EvalContext) -> Result<Value<AnyType>, String> + Copy
211+
where F: Fn(&[ValueRef<AnyType>], EvalContext) -> Result<Value<AnyType>, String> + Copy {
209212
move |args, ctx| {
210213
type T = NullableType<AnyType>;
211214
type Result = AnyType;

0 commit comments

Comments
 (0)