Skip to content

Supports regular fuzzing for all aggregate functions #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/sqlancer/datafusion/DataFusionErrors.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
errors.add("regex parse error");
errors.add("Invalid string operation: List"); // select [1,2] like null;
errors.add("Unsupported CAST from List"); // not sure
errors.add("This feature is not implemented: Support for 'approx_distinct' for data type");
errors.add("MedianAccumulator not supported for median");
errors.add("Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal");
errors.add("digest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal ");

/*
* Known bugs
Expand All @@ -44,6 +48,9 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
errors.add("Sort expressions cannot be empty for streaming merge."); // https://github.com/apache/datafusion/issues/11561
errors.add("compute_utf8_flag_op_scalar failed to cast literal value NULL for operation"); // https://github.com/apache/datafusion/issues/11623
errors.add("Schema error: No field named"); // https://github.com/apache/datafusion/issues/11635
errors.add("Min/Max accumulator not implemented for type Null."); // https://github.com/apache/datafusion/issues/11749
errors.add("APPROX_PERCENTILE_CONT_WITH_WEIGHT"); // TODO issue
errors.add("APPROX_MEDIAN"); // TODO issue

/*
* False positives
Expand All @@ -53,9 +60,10 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
// is generated in where
// clause
/*
* Not critical, report later
* Not critical, investigate in the future
*/
errors.add("does not match with the projection expression");
errors.add("invalid operator for nested");
errors.add("Arrow error: Cast error: Can't cast value");
}
}
11 changes: 10 additions & 1 deletion src/sqlancer/datafusion/DataFusionOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sqlancer.common.oracle.TestOracle;
import sqlancer.datafusion.DataFusionOptions.DataFusionOracleFactory;
import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState;
import sqlancer.datafusion.test.DataFusionNoCrashAggregate;
import sqlancer.datafusion.test.DataFusionNoRECOracle;
import sqlancer.datafusion.test.DataFusionQueryPartitioningAggrTester;
import sqlancer.datafusion.test.DataFusionQueryPartitioningHavingTester;
Expand All @@ -24,7 +25,9 @@ public class DataFusionOptions implements DBMSSpecificOptions<DataFusionOracleFa

@Override
public List<DataFusionOracleFactory> getTestOracleFactory() {
return Arrays.asList(DataFusionOracleFactory.NOREC, DataFusionOracleFactory.QUERY_PARTITIONING_WHERE
return Arrays.asList(
// DataFusionOracleFactory.NO_CRASH_AGGREGATE
DataFusionOracleFactory.NOREC, DataFusionOracleFactory.QUERY_PARTITIONING_WHERE
/* DataFusionOracleFactory.QUERY_PARTITIONING_AGGREGATE */
/* , DataFusionOracleFactory.QUERY_PARTITIONING_HAVING */);
}
Expand Down Expand Up @@ -53,6 +56,12 @@ public TestOracle<DataFusionGlobalState> create(DataFusionGlobalState globalStat
public TestOracle<DataFusionGlobalState> create(DataFusionGlobalState globalState) throws SQLException {
return new DataFusionQueryPartitioningAggrTester(globalState);
}
},
NO_CRASH_AGGREGATE {
@Override
public TestOracle<DataFusionGlobalState> create(DataFusionGlobalState globalState) throws SQLException {
return new DataFusionNoCrashAggregate(globalState);
}
}
}

Expand Down
37 changes: 37 additions & 0 deletions src/sqlancer/datafusion/ast/DataFusionSelect.java
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,43 @@ public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) {
return randomSelect;
}

// Randomly generate groupby/aggregates, and update fetch columns
// e.g.
// select v1, sum(v2)
// ...
// group by v1
//
// This method assume `DataFusionSelect` is propoerly initialized with `getRandomSelect()`
public void setAggregates(DataFusionGlobalState state) {
// group by exprs (e.g. group by v1, abs(v2))
List<Node<DataFusionExpression>> groupByExprs = new ArrayList<>();
int nGroupBy = state.getRandomly().getInteger(0, 3);
if (Randomly.getBoolean()) {
// Generate expressions like (v1+1, v2 *2)
groupByExprs = this.exprGenGroupBy.generateExpressions(nGroupBy);
} else {
// Generate simple column references like v1, v2
groupByExprs = this.exprGenGroupBy.generateColumns(nGroupBy);
}

// Generate aggregates like SUM(v1), MAX(V2)
this.exprGenAggregate.supportAggregate = true;
List<Node<DataFusionExpression>> aggrExprs = this.exprGenAggregate
.generateExpressions(state.getRandomly().getInteger(0, 3));
this.exprGenAggregate.supportAggregate = false;

// If it's empty, then no group by expr
if (!groupByExprs.isEmpty()) {
this.setGroupByClause(groupByExprs);

List<Node<DataFusionExpression>> fetchCols = new ArrayList<>();
fetchCols.addAll(groupByExprs);
fetchCols.addAll(aggrExprs);
fetchCols = Randomly.nonEmptySubset(fetchCols);
this.setFetchColumns(fetchCols);
}
}

/*
* If set fetch columns with string It will override `fetchColumns` in base class when
* `DataFusionToStringVisitor.asString()` is called
Expand Down
42 changes: 40 additions & 2 deletions src/sqlancer/datafusion/gen/DataFusionBaseExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ public static DataFusionBaseExpr createCommonNumericAggrFuncSingleArg(String nam
new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)))));
}

public static DataFusionBaseExpr createCommonNumericAggrFuncTwoArg(String name) {
return new DataFusionBaseExpr(name, 2, DataFusionBaseExprCategory.AGGREGATE,
Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE),
Arrays.asList(
new ArgumentType.Fixed(
new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))),
new ArgumentType.SameAsFirstArgType()));
}

public static DataFusionBaseExpr createCommonNumericFuncTwoArgs(String name) {
return new DataFusionBaseExpr(name, 2, DataFusionBaseExprCategory.FUNC,
Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE),
Expand Down Expand Up @@ -303,8 +312,37 @@ public enum DataFusionBaseExprType {

// Other Functions

// Aggregate Functions
AGGR_MIN, AGGR_MAX, AGGR_SUM, AGGR_AVG, AGGR_COUNT,
// Aggregate Functions (General)
AGGR_MIN, AGGR_MAX, AGGR_SUM, AGGR_AVG, AGGR_COUNT, BIT_AND, BIT_OR, BIT_XOR, BOOL_AND, BOOL_OR, MEAN, MEDIAN,
FIRST_VALUE, LAST_VALUE,
// Aggregate Functiosn (Statistical)
CORR, // corr(v1, v2)
COVAR, // covar(v1, v2)
COVAR_POP, // covar_pop(v1, v2)
COVAR_SAMP, // covar_samp(v1, v2)
STDDEV, // stddev(v)
STDDEV_POP, // stddev_pop(v)
STDDEV_SAMP, // stddev_samp(v)
VAR, // var(v)
VAR_POP, // var_pop(v)
VAR_SAMP, // var_samp(v)
REGR_AVGX, // regr_avgx(y, x)
REGR_AVGY, // regr_avgy(y, x)
REGR_COUNT, // regr_count(y, x)
REGR_INTERCEPT, // regr_intercept(y, x)
REGR_R2, // regr_r2(y, x)
REGR_SLOPE, // regr_slope(y, x)
REGR_SXX, // regr_sxx(x)
REGR_SYY, // regr_syy(y)
REGR_SXY, // regr_sxy(x, y)
// Aggregate Functions (Approximate)
APPROX_DISTINCT, // approx_distinct(expression)
APPROX_MEDIAN, // approx_median(expression)
APPROX_PERCENTILE_CONT, // approx_percentile_cont(expression, percentile)
APPROX_PERCENTILE_CONT2, // approx_percentile_cont(expression, percentile, centroids)
APPROX_PERCENTILE_CONT_WITH_WEIGHT // approx_percentile_cont_with_weight(expression, weight, percentile)

// Array Aggregate functions
}

/*
Expand Down
91 changes: 91 additions & 0 deletions src/sqlancer/datafusion/gen/DataFusionBaseExprFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,97 @@ public static DataFusionBaseExpr createExpr(DataFusionBaseExprType type) {
Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN,
DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)))),
true);
case BIT_AND:
return new DataFusionBaseExpr("BIT_AND", 1, DataFusionBaseExprCategory.AGGREGATE,
Arrays.asList(DataFusionDataType.BIGINT),
Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT)))),
false);
case BIT_OR:
return new DataFusionBaseExpr("BIT_OR", 1, DataFusionBaseExprCategory.AGGREGATE,
Arrays.asList(DataFusionDataType.BIGINT),
Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT)))),
false);
case BIT_XOR:
return new DataFusionBaseExpr("BIT_XOR", 1, DataFusionBaseExprCategory.AGGREGATE,
Arrays.asList(DataFusionDataType.BIGINT),
Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT)))),
false);
case BOOL_AND:
return new DataFusionBaseExpr("BOOL_AND", 1, DataFusionBaseExprCategory.AGGREGATE,
Arrays.asList(DataFusionDataType.BOOLEAN),
Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN)))),
false);
case BOOL_OR:
return new DataFusionBaseExpr("BOOL_OR", 1, DataFusionBaseExprCategory.AGGREGATE,
Arrays.asList(DataFusionDataType.BOOLEAN),
Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN)))),
false);
case MEAN:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("MEAN");
case MEDIAN:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("MEDIAN");
case FIRST_VALUE:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("FIRST_VALUE");
case LAST_VALUE:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("LAST_VALUE");
case CORR:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("CORR");
case COVAR:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("COVAR");
case COVAR_POP:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("COVAR_POP");
case COVAR_SAMP:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("COVAR_SAMP");
case STDDEV:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("STDDEV");
case STDDEV_POP:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("STDDEV_POP");
case STDDEV_SAMP:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("STDDEV_SAMP");
case VAR:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("VAR");
case VAR_POP:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("VAR_POP");
case VAR_SAMP:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("VAR_SAMP");
case REGR_AVGX:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("REGR_AVGX");
case REGR_AVGY:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("REGR_AVGY");
case REGR_COUNT:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("REGR_COUNT");
case REGR_INTERCEPT:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("REGR_INTERCEPT");
case REGR_R2:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("REGR_R2");
case REGR_SLOPE:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("REGR_SLOPE");
case REGR_SXX:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("REGR_SXX");
case REGR_SYY:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("REGR_SYY");
case REGR_SXY:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("REGR_SXY");
case APPROX_DISTINCT:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("APPROX_DISTINCT");
case APPROX_MEDIAN:
return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("APPROX_MEDIAN");
case APPROX_PERCENTILE_CONT:
return DataFusionBaseExpr.createCommonNumericAggrFuncTwoArg("APPROX_PERCENTILE_CONT");
case APPROX_PERCENTILE_CONT2:
return new DataFusionBaseExpr("APPROX_PERCENTILE_CONT", 3, DataFusionBaseExprCategory.FUNC,
Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE),
Arrays.asList(
new ArgumentType.Fixed(new ArrayList<>(
Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))),
new ArgumentType.SameAsFirstArgType(), new ArgumentType.SameAsFirstArgType()));
case APPROX_PERCENTILE_CONT_WITH_WEIGHT:
return new DataFusionBaseExpr("APPROX_PERCENTILE_CONT_WITH_WEIGHT", 3, DataFusionBaseExprCategory.FUNC,
Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE),
Arrays.asList(
new ArgumentType.Fixed(new ArrayList<>(
Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))),
new ArgumentType.SameAsFirstArgType(), new ArgumentType.SameAsFirstArgType()));
default:
dfAssert(false, "Unreachable. Unimplemented branch for type " + type);
}
Expand Down
16 changes: 16 additions & 0 deletions src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,22 @@ List<DataFusionColumn> filterColumns(DataFusionDataType type) {
}
}

// Duplicate column is possible (e.g. v1, v2, v1)
public List<Node<DataFusionExpression>> generateColumns(int nr) {
List<Node<DataFusionExpression>> cols = new ArrayList<>();
for (int i = 0; i < nr; i++) {
if (columns.isEmpty()) {
cols.add(generateColumn(getRandomType()));
} else {
DataFusionColumn col = Randomly.fromList(columns);
Node<DataFusionExpression> colExpr = new ColumnReferenceNode<>(col);
cols.add(colExpr);
}
}

return cols;
}

@Override
public List<Node<DataFusionExpression>> generateOrderBys() {
List<Node<DataFusionExpression>> expr = super.generateOrderBys();
Expand Down
2 changes: 1 addition & 1 deletion src/sqlancer/datafusion/gen/DataFusionTableGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public SQLQueryAdapter getQuery(DataFusionGlobalState globalState) {
sb.append(tableName);
sb.append("(");

int colCount = Randomly.smallNumber() + 1 + (Randomly.getBoolean() ? 1 : 0);
int colCount = (int) Randomly.getNotCachedInteger(1, 8);
for (int i = 0; i < colCount; i++) {
sb.append("v").append(i).append(" ").append(DataFusionDataType.getRandomWithoutNull().toString());

Expand Down
4 changes: 2 additions & 2 deletions src/sqlancer/datafusion/server/datafusion_server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ bytes = "1.4"
chrono = { version = "0.4.34", default-features = false }
dashmap = "5.5.0"
# This version is for SQLancer CI run (disabled temporary for multiple newly fixed bugs)
# datafusion = { version = "40.0.0" }
datafusion = { version = "41.0.0" }
# Use following line if you want to test against the latest main branch of DataFusion
datafusion = { git = "https://github.com/apache/datafusion.git", branch = "main" }
# datafusion = { git = "https://github.com/apache/datafusion.git", branch = "main" }
env_logger = "0.11"
futures = "0.3"
half = { version = "2.2.1", default-features = false }
Expand Down
79 changes: 79 additions & 0 deletions src/sqlancer/datafusion/test/DataFusionNoCrashAggregate.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package sqlancer.datafusion.test;

import static sqlancer.datafusion.DataFusionUtil.DataFusionLogger.DataFusionLogType.ERROR;
import static sqlancer.datafusion.ast.DataFusionSelect.getRandomSelect;
import static sqlancer.datafusion.gen.DataFusionExpressionGenerator.generateHavingClause;

import java.sql.SQLException;

import sqlancer.ComparatorHelper;
import sqlancer.Randomly;
import sqlancer.common.ast.newast.Node;
import sqlancer.common.oracle.NoRECBase;
import sqlancer.common.oracle.TestOracle;
import sqlancer.datafusion.DataFusionErrors;
import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState;
import sqlancer.datafusion.DataFusionToStringVisitor;
import sqlancer.datafusion.DataFusionUtil;
import sqlancer.datafusion.ast.DataFusionExpression;
import sqlancer.datafusion.ast.DataFusionSelect;
import sqlancer.datafusion.gen.DataFusionExpressionGenerator;

// Simply test no crash bug for generated queries.
// No extra oracle checks.
public class DataFusionNoCrashAggregate extends NoRECBase<DataFusionGlobalState>
implements TestOracle<DataFusionGlobalState> {

private final DataFusionGlobalState state;

public DataFusionNoCrashAggregate(DataFusionGlobalState globalState) {
super(globalState);
this.state = globalState;
DataFusionErrors.registerExpectedExecutionErrors(errors);
}

// Randomly generate a aggregate query.
// And make sure it won't crash DataFusion engine
@Override
public void check() throws SQLException {
DataFusionSelect randomSelect = getRandomSelect(state);
DataFusionExpressionGenerator gen = randomSelect.exprGenAll;

if (Randomly.getBoolean()) {
if (Randomly.getBoolean()) {
randomSelect.distinct = true;
}

if (Randomly.getBoolean()) {
randomSelect.setOrderByClauses(gen.generateOrderBys());
}
}

if (Randomly.getBoolean()) {
randomSelect.setWhereClause(gen.generatePredicate());
}

// generate {group_by_cols, aggrs}
if (!Randomly.getBooleanWithSmallProbability()) {
randomSelect.setAggregates(state);
}

if (Randomly.getBoolean()) {
Node<DataFusionExpression> havingPredicate = generateHavingClause(randomSelect.exprGenGroupBy,
randomSelect.exprGenAggregate);
randomSelect.setHavingClause(havingPredicate);
}

String qString = DataFusionToStringVisitor.asString(randomSelect);
try {
ComparatorHelper.getResultSetFirstColumnAsString(qString, errors, state);
} catch (AssertionError e) {
// Append detailed error message
String replay = DataFusionUtil.getReplay(state.getDatabaseName());
String newMessage = e.getMessage() + "\n" + e.getCause() + "\n" + replay + "\n";
state.dfLogger.appendToLog(ERROR, newMessage);

throw new AssertionError(newMessage);
}
}
}
Loading
Loading