diff --git a/bin/installcheck b/bin/installcheck index 23111ed0..9ab8e846 100755 --- a/bin/installcheck +++ b/bin/installcheck @@ -35,17 +35,16 @@ createdb contrib_regression # Tests # ######### TESTDIR="test" -PGXS=$(dirname `pg_config --pgxs`) +PGXS=$(dirname $(pg_config --pgxs)) REGRESS="${PGXS}/../test/regress/pg_regress" # Test names can be passed as parameters to this script. # If any test names are passed run only those tests. # Otherwise run all tests. -if [ "$#" -ne 0 ] -then - TESTS=$@ +if [ "$#" -ne 0 ]; then + TESTS=$@ else - TESTS=$(ls ${TESTDIR}/sql | sed -e 's/\..*$//' | sort ) + TESTS=$(ls ${TESTDIR}/sql | sed -e 's/\..*$//' | sort) fi # Execute the test fixtures diff --git a/docs/api.md b/docs/api.md index 7f4ad124..66a6450f 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1,4 +1,3 @@ - In our API, each SQL table is reflected as a set of GraphQL types. At a high level, tables become types and columns/foreign keys become fields on those types. @@ -167,6 +166,9 @@ Connections wrap a result set with some additional metadata. # Result set edges: [BlogEdge!]! + # Aggregate functions + aggregate: BlogAggregate + } ``` @@ -264,8 +266,176 @@ Connections wrap a result set with some additional metadata. The `totalCount` field is disabled by default because it can be expensive on large tables. To enable it use a [comment directive](configuration.md#totalcount) +#### Aggregates + +Aggregate functions are available on the collection's `aggregate` field when enabled via [comment directive](configuration.md#aggregate). These allow you to perform calculations on the collection of records that match your filter criteria. + +The supported aggregate operations are: + +- **count**: Always available, returns the number of records matching the query +- **sum**: Available for numeric fields, returns the sum of values +- **avg**: Available for numeric fields, returns the average (mean) of values +- **min**: Available for numeric, string, boolean, and date/time fields, returns the minimum value +- **max**: Available for numeric, string, boolean, and date/time fields, returns the maximum value + +**Example** + +=== "Query" + + ```graphql + { + blogCollection( + filter: { rating: { gt: 3 } } + ) { + aggregate { + count + sum { + rating + visits + } + avg { + rating + } + min { + createdAt + title + } + max { + rating + updatedAt + } + } + } + } + ``` + +=== "Response" + + ```json + { + "data": { + "blogCollection": { + "aggregate": { + "count": 5, + "sum": { + "rating": 23, + "visits": 1250 + }, + "avg": { + "rating": 4.6 + }, + "min": { + "createdAt": "2022-01-15T08:30:00Z", + "title": "A Blog Post" + }, + "max": { + "rating": 5, + "updatedAt": "2023-04-22T14:15:30Z" + } + } + } + } + } + ``` + +**GraphQL Types** +=== "BlogAggregate" + + ```graphql + """Aggregate results for `Blog`""" + type BlogAggregate { + """The number of records matching the query""" + count: Int! + + """Summation aggregates for `Blog`""" + sum: BlogSumAggregateResult + + """Average aggregates for `Blog`""" + avg: BlogAvgAggregateResult + + """Minimum aggregates for comparable fields""" + min: BlogMinAggregateResult + + """Maximum aggregates for comparable fields""" + max: BlogMaxAggregateResult + } + ``` + +=== "BlogSumAggregateResult" + + ```graphql + """Result of summation aggregation for `Blog`""" + type BlogSumAggregateResult { + """Sum of rating values""" + rating: BigFloat + + """Sum of visits values""" + visits: BigInt + + # Other numeric fields... + } + ``` + +=== "BlogAvgAggregateResult" + + ```graphql + """Result of average aggregation for `Blog`""" + type BlogAvgAggregateResult { + """Average of rating values""" + rating: BigFloat + + """Average of visits values""" + visits: BigFloat + + # Other numeric fields... + } + ``` + +=== "BlogMinAggregateResult" + ```graphql + """Result of minimum aggregation for `Blog`""" + type BlogMinAggregateResult { + """Minimum rating value""" + rating: Float + + """Minimum title value""" + title: String + + """Minimum createdAt value""" + createdAt: Datetime + + # Other comparable fields... + } + ``` + +=== "BlogMaxAggregateResult" + + ```graphql + """Result of maximum aggregation for `Blog`""" + type BlogMaxAggregateResult { + """Maximum rating value""" + rating: Float + + """Maximum title value""" + title: String + + """Maximum updatedAt value""" + updatedAt: Datetime + + # Other comparable fields... + } + ``` + +!!! note + + - The return type for `sum` depends on the input type: integer fields return `BigInt`, while other numeric fields return `BigFloat`. + - The return type for `avg` is always `BigFloat`. + - The return types for `min` and `max` match the original field types. + +!!! note + The `aggregate` field is disabled by default because it can be expensive on large tables. To enable it use a [comment directive](configuration.md#Aggregate) #### Pagination diff --git a/docs/changelog.md b/docs/changelog.md index 5b585180..f55f68c7 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -95,3 +95,4 @@ - bugfix: qualify schema refs ## master +- feature: Add support for aggregate functions (count, sum, avg, min, max) on collection types diff --git a/docs/configuration.md b/docs/configuration.md index b40479fd..395df1f0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -95,6 +95,41 @@ create table "BlogPost"( comment on table "BlogPost" is e'@graphql({"totalCount": {"enabled": true}})'; ``` +### Aggregate + +The `aggregate` field is an opt-in field that extends a table's Connection type. It provides various aggregate functions like count, sum, avg, min, and max that operate on the collection of records that match the query's filters. + +```graphql +type BlogPostConnection { + edges: [BlogPostEdge!]! + pageInfo: PageInfo! + + """Aggregate functions calculated on the collection of `BlogPost`""" + aggregate: BlogPostAggregate # this field +} +``` + +To enable the `aggregate` field for a table, use the directive: + +```sql +comment on table "BlogPost" is e'@graphql({"aggregate": {"enabled": true}})'; +``` + +For example: +```sql +create table "BlogPost"( + id serial primary key, + title varchar(255) not null, + rating int not null +); +comment on table "BlogPost" is e'@graphql({"aggregate": {"enabled": true}})'; +``` + +You can combine both totalCount and aggregate directives: + +```sql +comment on table "BlogPost" is e'@graphql({"totalCount": {"enabled": true}, "aggregate": {"enabled": true}})'; +``` ### Renaming diff --git a/sql/load_sql_context.sql b/sql/load_sql_context.sql index 6f92c567..b6040788 100644 --- a/sql/load_sql_context.sql +++ b/sql/load_sql_context.sql @@ -256,6 +256,14 @@ select false ) ), + 'aggregate', jsonb_build_object( + 'enabled', coalesce( + ( + d.directive -> 'aggregate' ->> 'enabled' = 'true' + ), + false + ) + ), 'primary_key_columns', d.directive -> 'primary_key_columns', 'foreign_keys', d.directive -> 'foreign_keys' ) diff --git a/src/builder.rs b/src/builder.rs index 73800466..adb23a60 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -10,6 +10,39 @@ use std::ops::Deref; use std::str::FromStr; use std::sync::Arc; +#[derive(Clone, Debug)] +pub struct AggregateBuilder { + pub alias: String, + pub selections: Vec, +} + +#[derive(Clone, Debug)] +pub enum AggregateSelection { + Count { + alias: String, + }, + Sum { + alias: String, + column_builders: Vec, + }, + Avg { + alias: String, + column_builders: Vec, + }, + Min { + alias: String, + column_builders: Vec, + }, + Max { + alias: String, + column_builders: Vec, + }, + Typename { + alias: String, + typename: String, + }, +} + #[derive(Clone, Debug)] pub struct InsertBuilder { // args @@ -758,7 +791,6 @@ pub struct ConnectionBuilder { //fields pub selections: Vec, - pub max_rows: u64, } @@ -927,6 +959,7 @@ pub enum ConnectionSelection { Edge(EdgeBuilder), PageInfo(PageInfoBuilder), Typename { alias: String, typename: String }, + Aggregate(AggregateBuilder), } #[derive(Clone, Debug)] @@ -1439,7 +1472,15 @@ where for selection_field in selection_fields { match field_map.get(selection_field.name.as_ref()) { - None => return Err("unknown field in connection".to_string()), + None => { + let error = if selection_field.name.as_ref() == "aggregate" { + "enable the aggregate directive to use aggregates" + } else { + "unknown field in connection" + } + .to_string(); + return Err(error); + } Some(f) => builder_fields.push(match &f.type_.unmodified_type() { __Type::Edge(_) => ConnectionSelection::Edge(to_edge_builder( f, @@ -1454,20 +1495,51 @@ where fragment_definitions, variables, )?), - - _ => match f.name().as_ref() { - "totalCount" => ConnectionSelection::TotalCount { - alias: alias_or_name(&selection_field), - }, - "__typename" => ConnectionSelection::Typename { - alias: alias_or_name(&selection_field), - typename: xtype.name().expect("connection type should have a name"), - }, - _ => return Err("unexpected field type on connection".to_string()), - }, + __Type::Aggregate(_) => { + ConnectionSelection::Aggregate(to_aggregate_builder( + f, + &selection_field, + fragment_definitions, + variables, + )?) + } + __Type::Scalar(Scalar::Int) => { + if selection_field.name.as_ref() == "totalCount" { + ConnectionSelection::TotalCount { + alias: alias_or_name(&selection_field), + } + } else { + return Err(format!( + "Unsupported field type for connection field {}", + selection_field.name.as_ref() + )); + } + } + __Type::Scalar(Scalar::String(None)) => { + if selection_field.name.as_ref() == "__typename" { + ConnectionSelection::Typename { + alias: alias_or_name(&selection_field), + typename: xtype + .name() + .expect("connection type should have a name"), + } + } else { + return Err(format!( + "Unsupported field type for connection field {}", + selection_field.name.as_ref() + )); + } + } + _ => { + return Err(format!( + "unknown field type on connection: {}", + selection_field.name.as_ref() + )) + } }), } } + Ok(ConnectionBuilder { alias, source: ConnectionBuilderSource { @@ -1492,6 +1564,146 @@ where } } +fn to_aggregate_builder<'a, T>( + field: &__Field, + query_field: &graphql_parser::query::Field<'a, T>, + fragment_definitions: &Vec>, + variables: &serde_json::Value, +) -> Result +where + T: Text<'a> + Eq + AsRef + Clone, + T::Value: Hash, +{ + let type_ = field.type_().unmodified_type(); + let __Type::Aggregate(ref _agg_type) = type_ else { + return Err("Internal Error: Expected AggregateType in to_aggregate_builder".to_string()); + }; + + let alias = alias_or_name(query_field); + let mut selections = Vec::new(); + let field_map = field_map(&type_); // Get fields of the AggregateType (count, sum, avg, etc.) + + let type_name = type_.name().ok_or("Aggregate type has no name")?; + + let selection_fields = normalize_selection_set( + &query_field.selection_set, + fragment_definitions, + &type_name, + variables, + )?; + + for selection_field in selection_fields { + let field_name = selection_field.name.as_ref(); + let sub_field = field_map.get(field_name).ok_or(format!( + "Unknown field \"{}\" selected on type \"{}\"", + field_name, type_name + ))?; + let sub_alias = alias_or_name(&selection_field); + + let col_selections = if field_name == "sum" + || field_name == "avg" + || field_name == "min" + || field_name == "max" + { + to_aggregate_column_builders( + sub_field, + &selection_field, + fragment_definitions, + variables, + )? + } else { + vec![] + }; + + selections.push(match field_name { + "count" => AggregateSelection::Count { alias: sub_alias }, + "sum" => AggregateSelection::Sum { + alias: sub_alias, + column_builders: col_selections, + }, + "avg" => AggregateSelection::Avg { + alias: sub_alias, + column_builders: col_selections, + }, + "min" => AggregateSelection::Min { + alias: sub_alias, + column_builders: col_selections, + }, + "max" => AggregateSelection::Max { + alias: sub_alias, + column_builders: col_selections, + }, + "__typename" => AggregateSelection::Typename { + alias: sub_alias, + typename: field + .type_() + .name() + .ok_or("Name for aggregate field's type not found")? + .to_string(), + }, + _ => return Err(format!("Unknown aggregate field: {}", field_name)), + }) + } + + Ok(AggregateBuilder { alias, selections }) +} + +fn to_aggregate_column_builders<'a, T>( + field: &__Field, + query_field: &graphql_parser::query::Field<'a, T>, + fragment_definitions: &Vec>, + variables: &serde_json::Value, +) -> Result, String> +where + T: Text<'a> + Eq + AsRef + Clone, + T::Value: Hash, +{ + let type_ = field.type_().unmodified_type(); + let __Type::AggregateNumeric(_) = type_ else { + return Err("Internal Error: Expected AggregateNumericType".to_string()); + }; + let mut column_builers = Vec::new(); + let field_map = field_map(&type_); + let type_name = type_.name().ok_or("AggregateNumeric type has no name")?; + let selection_fields = normalize_selection_set( + &query_field.selection_set, + fragment_definitions, + &type_name, + variables, + )?; + + for selection_field in selection_fields { + let col_name = selection_field.name.as_ref(); + let sub_field = field_map.get(col_name).ok_or_else(|| { + format!( + "Unknown or invalid field \"{}\" selected on type \"{}\"", + col_name, type_name + ) + })?; + + let __Type::Scalar(_) = sub_field.type_().unmodified_type() else { + return Err(format!( + "Field \"{}\" on type \"{}\" is not a scalar column", + col_name, type_name + )); + }; + let Some(NodeSQLType::Column(column)) = &sub_field.sql_type else { + return Err(format!( + "Internal error: Missing column info for aggregate field '{}'", + col_name + )); + }; + + let alias = alias_or_name(&selection_field); + + column_builers.push(ColumnBuilder { + alias, + column: Arc::clone(column), + }); + } + Ok(column_builers) +} + fn to_page_info_builder<'a, T>( field: &__Field, query_field: &graphql_parser::query::Field<'a, T>, diff --git a/src/graphql.rs b/src/graphql.rs index 1e2d5639..06147774 100644 --- a/src/graphql.rs +++ b/src/graphql.rs @@ -531,6 +531,8 @@ pub enum __Type { // Modifiers List(ListType), NonNull(NonNullType), + Aggregate(AggregateType), + AggregateNumeric(AggregateNumericType), } #[cached( @@ -605,6 +607,8 @@ impl ___Type for __Type { Self::__Directive(x) => x.kind(), Self::List(x) => x.kind(), Self::NonNull(x) => x.kind(), + Self::Aggregate(x) => x.kind(), + Self::AggregateNumeric(x) => x.kind(), } } @@ -640,6 +644,8 @@ impl ___Type for __Type { Self::__Directive(x) => x.name(), Self::List(x) => x.name(), Self::NonNull(x) => x.name(), + Self::Aggregate(x) => x.name(), + Self::AggregateNumeric(x) => x.name(), } } @@ -675,6 +681,8 @@ impl ___Type for __Type { Self::__Directive(x) => x.description(), Self::List(x) => x.description(), Self::NonNull(x) => x.description(), + Self::Aggregate(x) => x.description(), + Self::AggregateNumeric(x) => x.description(), } } @@ -711,6 +719,8 @@ impl ___Type for __Type { Self::__Directive(x) => x.fields(_include_deprecated), Self::List(x) => x.fields(_include_deprecated), Self::NonNull(x) => x.fields(_include_deprecated), + Self::Aggregate(x) => x.fields(_include_deprecated), + Self::AggregateNumeric(x) => x.fields(_include_deprecated), } } @@ -747,6 +757,8 @@ impl ___Type for __Type { Self::__Directive(x) => x.interfaces(), Self::List(x) => x.interfaces(), Self::NonNull(x) => x.interfaces(), + Self::Aggregate(x) => x.interfaces(), + Self::AggregateNumeric(x) => x.interfaces(), } } @@ -792,6 +804,8 @@ impl ___Type for __Type { Self::__Directive(x) => x.enum_values(_include_deprecated), Self::List(x) => x.enum_values(_include_deprecated), Self::NonNull(x) => x.enum_values(_include_deprecated), + Self::Aggregate(x) => x.enum_values(_include_deprecated), + Self::AggregateNumeric(x) => x.enum_values(_include_deprecated), } } @@ -828,6 +842,8 @@ impl ___Type for __Type { Self::__Directive(x) => x.input_fields(), Self::List(x) => x.input_fields(), Self::NonNull(x) => x.input_fields(), + Self::Aggregate(x) => x.input_fields(), + Self::AggregateNumeric(x) => x.input_fields(), } } @@ -1676,39 +1692,44 @@ impl ___Type for ConnectionType { } fn fields(&self, _include_deprecated: bool) -> Option> { - let mut fields = vec![ - __Field { - name_: "edges".to_string(), - type_: __Type::NonNull(NonNullType { - type_: Box::new(__Type::List(ListType { - type_: Box::new(__Type::NonNull(NonNullType { - type_: Box::new(__Type::Edge(EdgeType { - table: Arc::clone(&self.table), - schema: Arc::clone(&self.schema), - })), - })), + let table_base_type_name = &self.schema.graphql_table_base_type_name(&self.table); + let edge_type = __Type::Edge(EdgeType { + table: Arc::clone(&self.table), + schema: self.schema.clone(), + }); + + let edge = __Field { + name_: "edges".to_string(), + type_: __Type::NonNull(NonNullType { + type_: Box::new(__Type::List(ListType { + type_: Box::new(__Type::NonNull(NonNullType { + type_: Box::new(edge_type), })), - }), - args: vec![], - description: None, - deprecation_reason: None, - sql_type: None, - }, - __Field { - name_: "pageInfo".to_string(), - type_: __Type::NonNull(NonNullType { - type_: Box::new(__Type::PageInfo(PageInfoType)), - }), - args: vec![], - description: None, - deprecation_reason: None, - sql_type: None, - }, - ]; + })), + }), + args: vec![], + description: None, + deprecation_reason: None, + sql_type: None, + }; + + let page_info = __Field { + name_: "pageInfo".to_string(), + type_: __Type::NonNull(NonNullType { + type_: Box::new(__Type::PageInfo(PageInfoType)), + }), + args: vec![], + description: None, + deprecation_reason: None, + sql_type: None, + }; - if let Some(total_count) = self.table.directives.total_count.as_ref() { - if total_count.enabled { - let total_count_field = __Field { + let mut fields = vec![edge, page_info]; + + // Conditionally add totalCount based on the directive + if let Some(total_count_directive) = self.table.directives.total_count.as_ref() { + if total_count_directive.enabled { + let total_count = __Field { name_: "totalCount".to_string(), type_: __Type::NonNull(NonNullType { type_: Box::new(__Type::Scalar(Scalar::Int)), @@ -1720,9 +1741,30 @@ impl ___Type for ConnectionType { deprecation_reason: None, sql_type: None, }; - fields.push(total_count_field); + fields.push(total_count); } } + + // Conditionally add aggregate based on the directive + if let Some(aggregate_directive) = self.table.directives.aggregate.as_ref() { + if aggregate_directive.enabled { + let aggregate = __Field { + name_: "aggregate".to_string(), + type_: __Type::Aggregate(AggregateType { + table: Arc::clone(&self.table), + schema: self.schema.clone(), + }), + args: vec![], + description: Some(format!( + "Aggregate functions calculated on the collection of `{table_base_type_name}`" + )), + deprecation_reason: None, + sql_type: None, + }; + fields.push(aggregate); + } + } + Some(fields) } } @@ -4182,6 +4224,53 @@ impl __Schema { schema: Arc::clone(&schema_rc), })); } + + // Add Aggregate types if the table is selectable + if self.graphql_table_select_types_are_valid(table) { + // Only add aggregate types if the directive is enabled + if let Some(aggregate_directive) = table.directives.aggregate.as_ref() { + if aggregate_directive.enabled { + types_.push(__Type::Aggregate(AggregateType { + table: Arc::clone(table), + schema: Arc::clone(&schema_rc), + })); + // Check if there are any columns aggregatable by sum/avg + if table + .columns + .iter() + .any(|c| is_aggregatable(c, &AggregateOperation::Sum)) + { + types_.push(__Type::AggregateNumeric(AggregateNumericType { + table: Arc::clone(table), + schema: Arc::clone(&schema_rc), + aggregate_op: AggregateOperation::Sum, + })); + types_.push(__Type::AggregateNumeric(AggregateNumericType { + table: Arc::clone(table), + schema: Arc::clone(&schema_rc), + aggregate_op: AggregateOperation::Avg, + })); + } + // Check if there are any columns aggregatable by min/max + if table + .columns + .iter() + .any(|c| is_aggregatable(c, &AggregateOperation::Min)) + { + types_.push(__Type::AggregateNumeric(AggregateNumericType { + table: Arc::clone(table), + schema: Arc::clone(&schema_rc), + aggregate_op: AggregateOperation::Min, + })); + types_.push(__Type::AggregateNumeric(AggregateNumericType { + table: Arc::clone(table), + schema: Arc::clone(&schema_rc), + aggregate_op: AggregateOperation::Max, + })); + } + } + } + } } for (_, enum_) in self @@ -4299,3 +4388,331 @@ impl __Schema { ] } } + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct AggregateType { + pub table: Arc, + pub schema: Arc<__Schema>, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct AggregateNumericType { + pub table: Arc
, + pub schema: Arc<__Schema>, + pub aggregate_op: AggregateOperation, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub enum AggregateOperation { + Sum, + Avg, + Min, + Max, + // Count is handled directly in AggregateType +} + +impl AggregateOperation { + // Helper for descriptive terms used in descriptions + fn descriptive_term(&self) -> &str { + match self { + AggregateOperation::Sum => "summation", + AggregateOperation::Avg => "average", + AggregateOperation::Min => "minimum", + AggregateOperation::Max => "maximum", + } + } + + // Helper for capitalized descriptive terms used in field descriptions + fn capitalized_descriptive_term(&self) -> &str { + match self { + AggregateOperation::Sum => "Sum", + AggregateOperation::Avg => "Average", + AggregateOperation::Min => "Minimum", + AggregateOperation::Max => "Maximum", + } + } +} + +impl std::fmt::Display for AggregateOperation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AggregateOperation::Sum => write!(f, "Sum"), + AggregateOperation::Avg => write!(f, "Avg"), // GraphQL schema uses "Avg" for the type name part + AggregateOperation::Min => write!(f, "Min"), + AggregateOperation::Max => write!(f, "Max"), + } + } +} + +/// Determines if a column's type is suitable for a given aggregate operation. +fn is_aggregatable(column: &Column, op: &AggregateOperation) -> bool { + let Some(ref type_) = column.type_ else { + return false; + }; + + // Removed duplicated closures, will use helper functions below + + match op { + // Sum/Avg only make sense for numeric types + AggregateOperation::Sum | AggregateOperation::Avg => { + // Check category first for arrays/enums, then check name for base types + match type_.category { + TypeCategory::Other => is_pg_numeric_type(&type_.name), + _ => false, // Only allow sum/avg on base numeric types for now + } + } + // Min/Max can work on more types (numeric, string, date/time, etc.) + AggregateOperation::Min | AggregateOperation::Max => { + match type_.category { + TypeCategory::Other => { + is_pg_numeric_type(&type_.name) + || is_pg_string_type(&type_.name) + || is_pg_datetime_type(&type_.name) + || is_pg_boolean_type(&type_.name) + } + _ => false, // Don't allow min/max on composites, arrays, tables, pseudo + } + } + } +} + +/// Returns the appropriate GraphQL scalar type for an aggregate result. +fn aggregate_result_type(column: &Column, op: &AggregateOperation) -> Option { + let type_ = column.type_.as_ref()?; + + match op { + AggregateOperation::Sum => { + // SUM of integers often results in bigint + // SUM of float/numeric results in bigfloat + // Let's simplify and return BigInt for int-like, BigFloat otherwise + if is_pg_small_integer_type(&type_.name) { + Some(Scalar::BigInt) + } else if is_pg_numeric_type(&type_.name) { + Some(Scalar::BigFloat) + } else { + None + } + } + AggregateOperation::Avg => { + if is_pg_numeric_type(&type_.name) { + Some(Scalar::BigFloat) + } else { + None + } + } + AggregateOperation::Min | AggregateOperation::Max => { + if is_pg_numeric_type(&type_.name) { + sql_type_to_scalar(&type_.name, column.max_characters) + } else if is_pg_string_type(&type_.name) { + Some(Scalar::String(column.max_characters)) + } else if is_pg_datetime_type(&type_.name) { + sql_type_to_scalar(&type_.name, column.max_characters) + } else if is_pg_boolean_type(&type_.name) { + Some(Scalar::Boolean) + } else { + None + } + } + } +} + +impl ___Type for AggregateType { + fn kind(&self) -> __TypeKind { + __TypeKind::OBJECT + } + + fn name(&self) -> Option { + let table_base_type_name = &self.schema.graphql_table_base_type_name(&self.table); + Some(format!("{table_base_type_name}Aggregate")) + } + + fn description(&self) -> Option { + let table_base_type_name = &self.schema.graphql_table_base_type_name(&self.table); + Some(format!("Aggregate results for `{table_base_type_name}`")) + } + + fn fields(&self, _include_deprecated: bool) -> Option> { + let mut fields = Vec::new(); + + // Count field (always present) + fields.push(__Field { + name_: "count".to_string(), + type_: __Type::NonNull(NonNullType { + type_: Box::new(__Type::Scalar(Scalar::Int)), + }), + args: vec![], + description: Some("The number of records matching the query".to_string()), + deprecation_reason: None, + sql_type: None, + }); + + // Add fields for Sum, Avg, Min, Max if there are any aggregatable columns + let has_sum_avgable = self + .table + .columns + .iter() + .any(|c| is_aggregatable(c, &AggregateOperation::Sum)); + let has_min_maxable = self + .table + .columns + .iter() + .any(|c| is_aggregatable(c, &AggregateOperation::Min)); + + if has_sum_avgable { + fields.push(__Field { + name_: "sum".to_string(), + type_: __Type::AggregateNumeric(AggregateNumericType { + table: Arc::clone(&self.table), + schema: Arc::clone(&self.schema), + aggregate_op: AggregateOperation::Sum, + }), + args: vec![], + description: Some("Summation aggregates for numeric fields".to_string()), + deprecation_reason: None, + sql_type: None, + }); + fields.push(__Field { + name_: "avg".to_string(), + type_: __Type::AggregateNumeric(AggregateNumericType { + table: Arc::clone(&self.table), + schema: Arc::clone(&self.schema), + aggregate_op: AggregateOperation::Avg, + }), + args: vec![], + description: Some("Average aggregates for numeric fields".to_string()), + deprecation_reason: None, + sql_type: None, + }); + } + + if has_min_maxable { + fields.push(__Field { + name_: "min".to_string(), + type_: __Type::AggregateNumeric(AggregateNumericType { + table: Arc::clone(&self.table), + schema: Arc::clone(&self.schema), + aggregate_op: AggregateOperation::Min, + }), + args: vec![], + description: Some("Minimum aggregates for comparable fields".to_string()), + deprecation_reason: None, + sql_type: None, + }); + fields.push(__Field { + name_: "max".to_string(), + type_: __Type::AggregateNumeric(AggregateNumericType { + table: Arc::clone(&self.table), + schema: Arc::clone(&self.schema), + aggregate_op: AggregateOperation::Max, + }), + args: vec![], + description: Some("Maximum aggregates for comparable fields".to_string()), + deprecation_reason: None, + sql_type: None, + }); + } + Some(fields) + } +} + +impl ___Type for AggregateNumericType { + fn kind(&self) -> __TypeKind { + __TypeKind::OBJECT + } + + fn name(&self) -> Option { + let table_base_type_name = &self.schema.graphql_table_base_type_name(&self.table); + // Use Display trait for op_name + Some(format!( + "{table_base_type_name}{}AggregateResult", + self.aggregate_op + )) + } + + fn description(&self) -> Option { + let table_base_type_name = &self.schema.graphql_table_base_type_name(&self.table); + Some(format!( + "Result of {} aggregation for `{table_base_type_name}`", + self.aggregate_op.descriptive_term() + )) + } + + fn fields(&self, _include_deprecated: bool) -> Option> { + let mut fields = Vec::new(); + + for col in self.table.columns.iter() { + if is_aggregatable(col, &self.aggregate_op) { + if let Some(scalar_type) = aggregate_result_type(col, &self.aggregate_op) { + let field_name = self.schema.graphql_column_field_name(col); + fields.push(__Field { + name_: field_name.clone(), + type_: __Type::Scalar(scalar_type), + args: vec![], + description: Some(format!( + "{} of {} across all matching records", + self.aggregate_op.capitalized_descriptive_term(), + field_name + )), + deprecation_reason: None, + sql_type: Some(NodeSQLType::Column(Arc::clone(col))), + }); + } + } + } + if fields.is_empty() { + None + } else { + Some(fields) + } + } +} + +// Converts SQL type name to a GraphQL Scalar, needed for aggregate_result_type +// This function might already exist or needs to be created/adapted. +// Placeholder implementation: +fn sql_type_to_scalar(sql_type_name: &str, typmod: Option) -> Option { + // Simplified mapping - adapt based on existing logic in sql_types.rs or elsewhere + match sql_type_name { + "int2" | "int4" => Some(Scalar::Int), + "int8" => Some(Scalar::BigInt), + "float4" | "float8" | "numeric" | "decimal" => Some(Scalar::BigFloat), // Use BigFloat for precision + "text" | "varchar" | "char" | "bpchar" | "name" => Some(Scalar::String(typmod)), + "bool" => Some(Scalar::Boolean), + "date" => Some(Scalar::Date), + "time" | "timetz" => Some(Scalar::Time), + "timestamp" | "timestamptz" => Some(Scalar::Datetime), + "uuid" => Some(Scalar::UUID), + "json" | "jsonb" => Some(Scalar::JSON), + _ => Some(Scalar::Opaque), // Fallback for unknown types + } +} + +// Helper functions for PostgreSQL type checking (extracted to deduplicate) +fn is_pg_numeric_type(name: &str) -> bool { + matches!( + name, + "int2" | "int4" | "int8" | "float4" | "float8" | "numeric" | "decimal" | "money" + ) +} + +fn is_pg_string_type(name: &str) -> bool { + matches!( + name, + "text" | "varchar" | "char" | "bpchar" | "name" | "citext" + ) +} + +fn is_pg_datetime_type(name: &str) -> bool { + matches!( + name, + "date" | "time" | "timetz" | "timestamp" | "timestamptz" + ) +} + +fn is_pg_boolean_type(name: &str) -> bool { + matches!(name, "bool") +} + +fn is_pg_small_integer_type(name: &str) -> bool { + matches!(name, "int2" | "int4" | "int8") +} diff --git a/src/sql_types.rs b/src/sql_types.rs index 1645feec..9faff16b 100644 --- a/src/sql_types.rs +++ b/src/sql_types.rs @@ -447,6 +447,11 @@ pub struct TableDirectiveTotalCount { pub enabled: bool, } +#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)] +pub struct TableDirectiveAggregate { + pub enabled: bool, +} + #[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)] pub struct TableDirectiveForeignKey { // Equivalent to ForeignKeyDirectives.local_name @@ -471,6 +476,9 @@ pub struct TableDirectives { // @graphql({"totalCount": { "enabled": true } }) pub total_count: Option, + // @graphql({"aggregate": { "enabled": true } }) + pub aggregate: Option, + // @graphql({"primary_key_columns": ["id"]}) pub primary_key_columns: Option>, diff --git a/src/transpile.rs b/src/transpile.rs index efe981da..062f3a80 100644 --- a/src/transpile.rs +++ b/src/transpile.rs @@ -813,12 +813,6 @@ pub struct FromFunction { } impl ConnectionBuilder { - fn requested_total(&self) -> bool { - self.selections - .iter() - .any(|x| matches!(&x, ConnectionSelection::TotalCount { alias: _ })) - } - fn page_selections(&self) -> Vec { self.selections .iter() @@ -875,13 +869,13 @@ impl ConnectionBuilder { let frags: Vec = self .selections .iter() - .map(|x| { + .filter_map(|x| { x.to_sql( quoted_block_name, &self.order_by, &self.source.table, param_context, - ) + ).transpose() }) .collect::, _>>()?; @@ -917,6 +911,96 @@ impl ConnectionBuilder { } } + // Generates the *contents* of the aggregate jsonb_build_object + fn aggregate_select_list(&self, quoted_block_name: &str) -> Result, String> { + let Some(agg_builder) = self.selections.iter().find_map(|sel| match sel { + ConnectionSelection::Aggregate(builder) => Some(builder), + _ => None, + }) else { + return Ok(None); + }; + + let mut agg_selections = vec![]; + + for selection in &agg_builder.selections { + match selection { + AggregateSelection::Count { alias } => { + // Produces: 'count_alias', count(*) + agg_selections.push(format!("{}, count(*)", quote_literal(alias))); + } + AggregateSelection::Sum { + alias, + column_builders: selections, + } + | AggregateSelection::Avg { + alias, + column_builders: selections, + } + | AggregateSelection::Min { + alias, + column_builders: selections, + } + | AggregateSelection::Max { + alias, + column_builders: selections, + } => { + let pg_func = match selection { + AggregateSelection::Sum { .. } => "sum", + AggregateSelection::Avg { .. } => "avg", + AggregateSelection::Min { .. } => "min", + AggregateSelection::Max { .. } => "max", + AggregateSelection::Count { .. } => { + unreachable!("Count should be handled by its own arm") + } + AggregateSelection::Typename { .. } => { + unreachable!("Typename should be handled by its own arm") + } + }; + + let mut field_selections = vec![]; + for col_builder in selections { + let col_sql = col_builder.to_sql(quoted_block_name)?; + let col_alias = &col_builder.alias; + + // Always cast avg input to numeric for precision + let col_sql_casted = if pg_func == "avg" { + format!("{}::numeric", col_sql) + } else { + col_sql + }; + // Produces: 'col_alias', agg_func(col) + field_selections.push(format!( + "{}, {}({})", + quote_literal(col_alias), + pg_func, + col_sql_casted + )); + } + // Produces: 'agg_alias', jsonb_build_object('col_alias', agg_func(col), ...) + agg_selections.push(format!( + "{}, jsonb_build_object({})", + quote_literal(alias), + field_selections.join(", ") + )); + } + AggregateSelection::Typename { alias, typename } => { + // Produces: '__typename', 'AggregateTypeName' + agg_selections.push(format!( + "{}, {}", + quote_literal(alias), + quote_literal(typename) + )); + } + } + } + + if agg_selections.is_empty() { + Ok(None) + } else { + Ok(Some(agg_selections.join(", "))) + } + } + pub fn to_sql( &self, quoted_parent_block_name: Option<&str>, @@ -946,7 +1030,6 @@ impl ConnectionBuilder { false => &order_by_clause, }; - let requested_total = self.requested_total(); let requested_next_page = self.requested_next_page(); let requested_previous_page = self.requested_previous_page(); @@ -955,6 +1038,7 @@ impl ConnectionBuilder { let cursor = &self.before.clone().or_else(|| self.after.clone()); let object_clause = self.object_clause("ed_block_name, param_context)?; + let aggregate_select_list = self.aggregate_select_list("ed_block_name)?; let selectable_columns_clause = self.source.table.to_selectable_columns_clause(); @@ -985,6 +1069,9 @@ impl ConnectionBuilder { let limit = self.limit_clause(); let offset = self.offset.unwrap_or(0); + // Determine if aggregates are requested based on if we generated a select list + let requested_aggregates = aggregate_select_list.is_some(); + // initialized assuming forwards pagination let mut has_next_page_query = format!( " @@ -1006,7 +1093,7 @@ impl ConnectionBuilder { " ); - let mut has_prev_page_query = format!(" + let mut has_prev_page_query = format!(" with page_minus_1 as ( select not ({pkey_tuple_clause_from_block} = any( __records.seen )) is_pkey_in_records @@ -1035,8 +1122,59 @@ impl ConnectionBuilder { has_prev_page_query = "select null".to_string() } + // Build aggregate CTE if requested + let aggregate_cte = if requested_aggregates { + let select_list_str = aggregate_select_list.unwrap_or_default(); + format!( + r#" + ,__aggregates(agg_result) as ( + select + jsonb_build_object({select_list_str}) + from + {from_clause} + where + {join_clause} + and {where_clause} + ) + "# + ) + } else { + r#" + ,__aggregates(agg_result) as (select null::jsonb) + "# + .to_string() + }; + + // Add helper cte to set page info correctly for empty collections + let has_records_cte = r#" + ,__has_records(has_records) as (select exists(select 1 from __records)) + "#; + + // Clause containing selections *not* including the aggregate + let base_object_clause = object_clause; // Renamed original object_clause + + // Clause to merge the aggregate result if requested + let aggregate_merge_clause = if requested_aggregates { + let agg_alias = self + .selections + .iter() + .find_map(|sel| match sel { + ConnectionSelection::Aggregate(builder) => Some(builder.alias.clone()), + _ => None, + }) + .ok_or( + "Internal Error: Aggregate builder not found when requested_aggregates is true", + )?; + format!( + "|| jsonb_build_object({}, coalesce(__aggregates.agg_result, '{{}}'::jsonb))", + quote_literal(&agg_alias) + ) + } else { + "".to_string() + }; + Ok(format!( - " + r#" ( with __records as ( select @@ -1061,25 +1199,39 @@ impl ConnectionBuilder { from {from_clause} where - {requested_total} -- skips total when not requested - and {join_clause} + {join_clause} and {where_clause} ), __has_next_page(___has_next_page) as ( {has_next_page_query} - ), __has_previous_page(___has_previous_page) as ( {has_prev_page_query} ) + {has_records_cte} + {aggregate_cte}, + __base_object as ( + select jsonb_build_object({base_object_clause}) as obj + from + __total_count + cross join __has_next_page + cross join __has_previous_page + cross join __has_records + left join __records {quoted_block_name} on true + group by + __total_count.___total_count, + __has_next_page.___has_next_page, + __has_previous_page.___has_previous_page, + __has_records.has_records + ) select - jsonb_build_object({object_clause}) -- sorted within edge + coalesce(__base_object.obj, '{{}}'::jsonb) {aggregate_merge_clause} from - __records {quoted_block_name}, - __total_count, - __has_next_page, - __has_previous_page - )" + (select 1) as __dummy_for_left_join + left join __base_object on true + cross join __aggregates + ) + "# )) } } @@ -1124,13 +1276,13 @@ impl PageInfoSelection { Ok(match self { Self::StartCursor { alias } => { format!( - "{}, (array_agg({cursor_clause} order by {order_by_clause}))[1]", + "{}, case when __has_records.has_records then (array_agg({cursor_clause} order by {order_by_clause}))[1] else null end", quote_literal(alias) ) } Self::EndCursor { alias } => { format!( - "{}, (array_agg({cursor_clause} order by {order_by_clause_reversed}))[1]", + "{}, case when __has_records.has_records then (array_agg({cursor_clause} order by {order_by_clause_reversed}))[1] else null end", quote_literal(alias) ) } @@ -1160,31 +1312,30 @@ impl ConnectionSelection { order_by: &OrderByBuilder, table: &Table, param_context: &mut ParamContext, - ) -> Result { + ) -> Result, String> { Ok(match self { - Self::Edge(x) => { - format!( - "{}, {}", - quote_literal(&x.alias), - x.to_sql(block_name, order_by, table, param_context)? - ) - } - Self::PageInfo(x) => { - format!( - "{}, {}", - quote_literal(&x.alias), - x.to_sql(block_name, order_by, table)? - ) - } - Self::TotalCount { alias } => { - format!( - "{}, coalesce(min(__total_count.___total_count), 0)", - quote_literal(alias) - ) - } - Self::Typename { alias, typename } => { - format!("{}, {}", quote_literal(alias), quote_literal(typename)) - } + Self::Edge(x) => Some(format!( + "{}, {}", + quote_literal(&x.alias), + x.to_sql(block_name, order_by, table, param_context)? + )), + Self::PageInfo(x) => Some(format!( + "{}, {}", + quote_literal(&x.alias), + x.to_sql(block_name, order_by, table)? + )), + Self::TotalCount { alias } => Some(format!( + "{}, coalesce(__total_count.___total_count, 0)", + quote_literal(alias), + )), + Self::Typename { alias, typename } => Some(format!( + "{}, {}", + quote_literal(alias), + quote_literal(typename) + )), + // SQL generation is handled by ConnectionBuilder::aggregate_select_list + // and the results are merged in later in the process + Self::Aggregate(_) => None, }) } } @@ -1206,12 +1357,26 @@ impl EdgeBuilder { let x = frags.join(", "); let order_by_clause = order_by.to_order_by_clause(block_name); + // Get the first primary key column name to use in the filter + let first_pk_col = table.primary_key_columns().first().map(|col| &col.name); + + // Create a filter clause that checks if any primary key column is not NULL + let filter_clause = if let Some(pk_col) = first_pk_col { + format!( + "filter (where {}.{} is not null)", + block_name, + quote_ident(pk_col) + ) + } else { + "".to_string() // Fallback if no primary key columns (should be rare) + }; + Ok(format!( "coalesce( jsonb_agg( jsonb_build_object({x}) order by {order_by_clause} - ), + ) {filter_clause}, jsonb_build_array() )" )) diff --git a/test/expected/aggregate.out b/test/expected/aggregate.out new file mode 100644 index 00000000..a1cd654a --- /dev/null +++ b/test/expected/aggregate.out @@ -0,0 +1,503 @@ +begin; + create table account( + id serial primary key, + email varchar(255) not null, + created_at timestamp not null + ); + create table blog( + id serial primary key, + owner_id integer not null references account(id) on delete cascade, + name varchar(255) not null, + description varchar(255), + created_at timestamp not null + ); + create type blog_post_status as enum ('PENDING', 'RELEASED'); + create table blog_post( + id uuid not null default gen_random_uuid() primary key, + blog_id integer not null references blog(id) on delete cascade, + title varchar(255) not null, + body varchar(10000), + tags TEXT[], + status blog_post_status not null, + created_at timestamp not null + ); + -- 5 Accounts + insert into public.account(email, created_at) + values + ('aardvark@x.com', '2025-04-27 12:00:00'), + ('bat@x.com', '2025-04-28 12:00:00'), + ('cat@x.com', '2025-04-29 12:00:00'), + ('dog@x.com', '2025-04-30 12:00:00'), + ('elephant@x.com', '2025-05-01 12:00:00'); + insert into blog(owner_id, name, description, created_at) + values + ((select id from account where email ilike 'a%'), 'A: Blog 1', 'a desc1', '2025-04-22 12:00:00'), + ((select id from account where email ilike 'a%'), 'A: Blog 2', 'a desc2', '2025-04-23 12:00:00'), + ((select id from account where email ilike 'a%'), 'A: Blog 3', 'a desc3', '2025-04-24 12:00:00'), + ((select id from account where email ilike 'b%'), 'B: Blog 3', 'b desc1', '2025-04-25 12:00:00'); + insert into blog_post (blog_id, title, body, tags, status, created_at) + values + ((SELECT id FROM blog WHERE name = 'A: Blog 1'), 'Post 1 in A Blog 1', 'Content for post 1 in A Blog 1', '{"tech", "update"}', 'RELEASED', '2025-04-02 12:00:00'), + ((SELECT id FROM blog WHERE name = 'A: Blog 1'), 'Post 2 in A Blog 1', 'Content for post 2 in A Blog 1', '{"announcement", "tech"}', 'PENDING', '2025-04-07 12:00:00'), + ((SELECT id FROM blog WHERE name = 'A: Blog 2'), 'Post 1 in A Blog 2', 'Content for post 1 in A Blog 2', '{"personal"}', 'RELEASED', '2025-04-12 12:00:00'), + ((SELECT id FROM blog WHERE name = 'A: Blog 2'), 'Post 2 in A Blog 2', 'Content for post 2 in A Blog 2', '{"update"}', 'RELEASED', '2025-04-17 12:00:00'), + ((SELECT id FROM blog WHERE name = 'A: Blog 3'), 'Post 1 in A Blog 3', 'Content for post 1 in A Blog 3', '{"travel", "adventure"}', 'PENDING', '2025-04-22 12:00:00'), + ((SELECT id FROM blog WHERE name = 'B: Blog 3'), 'Post 1 in B Blog 3', 'Content for post 1 in B Blog 3', '{"tech", "review"}', 'RELEASED', '2025-04-27 12:00:00'), + ((SELECT id FROM blog WHERE name = 'B: Blog 3'), 'Post 2 in B Blog 3', 'Content for post 2 in B Blog 3', '{"coding", "tutorial"}', 'PENDING', '2025-05-02 12:00:00'); + comment on table account is e'@graphql({"totalCount": {"enabled": true}, "aggregate": {"enabled": true}})'; + comment on table blog is e'@graphql({"totalCount": {"enabled": true}, "aggregate": {"enabled": true}})'; + comment on table blog_post is e'@graphql({"totalCount": {"enabled": true}, "aggregate": {"enabled": true}})'; + -- Test Case 1: Basic Count on accountCollection + select graphql.resolve($$ + query { + accountCollection { + aggregate { + count + } + } + } + $$); + resolve +-------------------------------------------------------------- + {"data": {"accountCollection": {"aggregate": {"count": 5}}}} +(1 row) + + -- Test Case 2: Filtered Count on accountCollection + select graphql.resolve($$ + query { + accountCollection(filter: { id: { gt: 3 } }) { + aggregate { + count + } + } + } + $$); + resolve +-------------------------------------------------------------- + {"data": {"accountCollection": {"aggregate": {"count": 2}}}} +(1 row) + + -- Test Case 3: Sum, Avg, Min, Max on blogCollection.id + select graphql.resolve($$ + query { + blogCollection { + aggregate { + count + sum { + id + } + avg { + id + } + min { + id + } + max { + id + } + } + } + } + $$); + resolve +-------------------------------------------------------------------------------------------------------------------------------------- + {"data": {"blogCollection": {"aggregate": {"avg": {"id": 2.5}, "max": {"id": 4}, "min": {"id": 1}, "sum": {"id": 10}, "count": 4}}}} +(1 row) + + -- Test Case 4: Aggregates with Filter on blogCollection.id + select graphql.resolve($$ + query { + blogCollection(filter: { ownerId: { lt: 2 } }) { + aggregate { + count + sum { + id + } + avg { + id + } + min { + id + } + max { + id + } + } + } + } + $$); + resolve +------------------------------------------------------------------------------------------------------------------------------------- + {"data": {"blogCollection": {"aggregate": {"avg": {"id": 2.0}, "max": {"id": 3}, "min": {"id": 1}, "sum": {"id": 6}, "count": 3}}}} +(1 row) + + -- Test Case 5: Aggregates with Pagination on blogCollection (should ignore pagination for aggregates) + select graphql.resolve($$ + query { + blogCollection(first: 1) { + edges { + node { + id + name + } + } + aggregate { + count + sum { + id + } + } + } + } + $$); + resolve +----------------------------------------------------------------------------------------------------------------------------------- + {"data": {"blogCollection": {"edges": [{"node": {"id": 1, "name": "A: Blog 1"}}], "aggregate": {"sum": {"id": 10}, "count": 4}}}} +(1 row) + + -- Test Case 7: Aggregates with empty result set on accountCollection + select graphql.resolve($$ + query { + accountCollection(filter: { id: { gt: 1000 } }) { + aggregate { + count + sum { + id + } + avg { + id + } + min { + id + } + max { + id + } + } + } + } + $$); + resolve +-------------------------------------------------------------------------------------------------------------------------------------------------- + {"data": {"accountCollection": {"aggregate": {"avg": {"id": null}, "max": {"id": null}, "min": {"id": null}, "sum": {"id": null}, "count": 0}}}} +(1 row) + + -- Test Case 8: Aggregates on table with null values (using blog.description) + -- Count where description is not null + select graphql.resolve($$ + query { + blogCollection(filter: { description: { is: NOT_NULL }}) { + aggregate { + count + } + } + } + $$); + resolve +----------------------------------------------------------- + {"data": {"blogCollection": {"aggregate": {"count": 4}}}} +(1 row) + + -- Count where description is null + select graphql.resolve($$ + query { + blogCollection(filter: { description: { is: NULL }}) { + aggregate { + count + } + } + } + $$); + resolve +----------------------------------------------------------- + {"data": {"blogCollection": {"aggregate": {"count": 0}}}} +(1 row) + + -- Test Case 9: Basic Count on blogPostCollection + select graphql.resolve($$ + query { + blogPostCollection { + aggregate { + count + } + } + } + $$); + resolve +--------------------------------------------------------------- + {"data": {"blogPostCollection": {"aggregate": {"count": 7}}}} +(1 row) + + -- Test Case 10: Min/Max on non-numeric fields (string, datetime) + select graphql.resolve($$ + query { + blogCollection { + aggregate { + min { + name + description + createdAt + } + max { + name + description + createdAt + } + } + } + } + $$); + resolve +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"data": {"blogCollection": {"aggregate": {"max": {"name": "B: Blog 3", "createdAt": "2025-04-25T12:00:00", "description": "b desc1"}, "min": {"name": "A: Blog 1", "createdAt": "2025-04-22T12:00:00", "description": "a desc1"}}}}} +(1 row) + + -- Test Case 11: Aggregation with relationships (nested queries) + select graphql.resolve($$ + query { + accountCollection { + edges { + node { + email + blogCollection { + aggregate { + count + sum { + id + } + } + } + } + } + } + } + $$); + resolve +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"data": {"accountCollection": {"edges": [{"node": {"email": "aardvark@x.com", "blogCollection": {"aggregate": {"sum": {"id": 6}, "count": 3}}}}, {"node": {"email": "bat@x.com", "blogCollection": {"aggregate": {"sum": {"id": 4}, "count": 1}}}}, {"node": {"email": "cat@x.com", "blogCollection": {"aggregate": {"sum": {"id": null}, "count": 0}}}}, {"node": {"email": "dog@x.com", "blogCollection": {"aggregate": {"sum": {"id": null}, "count": 0}}}}, {"node": {"email": "elephant@x.com", "blogCollection": {"aggregate": {"sum": {"id": null}, "count": 0}}}}]}}} +(1 row) + + -- Test Case 12: Combination of aggregates in a complex query + select graphql.resolve($$ + query { + blogCollection { + edges { + node { + name + blogPostCollection { + aggregate { + count + min { + createdAt + } + max { + createdAt + } + } + } + } + } + aggregate { + count + min { + id + createdAt + } + max { + id + createdAt + } + sum { + id + } + avg { + id + } + } + } + } + $$); + resolve +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"data": {"blogCollection": {"edges": [{"node": {"name": "A: Blog 1", "blogPostCollection": {"aggregate": {"max": {"createdAt": "2025-04-07T12:00:00"}, "min": {"createdAt": "2025-04-02T12:00:00"}, "count": 2}}}}, {"node": {"name": "A: Blog 2", "blogPostCollection": {"aggregate": {"max": {"createdAt": "2025-04-17T12:00:00"}, "min": {"createdAt": "2025-04-12T12:00:00"}, "count": 2}}}}, {"node": {"name": "A: Blog 3", "blogPostCollection": {"aggregate": {"max": {"createdAt": "2025-04-22T12:00:00"}, "min": {"createdAt": "2025-04-22T12:00:00"}, "count": 1}}}}, {"node": {"name": "B: Blog 3", "blogPostCollection": {"aggregate": {"max": {"createdAt": "2025-05-02T12:00:00"}, "min": {"createdAt": "2025-04-27T12:00:00"}, "count": 2}}}}], "aggregate": {"avg": {"id": 2.5}, "max": {"id": 4, "createdAt": "2025-04-25T12:00:00"}, "min": {"id": 1, "createdAt": "2025-04-22T12:00:00"}, "sum": {"id": 10}, "count": 4}}}} +(1 row) + + -- Test Case 13: Complex filters with aggregates using AND/OR/NOT + select graphql.resolve($$ + query { + blogPostCollection( + filter: { + or: [ + {status: {eq: RELEASED}}, + {title: {startsWith: "Post"}} + ] + } + ) { + aggregate { + count + } + } + } + $$); + resolve +--------------------------------------------------------------- + {"data": {"blogPostCollection": {"aggregate": {"count": 7}}}} +(1 row) + + select graphql.resolve($$ + query { + blogPostCollection( + filter: { + and: [ + {status: {eq: PENDING}}, + {not: {blogId: {eq: 4}}} + ] + } + ) { + aggregate { + count + } + } + } + $$); + resolve +--------------------------------------------------------------- + {"data": {"blogPostCollection": {"aggregate": {"count": 2}}}} +(1 row) + + -- Test Case 14: Array field aggregation (on tags array) + select graphql.resolve($$ + query { + blogPostCollection( + filter: { + tags: {contains: "tech"} + } + ) { + aggregate { + count + } + } + } + $$); + resolve +--------------------------------------------------------------- + {"data": {"blogPostCollection": {"aggregate": {"count": 3}}}} +(1 row) + + -- Test Case 15: UUID field aggregation + -- This test verifies that UUID fields are intentionally excluded from min/max aggregation. + -- UUIDs don't have a meaningful natural ordering for aggregation purposes, so they're explicitly + -- excluded from the list of types that can be aggregated with min/max. + select graphql.resolve($$ + query { + blogPostCollection { + aggregate { + min { + id + } + max { + id + } + } + } + } + $$); + resolve +---------------------------------------------------------------------------------------------------------------------------- + {"data": null, "errors": [{"message": "Unknown or invalid field \"id\" selected on type \"BlogPostMinAggregateResult\""}]} +(1 row) + + -- Test Case 16: Edge case - Empty result set with aggregates + select graphql.resolve($$ + query { + blogPostCollection( + filter: { + title: {eq: "This title does not exist"} + } + ) { + aggregate { + count + min { + createdAt + } + max { + createdAt + } + } + } + } + $$); + resolve +----------------------------------------------------------------------------------------------------------------------- + {"data": {"blogPostCollection": {"aggregate": {"max": {"createdAt": null}, "min": {"createdAt": null}, "count": 0}}}} +(1 row) + + -- Test Case 17: Filtering on aggregate results (verify all posts with RELEASED status) + select graphql.resolve($$ + query { + blogPostCollection( + filter: {status: {eq: RELEASED}} + ) { + aggregate { + count + } + } + } + $$); + resolve +--------------------------------------------------------------- + {"data": {"blogPostCollection": {"aggregate": {"count": 4}}}} +(1 row) + + -- Test Case 18: Aggregates on filtered relationships + select graphql.resolve($$ + query { + blogCollection { + edges { + node { + name + blogPostCollection( + filter: {status: {eq: RELEASED}} + ) { + aggregate { + count + } + } + } + } + } + } + $$); + resolve +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"data": {"blogCollection": {"edges": [{"node": {"name": "A: Blog 1", "blogPostCollection": {"aggregate": {"count": 1}}}}, {"node": {"name": "A: Blog 2", "blogPostCollection": {"aggregate": {"count": 2}}}}, {"node": {"name": "A: Blog 3", "blogPostCollection": {"aggregate": {"count": 0}}}}, {"node": {"name": "B: Blog 3", "blogPostCollection": {"aggregate": {"count": 1}}}}]}}} +(1 row) + + -- Test Case 19: aliases test case + select graphql.resolve($$ + query { + blogCollection { + agg: aggregate { + cnt: count + total: sum { + identifier: id + } + average: avg { + identifier: id + } + minimum: min { + identifier: id + } + maximum: max { + identifier: id + } + } + } + } + $$); + resolve +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"data": {"blogCollection": {"agg": {"cnt": 4, "total": {"identifier": 10}, "average": {"identifier": 2.5}, "maximum": {"identifier": 4}, "minimum": {"identifier": 1}}}}} +(1 row) + +rollback; diff --git a/test/expected/aggregate_directive.out b/test/expected/aggregate_directive.out new file mode 100644 index 00000000..6f3800d0 --- /dev/null +++ b/test/expected/aggregate_directive.out @@ -0,0 +1,60 @@ +begin; +-- Create a simple table without any directives +create table product( + id serial primary key, + name text not null, + price numeric not null, + stock int not null +); +insert into product(name, price, stock) +values + ('Widget', 9.99, 100), + ('Gadget', 19.99, 50), + ('Gizmo', 29.99, 25); +-- Try to query aggregate without enabling the directive - should fail +select graphql.resolve($$ +{ + productCollection { + aggregate { + count + } + } +} +$$); + resolve +--------------------------------------------------------------------------------------------- + {"data": null, "errors": [{"message": "enable the aggregate directive to use aggregates"}]} +(1 row) + +-- Enable aggregates +comment on table product is e'@graphql({"aggregate": {"enabled": true}})'; +-- Now aggregates should be available - should succeed +select graphql.resolve($$ +{ + productCollection { + aggregate { + count + sum { + price + stock + } + avg { + price + } + max { + price + name + } + min { + stock + } + } + } +} +$$); + resolve +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"data": {"productCollection": {"aggregate": {"avg": {"price": 19.99}, "max": {"name": "Widget", "price": 29.99}, "min": {"stock": 25}, "sum": {"price": 59.97, "stock": 175}, "count": 3}}}} +(1 row) + +rollback; diff --git a/test/sql/aggregate.sql b/test/sql/aggregate.sql new file mode 100644 index 00000000..41d07ce5 --- /dev/null +++ b/test/sql/aggregate.sql @@ -0,0 +1,441 @@ +begin; + + create table account( + id serial primary key, + email varchar(255) not null, + created_at timestamp not null + ); + + + create table blog( + id serial primary key, + owner_id integer not null references account(id) on delete cascade, + name varchar(255) not null, + description varchar(255), + created_at timestamp not null + ); + + + create type blog_post_status as enum ('PENDING', 'RELEASED'); + + + create table blog_post( + id uuid not null default gen_random_uuid() primary key, + blog_id integer not null references blog(id) on delete cascade, + title varchar(255) not null, + body varchar(10000), + tags TEXT[], + status blog_post_status not null, + created_at timestamp not null + ); + + + -- 5 Accounts + insert into public.account(email, created_at) + values + ('aardvark@x.com', '2025-04-27 12:00:00'), + ('bat@x.com', '2025-04-28 12:00:00'), + ('cat@x.com', '2025-04-29 12:00:00'), + ('dog@x.com', '2025-04-30 12:00:00'), + ('elephant@x.com', '2025-05-01 12:00:00'); + + insert into blog(owner_id, name, description, created_at) + values + ((select id from account where email ilike 'a%'), 'A: Blog 1', 'a desc1', '2025-04-22 12:00:00'), + ((select id from account where email ilike 'a%'), 'A: Blog 2', 'a desc2', '2025-04-23 12:00:00'), + ((select id from account where email ilike 'a%'), 'A: Blog 3', 'a desc3', '2025-04-24 12:00:00'), + ((select id from account where email ilike 'b%'), 'B: Blog 3', 'b desc1', '2025-04-25 12:00:00'); + + insert into blog_post (blog_id, title, body, tags, status, created_at) + values + ((SELECT id FROM blog WHERE name = 'A: Blog 1'), 'Post 1 in A Blog 1', 'Content for post 1 in A Blog 1', '{"tech", "update"}', 'RELEASED', '2025-04-02 12:00:00'), + ((SELECT id FROM blog WHERE name = 'A: Blog 1'), 'Post 2 in A Blog 1', 'Content for post 2 in A Blog 1', '{"announcement", "tech"}', 'PENDING', '2025-04-07 12:00:00'), + ((SELECT id FROM blog WHERE name = 'A: Blog 2'), 'Post 1 in A Blog 2', 'Content for post 1 in A Blog 2', '{"personal"}', 'RELEASED', '2025-04-12 12:00:00'), + ((SELECT id FROM blog WHERE name = 'A: Blog 2'), 'Post 2 in A Blog 2', 'Content for post 2 in A Blog 2', '{"update"}', 'RELEASED', '2025-04-17 12:00:00'), + ((SELECT id FROM blog WHERE name = 'A: Blog 3'), 'Post 1 in A Blog 3', 'Content for post 1 in A Blog 3', '{"travel", "adventure"}', 'PENDING', '2025-04-22 12:00:00'), + ((SELECT id FROM blog WHERE name = 'B: Blog 3'), 'Post 1 in B Blog 3', 'Content for post 1 in B Blog 3', '{"tech", "review"}', 'RELEASED', '2025-04-27 12:00:00'), + ((SELECT id FROM blog WHERE name = 'B: Blog 3'), 'Post 2 in B Blog 3', 'Content for post 2 in B Blog 3', '{"coding", "tutorial"}', 'PENDING', '2025-05-02 12:00:00'); + + comment on table account is e'@graphql({"totalCount": {"enabled": true}, "aggregate": {"enabled": true}})'; + comment on table blog is e'@graphql({"totalCount": {"enabled": true}, "aggregate": {"enabled": true}})'; + comment on table blog_post is e'@graphql({"totalCount": {"enabled": true}, "aggregate": {"enabled": true}})'; + + -- Test Case 1: Basic Count on accountCollection + select graphql.resolve($$ + query { + accountCollection { + aggregate { + count + } + } + } + $$); + + + -- Test Case 2: Filtered Count on accountCollection + select graphql.resolve($$ + query { + accountCollection(filter: { id: { gt: 3 } }) { + aggregate { + count + } + } + } + $$); + + + -- Test Case 3: Sum, Avg, Min, Max on blogCollection.id + select graphql.resolve($$ + query { + blogCollection { + aggregate { + count + sum { + id + } + avg { + id + } + min { + id + } + max { + id + } + } + } + } + $$); + + + -- Test Case 4: Aggregates with Filter on blogCollection.id + select graphql.resolve($$ + query { + blogCollection(filter: { ownerId: { lt: 2 } }) { + aggregate { + count + sum { + id + } + avg { + id + } + min { + id + } + max { + id + } + } + } + } + $$); + + + -- Test Case 5: Aggregates with Pagination on blogCollection (should ignore pagination for aggregates) + select graphql.resolve($$ + query { + blogCollection(first: 1) { + edges { + node { + id + name + } + } + aggregate { + count + sum { + id + } + } + } + } + $$); + + + -- Test Case 7: Aggregates with empty result set on accountCollection + select graphql.resolve($$ + query { + accountCollection(filter: { id: { gt: 1000 } }) { + aggregate { + count + sum { + id + } + avg { + id + } + min { + id + } + max { + id + } + } + } + } + $$); + + -- Test Case 8: Aggregates on table with null values (using blog.description) + -- Count where description is not null + select graphql.resolve($$ + query { + blogCollection(filter: { description: { is: NOT_NULL }}) { + aggregate { + count + } + } + } + $$); + -- Count where description is null + select graphql.resolve($$ + query { + blogCollection(filter: { description: { is: NULL }}) { + aggregate { + count + } + } + } + $$); + + -- Test Case 9: Basic Count on blogPostCollection + select graphql.resolve($$ + query { + blogPostCollection { + aggregate { + count + } + } + } + $$); + + -- Test Case 10: Min/Max on non-numeric fields (string, datetime) + select graphql.resolve($$ + query { + blogCollection { + aggregate { + min { + name + description + createdAt + } + max { + name + description + createdAt + } + } + } + } + $$); + + -- Test Case 11: Aggregation with relationships (nested queries) + select graphql.resolve($$ + query { + accountCollection { + edges { + node { + email + blogCollection { + aggregate { + count + sum { + id + } + } + } + } + } + } + } + $$); + + -- Test Case 12: Combination of aggregates in a complex query + select graphql.resolve($$ + query { + blogCollection { + edges { + node { + name + blogPostCollection { + aggregate { + count + min { + createdAt + } + max { + createdAt + } + } + } + } + } + aggregate { + count + min { + id + createdAt + } + max { + id + createdAt + } + sum { + id + } + avg { + id + } + } + } + } + $$); + + -- Test Case 13: Complex filters with aggregates using AND/OR/NOT + select graphql.resolve($$ + query { + blogPostCollection( + filter: { + or: [ + {status: {eq: RELEASED}}, + {title: {startsWith: "Post"}} + ] + } + ) { + aggregate { + count + } + } + } + $$); + + select graphql.resolve($$ + query { + blogPostCollection( + filter: { + and: [ + {status: {eq: PENDING}}, + {not: {blogId: {eq: 4}}} + ] + } + ) { + aggregate { + count + } + } + } + $$); + + -- Test Case 14: Array field aggregation (on tags array) + select graphql.resolve($$ + query { + blogPostCollection( + filter: { + tags: {contains: "tech"} + } + ) { + aggregate { + count + } + } + } + $$); + + -- Test Case 15: UUID field aggregation + -- This test verifies that UUID fields are intentionally excluded from min/max aggregation. + -- UUIDs don't have a meaningful natural ordering for aggregation purposes, so they're explicitly + -- excluded from the list of types that can be aggregated with min/max. + select graphql.resolve($$ + query { + blogPostCollection { + aggregate { + min { + id + } + max { + id + } + } + } + } + $$); + + -- Test Case 16: Edge case - Empty result set with aggregates + select graphql.resolve($$ + query { + blogPostCollection( + filter: { + title: {eq: "This title does not exist"} + } + ) { + aggregate { + count + min { + createdAt + } + max { + createdAt + } + } + } + } + $$); + + -- Test Case 17: Filtering on aggregate results (verify all posts with RELEASED status) + select graphql.resolve($$ + query { + blogPostCollection( + filter: {status: {eq: RELEASED}} + ) { + aggregate { + count + } + } + } + $$); + + -- Test Case 18: Aggregates on filtered relationships + select graphql.resolve($$ + query { + blogCollection { + edges { + node { + name + blogPostCollection( + filter: {status: {eq: RELEASED}} + ) { + aggregate { + count + } + } + } + } + } + } + $$); + + + -- Test Case 19: aliases test case + select graphql.resolve($$ + query { + blogCollection { + agg: aggregate { + cnt: count + total: sum { + identifier: id + } + average: avg { + identifier: id + } + minimum: min { + identifier: id + } + maximum: max { + identifier: id + } + } + } + } + $$); + +rollback; diff --git a/test/sql/aggregate_directive.sql b/test/sql/aggregate_directive.sql new file mode 100644 index 00000000..1cd87d42 --- /dev/null +++ b/test/sql/aggregate_directive.sql @@ -0,0 +1,56 @@ +begin; + +-- Create a simple table without any directives +create table product( + id serial primary key, + name text not null, + price numeric not null, + stock int not null +); + +insert into product(name, price, stock) +values + ('Widget', 9.99, 100), + ('Gadget', 19.99, 50), + ('Gizmo', 29.99, 25); + +-- Try to query aggregate without enabling the directive - should fail +select graphql.resolve($$ +{ + productCollection { + aggregate { + count + } + } +} +$$); + +-- Enable aggregates +comment on table product is e'@graphql({"aggregate": {"enabled": true}})'; + +-- Now aggregates should be available - should succeed +select graphql.resolve($$ +{ + productCollection { + aggregate { + count + sum { + price + stock + } + avg { + price + } + max { + price + name + } + min { + stock + } + } + } +} +$$); + +rollback;