diff --git a/src/sqlancer/common/ast/newast/NewOrderingTerm.java b/src/sqlancer/common/ast/newast/NewOrderingTerm.java index 5afe7182..dbbcc287 100644 --- a/src/sqlancer/common/ast/newast/NewOrderingTerm.java +++ b/src/sqlancer/common/ast/newast/NewOrderingTerm.java @@ -1,11 +1,14 @@ package sqlancer.common.ast.newast; +import java.util.Optional; + import sqlancer.Randomly; public class NewOrderingTerm implements Node { private final Node expr; private final Ordering ordering; + private final Optional orderingNullsOptional; public enum Ordering { ASC, DESC; @@ -15,9 +18,36 @@ public static Ordering getRandom() { } } + public enum OrderingNulls { + NULLS_FIRST, NULLS_LAST; + + public static OrderingNulls getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String toString() { + switch (this) { + case NULLS_FIRST: + return "NULLS FIRST"; + case NULLS_LAST: + return "NULLS LAST"; + default: + throw new AssertionError("Unreachable"); + } + } + } + public NewOrderingTerm(Node expr, Ordering ordering) { this.expr = expr; this.ordering = ordering; + this.orderingNullsOptional = Optional.empty(); + } + + public NewOrderingTerm(Node expr, Ordering ordering, OrderingNulls orderingNulls) { + this.expr = expr; + this.ordering = ordering; + this.orderingNullsOptional = Optional.of(orderingNulls); } public Node getExpr() { @@ -28,4 +58,7 @@ public Ordering getOrdering() { return ordering; } + public Optional getOrderingNullsOptional() { + return orderingNullsOptional; + } } diff --git a/src/sqlancer/common/ast/newast/NewToStringVisitor.java b/src/sqlancer/common/ast/newast/NewToStringVisitor.java index 3ad62544..a1c78c67 100644 --- a/src/sqlancer/common/ast/newast/NewToStringVisitor.java +++ b/src/sqlancer/common/ast/newast/NewToStringVisitor.java @@ -68,6 +68,10 @@ public void visit(NewOrderingTerm ordering) { visit(ordering.getExpr()); sb.append(" "); sb.append(ordering.getOrdering()); + if (ordering.getOrderingNullsOptional().isPresent()) { + sb.append(" "); + sb.append(ordering.getOrderingNullsOptional().get()); + } } public void visit(NewCaseOperatorNode op) { diff --git a/src/sqlancer/datafusion/DataFusionErrors.java b/src/sqlancer/datafusion/DataFusionErrors.java index d294b5cf..80854462 100644 --- a/src/sqlancer/datafusion/DataFusionErrors.java +++ b/src/sqlancer/datafusion/DataFusionErrors.java @@ -2,8 +2,6 @@ import static sqlancer.datafusion.DataFusionUtil.dfAssert; -import java.util.regex.Pattern; - import sqlancer.common.query.ExpectedErrors; public final class DataFusionErrors { @@ -32,17 +30,21 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) { errors.add("Divide by zero"); errors.add("Sort requires at least one column"); errors.add("The data type type Null has no natural order"); + errors.add("Regular expression did not compile"); + errors.add("Cannot cast value"); + errors.add("regex parse error"); + errors.add("Invalid string operation: List"); // select [1,2] like null; + errors.add("Unsupported CAST from List"); // not sure + /* * Known bugs */ - errors.add("to type Int64"); // https://github.com/apache/datafusion/issues/11249 + errors.add("to type Int"); // https://github.com/apache/datafusion/issues/11249 errors.add("bitwise"); // https://github.com/apache/datafusion/issues/11260 - errors.add(" Not all InterleaveExec children have a consistent hash partitioning."); // https://github.com/apache/datafusion/issues/11409 - Pattern pattern = Pattern.compile("ORDER BY.*LOG", Pattern.CASE_INSENSITIVE); - errors.addRegex(pattern); // https://github.com/apache/datafusion/issues/11549 - Pattern patternTriaFunc = Pattern.compile("ORDER BY.*\\b(ACOS|ACOSH|ASIN|ATANH)\\b", Pattern.CASE_INSENSITIVE); - errors.addRegex(patternTriaFunc); // https://github.com/apache/datafusion/issues/11552 errors.add("Sort expressions cannot be empty for streaming merge."); // https://github.com/apache/datafusion/issues/11561 + errors.add("compute_utf8_flag_op_scalar failed to cast literal value NULL for operation"); // https://github.com/apache/datafusion/issues/11623 + errors.add("Schema error: No field named"); // https://github.com/apache/datafusion/issues/11635 + /* * False positives */ @@ -50,5 +52,10 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) { errors.add("Physical plan does not support logical expression AggregateFunction"); // False positive: when aggr // is generated in where // clause + /* + * Not critical, report later + */ + errors.add("does not match with the projection expression"); + errors.add("invalid operator for nested"); } } diff --git a/src/sqlancer/datafusion/DataFusionOptions.java b/src/sqlancer/datafusion/DataFusionOptions.java index fcb0221a..a533843a 100644 --- a/src/sqlancer/datafusion/DataFusionOptions.java +++ b/src/sqlancer/datafusion/DataFusionOptions.java @@ -13,6 +13,8 @@ import sqlancer.datafusion.DataFusionOptions.DataFusionOracleFactory; import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; import sqlancer.datafusion.test.DataFusionNoRECOracle; +import sqlancer.datafusion.test.DataFusionQueryPartitioningAggrTester; +import sqlancer.datafusion.test.DataFusionQueryPartitioningHavingTester; import sqlancer.datafusion.test.DataFusionQueryPartitioningWhereTester; @Parameters(commandDescription = "DataFusion") @@ -22,7 +24,9 @@ public class DataFusionOptions implements DBMSSpecificOptions getTestOracleFactory() { - return Arrays.asList(DataFusionOracleFactory.NOREC, DataFusionOracleFactory.QUERY_PARTITIONING_WHERE); + return Arrays.asList(DataFusionOracleFactory.NOREC, DataFusionOracleFactory.QUERY_PARTITIONING_WHERE + /* DataFusionOracleFactory.QUERY_PARTITIONING_AGGREGATE */ + /* , DataFusionOracleFactory.QUERY_PARTITIONING_HAVING */); } public enum DataFusionOracleFactory implements OracleFactory { @@ -37,6 +41,18 @@ public TestOracle create(DataFusionGlobalState globalStat public TestOracle create(DataFusionGlobalState globalState) throws SQLException { return new DataFusionQueryPartitioningWhereTester(globalState); } + }, + QUERY_PARTITIONING_HAVING { + @Override + public TestOracle create(DataFusionGlobalState globalState) throws SQLException { + return new DataFusionQueryPartitioningHavingTester(globalState); + } + }, + QUERY_PARTITIONING_AGGREGATE { + @Override + public TestOracle create(DataFusionGlobalState globalState) throws SQLException { + return new DataFusionQueryPartitioningAggrTester(globalState); + } } } diff --git a/src/sqlancer/datafusion/DataFusionSchema.java b/src/sqlancer/datafusion/DataFusionSchema.java index 4a5ed803..24c3a4eb 100644 --- a/src/sqlancer/datafusion/DataFusionSchema.java +++ b/src/sqlancer/datafusion/DataFusionSchema.java @@ -85,14 +85,14 @@ private static List getTableColumns(SQLConnection con, String /* * When adding a new type: 1. Update all methods inside this enum 2. Update all `DataFusionBaseExpr`'s signature, if - * it can support new type (in `DataFusionBaseExprFactory.java` + * it can support new type (in `DataFusionBaseExprFactory.java`) * * Types are 'SQL DataType' in DataFusion's documentation * https://datafusion.apache.org/user-guide/sql/data_types.html */ public enum DataFusionDataType { - BIGINT, DOUBLE, BOOLEAN, NULL; + STRING, BIGINT, DOUBLE, BOOLEAN, NULL; public static DataFusionDataType getRandomWithoutNull() { DataFusionDataType dt; @@ -102,6 +102,10 @@ public static DataFusionDataType getRandomWithoutNull() { return dt; } + public boolean isNumeric() { + return this == BIGINT || this == DOUBLE; + } + // How to parse type in DataFusion's catalog to `DataFusionDataType` // As displayed in: // create table t1(v1 int, v2 bigint); @@ -114,6 +118,8 @@ public static DataFusionDataType parseFromDataFusionCatalog(String typeString) { return DataFusionDataType.DOUBLE; case "Boolean": return DataFusionDataType.BOOLEAN; + case "Utf8": + return DataFusionDataType.STRING; default: dfAssert(false, "Unreachable. All branches should be eovered"); } @@ -129,7 +135,9 @@ public Node getRandomConstant(DataFusionGlobalState state) } switch (this) { case BIGINT: - return DataFusionConstant.createIntConstant(state.getRandomly().getInteger()); + long randInt = Randomly.getBoolean() ? state.getRandomly().getInteger() + : state.getRandomly().getInteger(-5, 5); + return DataFusionConstant.createIntConstant(randInt); case BOOLEAN: return new DataFusionConstant.DataFusionBooleanConstant(Randomly.getBoolean()); case DOUBLE: @@ -147,6 +155,8 @@ public Node getRandomConstant(DataFusionGlobalState state) return new DataFusionConstant.DataFusionDoubleConstant(state.getRandomly().getDouble()); case NULL: return DataFusionConstant.createNullConstant(); + case STRING: + return new DataFusionConstant.DataFusionStringConstant(state.getRandomly().getString()); default: dfAssert(false, "Unreachable. All branches should be eovered"); } diff --git a/src/sqlancer/datafusion/DataFusionToStringVisitor.java b/src/sqlancer/datafusion/DataFusionToStringVisitor.java index 1f0276d4..8b34c74e 100644 --- a/src/sqlancer/datafusion/DataFusionToStringVisitor.java +++ b/src/sqlancer/datafusion/DataFusionToStringVisitor.java @@ -95,6 +95,12 @@ private void visit(DataFusionConstant constant) { private void visit(DataFusionSelect select) { sb.append("SELECT "); + if (select.all && !select.distinct) { + sb.append("ALL "); + } + if (select.distinct) { + sb.append("DISTINCT "); + } if (select.fetchColumnsString.isPresent()) { sb.append(select.fetchColumnsString.get()); } else { diff --git a/src/sqlancer/datafusion/DataFusionUtil.java b/src/sqlancer/datafusion/DataFusionUtil.java index 8761bec9..602e5abf 100644 --- a/src/sqlancer/datafusion/DataFusionUtil.java +++ b/src/sqlancer/datafusion/DataFusionUtil.java @@ -1,5 +1,7 @@ package sqlancer.datafusion; +import static java.lang.System.exit; + import java.io.BufferedReader; import java.io.File; import java.io.FileReader; @@ -67,12 +69,12 @@ public static String displayTables(DataFusionGlobalState state, List fro // During development, you might want to manually let this function call exit(1) to fail fast public static void dfAssert(boolean condition, String message) { if (!condition) { - // // Development mode assertion failure - // String methodName = Thread.currentThread().getStackTrace()[2]// .getMethodName(); - // System.err.println("DataFusion assertion failed in function '" + methodName + "': " + message); - // exit(1); + // Development mode assertion failure + String methodName = Thread.currentThread().getStackTrace()[2].getMethodName(); + System.err.println("DataFusion assertion failed in function '" + methodName + "': " + message); + exit(1); - throw new AssertionError(message); + // throw new AssertionError(message); } } @@ -187,4 +189,25 @@ public enum DataFusionLogType { ERROR, DML, SELECT } } + + // Only used in TLP-Having + public static String cleanResultSetString(String value) { + if (value == null) { + return value; + } + + switch (value) { + case "-0.0": + return "0.0"; + case "-0": + return "0"; + default: + } + + if (value.getBytes().length > 7) { + return new String(value.getBytes(), 0, 7); + } + + return value; + } } diff --git a/src/sqlancer/datafusion/ast/DataFusionConstant.java b/src/sqlancer/datafusion/ast/DataFusionConstant.java index d123adb8..0b084db2 100644 --- a/src/sqlancer/datafusion/ast/DataFusionConstant.java +++ b/src/sqlancer/datafusion/ast/DataFusionConstant.java @@ -96,4 +96,42 @@ public String toString() { } + public static class DataFusionStringConstant extends DataFusionConstant { + private final String value; + + public static String cleanString(String input) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < input.length(); i++) { + char c = input.charAt(i); + // Check if the character is a high surrogate + if (Character.isHighSurrogate(c)) { + if (i + 1 < input.length() && Character.isLowSurrogate(input.charAt(i + 1))) { + // It's a valid surrogate pair, add both to the string + sb.append(c); + sb.append(input.charAt(i + 1)); + i++; // Skip the next character as it's part of the surrogate pair + } + } else if (!Character.isLowSurrogate(c) && !Character.isSurrogate(c)) { + // Add only if it's not a low surrogate or any standalone surrogate + sb.append(c); + } + } + return sb.toString(); + } + + public DataFusionStringConstant(String value) { + // cleanup invalid Utf8 + this.value = cleanString(value.replace("'", "''")); + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return "'" + value + "'"; + } + + } } diff --git a/src/sqlancer/datafusion/ast/DataFusionExpression.java b/src/sqlancer/datafusion/ast/DataFusionExpression.java index 6da95cd9..db74cb2d 100644 --- a/src/sqlancer/datafusion/ast/DataFusionExpression.java +++ b/src/sqlancer/datafusion/ast/DataFusionExpression.java @@ -1,5 +1,4 @@ package sqlancer.datafusion.ast; public interface DataFusionExpression { - } diff --git a/src/sqlancer/datafusion/ast/DataFusionSelect.java b/src/sqlancer/datafusion/ast/DataFusionSelect.java index 817f17eb..29805bee 100644 --- a/src/sqlancer/datafusion/ast/DataFusionSelect.java +++ b/src/sqlancer/datafusion/ast/DataFusionSelect.java @@ -19,13 +19,27 @@ import sqlancer.datafusion.gen.DataFusionExpressionGenerator; public class DataFusionSelect extends SelectBase> implements Node { + public boolean all; // SELECT ALL + public boolean distinct; // SELECT DISTINCT public Optional fetchColumnsString = Optional.empty(); // When available, override `fetchColumns` in base // class's `Node` representation (for display) // `from` is used to represent from table list and join clause // `fromList` and `joinList` in base class should always be empty public DataFusionFrom from; - public DataFusionExpressionGenerator exprGen; + // e.g. let's say all colummns are {c1, c2, c3, c4, c5} + // First randomly pick a subset say {c2, c1, c3, c4} + // `exprGenAll` can generate random expr using above 4 columns + // + // Next, randomly take two non-overlapping subset from all columns used by `exprGenAll` + // exprGenGroupBy: {c1} (randomly generate group by exprs using c1 only) + // exprGenAggregate: {c3, c4} + // + // Finally, use all `Gen`s to generate different clauses in a query (`exprGenAll` in where clause, `exprGenGroupBy` + // in group by clause, etc.) + public DataFusionExpressionGenerator exprGenAll; + public DataFusionExpressionGenerator exprGenGroupBy; + public DataFusionExpressionGenerator exprGenAggregate; public enum JoinType { INNER, LEFT, RIGHT, FULL, CROSS, NATURAL @@ -145,6 +159,9 @@ public static DataFusionFrom generateFromClause(DataFusionGlobalState state, // - [expr_aggr_cols] SUM(t3.v1 + t2.v1) public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) { DataFusionSelect randomSelect = new DataFusionSelect(); + if (Randomly.getBooleanWithRatherLowProbability()) { + randomSelect.all = true; + } /* Setup FROM clause */ DataFusionSchema schema = state.getSchema(); // schema of all tables @@ -156,14 +173,24 @@ public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) { } DataFusionFrom randomFrom = DataFusionFrom.generateFromClause(state, randomTables); + /* Setup expression generators (to generate different clauses) */ + List randomColumnsAll = DataFusionTable.getRandomColumns(randomTables); + // 0 <= splitPoint1 <= splitPoint2 < randomColumnsALl.size() + int splitPoint1 = state.getRandomly().getInteger(0, randomColumnsAll.size()); + int splitPoint2 = state.getRandomly().getInteger(splitPoint1, randomColumnsAll.size()); + + randomSelect.exprGenAll = new DataFusionExpressionGenerator(state).setColumns(randomColumnsAll); + randomSelect.exprGenGroupBy = new DataFusionExpressionGenerator(state) + .setColumns(randomColumnsAll.subList(0, splitPoint1)); + randomSelect.exprGenAggregate = new DataFusionExpressionGenerator(state) + .setColumns(randomColumnsAll.subList(splitPoint1, splitPoint2)); + /* Setup WHERE clause */ - List randomColumns = DataFusionTable.getRandomColumns(randomTables); - randomSelect.exprGen = new DataFusionExpressionGenerator(state).setColumns(randomColumns); - Node whereExpr = randomSelect.exprGen + Node whereExpr = randomSelect.exprGenAll .generateExpression(DataFusionSchema.DataFusionDataType.BOOLEAN); /* Constructing result */ - List> randomColumnNodes = randomColumns.stream() + List> randomColumnNodes = randomColumnsAll.stream() .map((c) -> new ColumnReferenceNode(c)) .collect(Collectors.toList()); diff --git a/src/sqlancer/datafusion/gen/DataFusionBaseExpr.java b/src/sqlancer/datafusion/gen/DataFusionBaseExpr.java index 0be57486..67534ca5 100644 --- a/src/sqlancer/datafusion/gen/DataFusionBaseExpr.java +++ b/src/sqlancer/datafusion/gen/DataFusionBaseExpr.java @@ -51,6 +51,26 @@ public static DataFusionBaseExpr createCommonNumericFuncSingleArg(String name) { new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); } + public static DataFusionBaseExpr createCommonStringOperatorTwoArgs(String name) { + return new DataFusionBaseExpr(name, 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.SameAsFirstArgType())); + } + + public static DataFusionBaseExpr createCommonStringFuncOneStringArg(String name, + List returnTypeList) { + return new DataFusionBaseExpr(name, 1, DataFusionBaseExprCategory.FUNC, returnTypeList, + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))))); + } + + public static DataFusionBaseExpr createCommonStringFuncTwoStringArg(String name, + List returnTypeList) { + return new DataFusionBaseExpr(name, 2, DataFusionBaseExprCategory.FUNC, returnTypeList, + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.SameAsFirstArgType())); + } + public static DataFusionBaseExpr createCommonNumericAggrFuncSingleArg(String name) { return new DataFusionBaseExpr(name, 1, DataFusionBaseExprCategory.AGGREGATE, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), @@ -115,15 +135,21 @@ public enum DataFusionBaseExprType { IS_DISTINCT_FROM, // 0 IS DISTINCT FROM NULL IS_NOT_DISTINCT_FROM, // NULL IS NOT DISTINCT FROM NULL - /* - * // Regular expression match operators REGEX_MATCH, // 'datafusion' ~ '^datafusion(-cli)*' - * REGEX_CASE_INSENSITIVE_MATCH, // 'datafusion' ~* '^DATAFUSION(-cli)*' NOT_REGEX_MATCH, // 'datafusion' !~ - * '^DATAFUSION(-cli)*' NOT_REGEX_CASE_INSENSITIVE_MATCH, // 'datafusion' !~* '^DATAFUSION(-cli)+' - * - * // Like pattern match operators LIKE_MATCH, // 'datafusion' ~~ 'dat_f%n' CASE_INSENSITIVE_LIKE_MATCH, // - * 'datafusion' ~~* 'Dat_F%n' NOT_LIKE_MATCH, // 'datafusion' !~~ 'Dat_F%n' NOT_CASE_INSENSITIVE_LIKE_MATCH // - * 'datafusion' !~~* 'Dat%F_n' - */ + // Pattern matching expressions + LIKE, // 'foo' like 'fo' + NOT_LIKE, ILIKE, NOT_ILIKE, + + // Regular expression match operators + REGEX_MATCH, // 'datafusion' ~ '^datafusion(-cli)*' + REGEX_CASE_INSENSITIVE_MATCH, // 'datafusion' ~* '^DATAFUSION(-cli)*' + NOT_REGEX_MATCH, // 'datafusion' !~ '^DATAFUSION(-cli)*' + NOT_REGEX_CASE_INSENSITIVE_MATCH, // 'datafusion' !~* '^DATAFUSION(-cli)+' + + // Like pattern match operators + LIKE_MATCH, // 'datafusion' ~~ 'dat_f%n' + CASE_INSENSITIVE_LIKE_MATCH, // 'datafusion' ~~* 'Dat_F%n' + NOT_LIKE_MATCH, // 'datafusion' !~~ 'Dat_F%n' + NOT_CASE_INSENSITIVE_LIKE_MATCH, // 'datafusion' !~~* 'Dat%F_n' // Logical Operators AND, // true and true @@ -136,10 +162,10 @@ public enum DataFusionBaseExprType { BITWISE_SHIFT_RIGHT, // 5 >> 3 BITWISE_SHIFT_LEFT, // 5 << 3 - /* - * // Other operators STRING_CONCATENATION, // 'Hello, ' || 'DataFusion!' ARRAY_CONTAINS, // - * make_array(1,2,3) @> make_array(1,3) ARRAY_IS_CONTAINED_BY // make_array(1,3) <@ make_array(1,2,3) - */ + // Other operators + STRING_CONCATENATION, // 'Hello, ' || 'DataFusion!' + // ARRAY_CONTAINS, // make_array(1,2,3) @> make_array(1,3) + // ARRAY_IS_CONTAINED_BY, // make_array(1,3) <@ make_array(1,2,3) // Unary Prefix Operators NOT, // NOT true @@ -201,6 +227,71 @@ public enum DataFusionBaseExprType { FUNC_IFNULL, // ifnull(NULL, 'default value') // String Functions + // TODO(datafusion) lpad('foo', 1e100) takes forever, we can let server stop a query if its been running for too + // long + // String Functions - return numeric + FUNC_ASCII, // ascii('string') + FUNC_LENGTH, // length('string') + FUNC_CHAR_LENGTH, // char_length('string') + FUNC_CHARACTER_LENGTH, // character_length('string') + FUNC_BIT_LENGTH, // bit_length('string') + FUNC_CHR, // chr(code) + FUNC_INSTR, // instr('string', 'substring') + FUNC_STRPOS, // strpos('string', 'substring') + FUNC_LEVENSHTEIN, // levenshtein('string1', 'string2') + FUNC_FIND_IN_SET, // find_in_set('b', 'a,b,c,d') + // String Functions - return String + FUNC_INITCAP, // initcap('string') + FUNC_LOWER, // lower('string') + FUNC_UPPER, // upper('string') + FUNC_OCTET_LENGTH, // octet_length('string') + FUNC_BTRIM, // btrim(' string ') + FUNC_BTRIM2, // btrim('--string--', '-') + FUNC_TRIM, // trim('string') + FUNC_TRIM2, // trim('string', 'trim_chars') + FUNC_LTRIM, // ltrim(' string ') + FUNC_LTRIM2, // ltrim('--string-', '-') + FUNC_RTRIM, // rtrim('string ') + FUNC_RTRIM2, // rtrim('-string--', '-') + FUNC_LEFT, // left('string', n) + FUNC_RIGHT, // right('string', n) + + FUNC_CONCAT, // concat('string1', 'string2', ...) + FUNC_CONCAT_WS, // concat_ws('separator', 'string1', 'string2', ...) + + // FUNC_LPAD, // lpad('string', length) + // FUNC_LPAD2, // lpad('string', length, 'pad_string') + // FUNC_RPAD, // rpad('string', length) + // FUNC_RPAD2, // rpad('string', length, 'pad_string') + + // FUNC_REPEAT, // repeat('string', n) + FUNC_REPLACE, // replace('string', 'search', 'replacement') + FUNC_REVERSE, // reverse('string') + FUNC_SPLIT_PART, // split_part('foo-bar-baz', '-', 3) + + FUNC_SUBSTR, // substr('string', start_pos) + FUNC_SUBSTR2, // substr('string', start_pos[, length]) + FUNC_SUBSTRING, // substring('string', start_pos) + FUNC_SUBSTRING2, // substring('string', start_pos[, length]) + FUNC_TRANSLATE, // translate('hello-world', '-', '--') + + FUNC_TO_HEX, // to_hex(number) + // FUNC_UUID, // uuid() + FUNC_SUBSTR_INDEX, // substr_index('string', 'delimiter', count) + FUNC_SUBSTRING_INDEX, // substring_index('string', 'delimiter', count) + + // String Functions - Return boolean + FUNC_ENDS_WITH, // ends_with('string', 'suffix') + FUNC_STARTS_WITH, // starts_with('string', 'prefix') + + // FUNC_OVERLAY, // overlay('string' placing 'substr' from position [for count]) + // TODO(datafusion) generate valid flags for regexp functions + FUNC_REGEXP_LIKE, // regexp_like('aBc', '(b|d)') + FUNC_REGEXP_LIKE2, // regexp_like('aBc', '(b|d)', 'i') + FUNC_REGEXP_MATCH, // regexp_match('aBc', '(b|d)') + FUNC_REGEXP_MATCH2, // regexp_match('aBc', '(b|d)', 'i') + FUNC_REGEXP_REPLACE, // regexp_replace('aBc', '(b|d)') + FUNC_REGEXP_REPLACE2, // regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i') // Time and Date Functions diff --git a/src/sqlancer/datafusion/gen/DataFusionBaseExprFactory.java b/src/sqlancer/datafusion/gen/DataFusionBaseExprFactory.java index d3fe3997..0f31d3fc 100644 --- a/src/sqlancer/datafusion/gen/DataFusionBaseExprFactory.java +++ b/src/sqlancer/datafusion/gen/DataFusionBaseExprFactory.java @@ -1,9 +1,6 @@ package sqlancer.datafusion.gen; import static sqlancer.datafusion.DataFusionUtil.dfAssert; -import static sqlancer.datafusion.gen.DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg; -import static sqlancer.datafusion.gen.DataFusionBaseExpr.createCommonNumericFuncSingleArg; -import static sqlancer.datafusion.gen.DataFusionBaseExpr.createCommonNumericFuncTwoArgs; import java.util.ArrayList; import java.util.Arrays; @@ -27,13 +24,15 @@ public static DataFusionBaseExpr createExpr(DataFusionBaseExprType type) { case IS_NULL: return new DataFusionBaseExpr("IS NULL", 1, DataFusionBaseExprCategory.UNARY_POSTFIX, Arrays.asList(DataFusionDataType.BOOLEAN), - Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.DOUBLE, DataFusionDataType.BIGINT, DataFusionDataType.NULL))))); + Arrays.asList(new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, DataFusionDataType.BOOLEAN, + DataFusionDataType.DOUBLE, DataFusionDataType.BIGINT, DataFusionDataType.NULL))))); case IS_NOT_NULL: return new DataFusionBaseExpr("IS NOT NULL", 1, DataFusionBaseExprCategory.UNARY_POSTFIX, Arrays.asList(DataFusionDataType.BOOLEAN), - Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.DOUBLE, DataFusionDataType.BIGINT, DataFusionDataType.NULL))))); + Arrays.asList(new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, DataFusionDataType.BOOLEAN, + DataFusionDataType.DOUBLE, DataFusionDataType.BIGINT, DataFusionDataType.NULL))))); case BITWISE_AND: return new DataFusionBaseExpr("&", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), @@ -114,65 +113,93 @@ public static DataFusionBaseExpr createExpr(DataFusionBaseExprType type) { return new DataFusionBaseExpr("=", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BOOLEAN), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, - DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), new ArgumentType.SameAsFirstArgType())); case EQUAL2: return new DataFusionBaseExpr("==", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BOOLEAN), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, - DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), new ArgumentType.SameAsFirstArgType())); case NOT_EQUAL: return new DataFusionBaseExpr("!=", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BOOLEAN), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, - DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), new ArgumentType.SameAsFirstArgType())); case LESS_THAN: return new DataFusionBaseExpr("<", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BOOLEAN), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, - DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), new ArgumentType.SameAsFirstArgType())); case LESS_THAN_OR_EQUAL_TO: return new DataFusionBaseExpr("<=", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BOOLEAN), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, - DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), new ArgumentType.SameAsFirstArgType())); case GREATER_THAN: return new DataFusionBaseExpr(">", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BOOLEAN), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, - DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), new ArgumentType.SameAsFirstArgType())); case GREATER_THAN_OR_EQUAL_TO: return new DataFusionBaseExpr(">=", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BOOLEAN), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, - DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), new ArgumentType.SameAsFirstArgType())); case IS_DISTINCT_FROM: return new DataFusionBaseExpr("IS DISTINCT FROM", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BOOLEAN), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, - DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), new ArgumentType.SameAsFirstArgType())); case IS_NOT_DISTINCT_FROM: return new DataFusionBaseExpr("IS NOT DISTINCT FROM", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BOOLEAN), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, - DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), new ArgumentType.SameAsFirstArgType())); + // String related operators + case LIKE: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs(" LIKE "); + case NOT_LIKE: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs(" NOT LIKE "); + case ILIKE: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs(" ILIKE "); + case NOT_ILIKE: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs(" NOT ILIKE "); + case REGEX_MATCH: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs("~"); + case REGEX_CASE_INSENSITIVE_MATCH: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs("~*"); + case NOT_REGEX_MATCH: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs("!~"); + case NOT_REGEX_CASE_INSENSITIVE_MATCH: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs("!~*"); + case LIKE_MATCH: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs("~~"); + case CASE_INSENSITIVE_LIKE_MATCH: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs("~~*"); + case NOT_LIKE_MATCH: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs("!~~"); + case NOT_CASE_INSENSITIVE_LIKE_MATCH: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs("!~~*"); + case STRING_CONCATENATION: + return DataFusionBaseExpr.createCommonStringOperatorTwoArgs("||"); + // Logical Operators case AND: return new DataFusionBaseExpr("AND", 2, DataFusionBaseExprCategory.BINARY, Arrays.asList(DataFusionDataType.BOOLEAN), @@ -200,38 +227,39 @@ public static DataFusionBaseExpr createExpr(DataFusionBaseExprType type) { new ArgumentType.Fixed(new ArrayList<>( Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))) // arg2 )); + // Scalar Functions case FUNC_ABS: - return createCommonNumericFuncSingleArg("ABS"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("ABS"); case FUNC_ACOS: - return createCommonNumericFuncSingleArg("ACOS"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("ACOS"); case FUNC_ACOSH: - return createCommonNumericFuncSingleArg("ACOSH"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("ACOSH"); case FUNC_ASIN: - return createCommonNumericFuncSingleArg("ASIN"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("ASIN"); case FUNC_ASINH: - return createCommonNumericFuncSingleArg("ASINH"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("ASINH"); case FUNC_ATAN: - return createCommonNumericFuncSingleArg("ATAN"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("ATAN"); case FUNC_ATANH: - return createCommonNumericFuncSingleArg("ATANH"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("ATANH"); case FUNC_ATAN2: - return createCommonNumericFuncTwoArgs("ATAN2"); + return DataFusionBaseExpr.createCommonNumericFuncTwoArgs("ATAN2"); case FUNC_CBRT: - return createCommonNumericFuncSingleArg("CBRT"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("CBRT"); case FUNC_CEIL: - return createCommonNumericFuncSingleArg("CEIL"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("CEIL"); case FUNC_COS: - return createCommonNumericFuncSingleArg("COS"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("COS"); case FUNC_COSH: - return createCommonNumericFuncSingleArg("COSH"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("COSH"); case FUNC_DEGREES: - return createCommonNumericFuncSingleArg("DEGREES"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("DEGREES"); case FUNC_EXP: - return createCommonNumericFuncSingleArg("EXP"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("EXP"); case FUNC_FACTORIAL: - return createCommonNumericFuncSingleArg("FACTORIAL"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("FACTORIAL"); case FUNC_FLOOR: - return createCommonNumericFuncSingleArg("FLOOR"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("FLOOR"); case FUNC_GCD: return new DataFusionBaseExpr("GCD", 2, DataFusionBaseExprCategory.FUNC, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), @@ -241,34 +269,34 @@ public static DataFusionBaseExpr createExpr(DataFusionBaseExprType type) { new ArgumentType.Fixed(new ArrayList<>( Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); case FUNC_ISNAN: - return createCommonNumericFuncSingleArg("ISNAN"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("ISNAN"); case FUNC_ISZERO: - return createCommonNumericFuncSingleArg("ISZERO"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("ISZERO"); case FUNC_LCM: - return createCommonNumericFuncTwoArgs("LCM"); + return DataFusionBaseExpr.createCommonNumericFuncTwoArgs("LCM"); case FUNC_LN: - return createCommonNumericFuncSingleArg("LN"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("LN"); case FUNC_LOG: - return createCommonNumericFuncSingleArg("LOG"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("LOG"); case FUNC_LOG_WITH_BASE: - return createCommonNumericFuncTwoArgs("LOG"); + return DataFusionBaseExpr.createCommonNumericFuncTwoArgs("LOG"); case FUNC_LOG10: - return createCommonNumericFuncSingleArg("LOG10"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("LOG10"); case FUNC_LOG2: - return createCommonNumericFuncSingleArg("LOG2"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("LOG2"); case FUNC_NANVL: - return createCommonNumericFuncTwoArgs("NANVL"); + return DataFusionBaseExpr.createCommonNumericFuncTwoArgs("NANVL"); case FUNC_PI: return new DataFusionBaseExpr("PI", 0, DataFusionBaseExprCategory.FUNC, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), Arrays.asList()); case FUNC_POW: - return createCommonNumericFuncSingleArg("POW"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("POW"); case FUNC_POWER: - return createCommonNumericFuncSingleArg("POWER"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("POWER"); case FUNC_RADIANS: - return createCommonNumericFuncSingleArg("RADIANS"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("RADIANS"); case FUNC_ROUND: - return createCommonNumericFuncSingleArg("ROUND"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("ROUND"); case FUNC_ROUND_WITH_DECIMAL: return new DataFusionBaseExpr("ROUND", 2, DataFusionBaseExprCategory.FUNC, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), @@ -277,19 +305,19 @@ public static DataFusionBaseExpr createExpr(DataFusionBaseExprType type) { Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); case FUNC_SIGNUM: - return createCommonNumericFuncSingleArg("SIGNUM"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("SIGNUM"); case FUNC_SIN: - return createCommonNumericFuncSingleArg("SIN"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("SIN"); case FUNC_SINH: - return createCommonNumericFuncSingleArg("SINH"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("SINH"); case FUNC_SQRT: - return createCommonNumericFuncSingleArg("SQRT"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("SQRT"); case FUNC_TAN: - return createCommonNumericFuncSingleArg("TAN"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("TAN"); case FUNC_TANH: - return createCommonNumericFuncSingleArg("TANH"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("TANH"); case FUNC_TRUNC: - return createCommonNumericFuncSingleArg("TRUNC"); + return DataFusionBaseExpr.createCommonNumericFuncSingleArg("TRUNC"); case FUNC_TRUNC_WITH_DECIMAL: return new DataFusionBaseExpr("TRUNC", 2, DataFusionBaseExprCategory.FUNC, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), @@ -300,50 +328,281 @@ public static DataFusionBaseExpr createExpr(DataFusionBaseExprType type) { case FUNC_COALESCE: return new DataFusionBaseExpr("COALESCE", -1, // overide by variadic DataFusionBaseExprCategory.FUNC, - Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), Arrays.asList(), true); + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN, + DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN, DataFusionDataType.STRING))))); case FUNC_NULLIF: return new DataFusionBaseExpr("NULLIF", 2, DataFusionBaseExprCategory.FUNC, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BOOLEAN, DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); case FUNC_NVL: return new DataFusionBaseExpr("NVL", 2, DataFusionBaseExprCategory.FUNC, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BOOLEAN, DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); case FUNC_NVL2: return new DataFusionBaseExpr("NVL2", 3, DataFusionBaseExprCategory.FUNC, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BOOLEAN, DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BOOLEAN, DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); case FUNC_IFNULL: return new DataFusionBaseExpr("IFNULL", 2, DataFusionBaseExprCategory.FUNC, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), Arrays.asList( - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), - new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, - DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, + DataFusionDataType.BOOLEAN, DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.STRING, DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case FUNC_ASCII: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("ASCII", + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)); + case FUNC_LENGTH: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("LENGTH", + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)); + case FUNC_CHAR_LENGTH: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("CHAR_LENGTH", + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)); + case FUNC_CHARACTER_LENGTH: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("CHARACTER_LENGTH", + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)); + case FUNC_BIT_LENGTH: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("BIT_LENGTH", + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)); + case FUNC_CHR: + return new DataFusionBaseExpr("CHR", 1, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_INSTR: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("INSTR", + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)); + case FUNC_STRPOS: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("STRPOS", + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)); + case FUNC_LEVENSHTEIN: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("LEVENSHTEIN", + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)); + case FUNC_FIND_IN_SET: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("FIND_IN_SET", + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)); + case FUNC_INITCAP: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("INITCAP", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_LOWER: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("LOWER", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_UPPER: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("UPPER", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_OCTET_LENGTH: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("OCTET_LENGTH", + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)); + case FUNC_BTRIM: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("BTRIM", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_BTRIM2: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("BTRIM", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_TRIM: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("TRIM", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_TRIM2: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("TRIM", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_LTRIM: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("LTRIM", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_LTRIM2: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("LTRIM", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_RTRIM: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("RTRIM", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_RTRIM2: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("RTRIM", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_LEFT: + return new DataFusionBaseExpr("LEFT", 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_RIGHT: + return new DataFusionBaseExpr("RIGHT", 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_CONCAT: + return new DataFusionBaseExpr("CONCAT", -1, // overide by variadic + DataFusionBaseExprCategory.FUNC, Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN, DataFusionDataType.STRING))))); + case FUNC_CONCAT_WS: + return new DataFusionBaseExpr("CONCAT_WS", -1, // overide by variadic + DataFusionBaseExprCategory.FUNC, Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN, DataFusionDataType.STRING))))); + // case FUNC_LPAD: + // return new DataFusionBaseExpr("LPAD", 2, DataFusionBaseExprCategory.FUNC, + // Arrays.asList(DataFusionDataType.STRING), + // Arrays.asList( + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + // case FUNC_LPAD2: + // return new DataFusionBaseExpr("LPAD", 3, DataFusionBaseExprCategory.FUNC, + // Arrays.asList(DataFusionDataType.STRING), + // Arrays.asList( + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))), + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))))); + // case FUNC_RPAD: + // return new DataFusionBaseExpr("RPAD", 2, DataFusionBaseExprCategory.FUNC, + // Arrays.asList(DataFusionDataType.STRING), + // Arrays.asList( + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + // case FUNC_RPAD2: + // return new DataFusionBaseExpr("RPAD", 3, DataFusionBaseExprCategory.FUNC, + // Arrays.asList(DataFusionDataType.STRING), + // Arrays.asList( + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))), + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))))); + // case FUNC_REPEAT: + // return new DataFusionBaseExpr("REPEAT", 2, DataFusionBaseExprCategory.FUNC, + // Arrays.asList(DataFusionDataType.STRING), + // Arrays.asList( + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + // new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_REPLACE: + return new DataFusionBaseExpr("REPLACE", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))))); + case FUNC_REVERSE: + return DataFusionBaseExpr.createCommonStringFuncOneStringArg("REVERSE", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_SPLIT_PART: + return new DataFusionBaseExpr("SPLIT_PART", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_SUBSTR: + return new DataFusionBaseExpr("SUBSTR", 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_SUBSTR2: + return new DataFusionBaseExpr("SUBSTR", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_SUBSTRING: + return new DataFusionBaseExpr("SUBSTRING", 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_SUBSTRING2: + return new DataFusionBaseExpr("SUBSTRING", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_TRANSLATE: + return new DataFusionBaseExpr("TRANSLATE", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))))); + case FUNC_TO_HEX: + return new DataFusionBaseExpr("TO_HEX", 1, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + // case FUNC_UUID: + // return new DataFusionBaseExpr("UUID", 0, DataFusionBaseExprCategory.FUNC, + // Arrays.asList(DataFusionDataType.STRING), + // Arrays.asList()); + case FUNC_SUBSTR_INDEX: + return new DataFusionBaseExpr("SUBSTR_INDEX", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_SUBSTRING_INDEX: + return new DataFusionBaseExpr("SUBSTRING_INDEX", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_ENDS_WITH: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("ENDS_WITH", + Arrays.asList(DataFusionDataType.BOOLEAN)); + case FUNC_STARTS_WITH: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("STARTS_WITH", + Arrays.asList(DataFusionDataType.BOOLEAN)); + case FUNC_REGEXP_LIKE: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("REGEXP_LIKE", + Arrays.asList(DataFusionDataType.BOOLEAN)); + case FUNC_REGEXP_LIKE2: + return new DataFusionBaseExpr("REGEXP_LIKE", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))))); + case FUNC_REGEXP_MATCH: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("REGEXP_MATCH", + Arrays.asList(DataFusionDataType.STRING)); // TODO(datafusion) + // return + // type + // change + // to + // List + // after + // List + // is + // supported + case FUNC_REGEXP_MATCH2: + return new DataFusionBaseExpr("REGEXP_MATCH", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))))); + case FUNC_REGEXP_REPLACE: + return DataFusionBaseExpr.createCommonStringFuncTwoStringArg("REGEXP_REPLACE", + Arrays.asList(DataFusionDataType.STRING)); + case FUNC_REGEXP_REPLACE2: + return new DataFusionBaseExpr("REGEXP_REPLACE", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.STRING), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.STRING))))); case AGGR_MIN: - return createCommonNumericAggrFuncSingleArg("MIN"); + return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("MIN"); case AGGR_MAX: - return createCommonNumericAggrFuncSingleArg("MAX"); + return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("MAX"); case AGGR_AVG: - return createCommonNumericAggrFuncSingleArg("AVG"); + return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("AVG"); case AGGR_SUM: - return createCommonNumericAggrFuncSingleArg("SUM"); + return DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg("SUM"); case AGGR_COUNT: return new DataFusionBaseExpr("COUNT", -1, DataFusionBaseExprCategory.AGGREGATE, Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), diff --git a/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java b/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java index ca0e0d2e..549a9f79 100644 --- a/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java +++ b/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java @@ -25,15 +25,18 @@ import sqlancer.datafusion.DataFusionSchema.DataFusionDataType; import sqlancer.datafusion.ast.DataFusionExpression; import sqlancer.datafusion.gen.DataFusionBaseExpr.ArgumentType; +import sqlancer.datafusion.gen.DataFusionBaseExpr.DataFusionBaseExprCategory; import sqlancer.datafusion.gen.DataFusionBaseExpr.DataFusionBaseExprType; public final class DataFusionExpressionGenerator extends TypedExpressionGenerator, DataFusionColumn, DataFusionDataType> { private final DataFusionGlobalState globalState; + public boolean supportAggregate; // control if generate aggr exprs, related logic is in `generateExpression()` public DataFusionExpressionGenerator(DataFusionGlobalState globalState) { this.globalState = globalState; + supportAggregate = false; } @Override @@ -51,6 +54,24 @@ protected boolean canGenerateColumnOfType(DataFusionDataType type) { return true; } + // If target expr type is numeric, when `supportAggregate`, make it more likely to generate aggregate functions + private boolean filterBaseExpr(DataFusionBaseExpr expr, DataFusionDataType type) { + // keep only aggregates + if (supportAggregate && type.isNumeric() && Randomly.getBoolean()) { + return expr.exprType == DataFusionBaseExpr.DataFusionBaseExprCategory.AGGREGATE; + } + + // keep all avaialble expressions (aggr + non-aggr) + if (supportAggregate || Randomly.getBooleanWithRatherLowProbability()) { + return true; + } + + // keep all non-aggregate exprs + return expr.exprType != DataFusionBaseExprCategory.AGGREGATE; + } + + // By default all possible non-aggregate expressions + // To generate aggregate functions: set this.supportAggregate to `true`, generate exprs, and reset. @Override protected Node generateExpression(DataFusionDataType type, int depth) { if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { @@ -61,12 +82,8 @@ protected Node generateExpression(DataFusionDataType type, return generateLeafNode(expectedType); } - // nested aggregate is not allowed, so occasionally apply it - Boolean includeAggr = Randomly.getBooleanWithSmallProbability(); List possibleBaseExprs = getExprsWithReturnType(Optional.of(type)).stream() - // Conditinally apply filter if `includeAggr` set to false - .filter(expr -> includeAggr || expr.exprType != DataFusionBaseExpr.DataFusionBaseExprCategory.AGGREGATE) - .collect(Collectors.toList()); + .filter(expr -> filterBaseExpr(expr, type)).collect(Collectors.toList()); if (possibleBaseExprs.isEmpty()) { dfAssert(type == DataFusionDataType.NULL, "should able to generate expression with type " + type); @@ -136,13 +153,25 @@ protected Node generateExpression(DataFusionDataType type, public Node generateFunctionExpression(DataFusionDataType type, int depth, DataFusionBaseExpr exprType) { - if (exprType.isVariadic || Randomly.getBooleanWithSmallProbability()) { - // TODO(datafusion) maybe add possible types. e.g. some function have signature variadic(INT/DOUBLE), then - // only randomly pick from INT and DOUBLE - int nArgs = Randomly.smallNumber(); // 0, 2, 4, ... smaller one is more likely + if (Randomly.getBooleanWithSmallProbability()) { + int nArgs = (int) Randomly.getNotCachedInteger(0, 5); return new NewFunctionNode(generateExpressions(nArgs), exprType); } + if (exprType.isVariadic) { + int nArgs = (int) Randomly.getNotCachedInteger(0, 5); + dfAssert(exprType.argTypes.get(0) instanceof ArgumentType.Fixed, + "variadic function must specify possible argument types"); + List possibleTypes = ((ArgumentType.Fixed) exprType.argTypes.get(0)).fixedType; + + List> argExprs = new ArrayList<>(); + for (int i = 0; i < nArgs; i++) { + argExprs.add(generateExpression(Randomly.fromList(possibleTypes))); + } + + return new NewFunctionNode(argExprs, exprType); + } + List funcArgTypeList = new ArrayList<>(); // types of current expression's input arguments int i = 0; for (ArgumentType argumentType : exprType.argTypes) { @@ -183,9 +212,18 @@ List filterColumns(DataFusionDataType type) { public List> generateOrderBys() { List> expr = super.generateOrderBys(); List> newExpr = new ArrayList<>(expr.size()); + for (Node curExpr : expr) { if (Randomly.getBoolean()) { - curExpr = new NewOrderingTerm<>(curExpr, NewOrderingTerm.Ordering.getRandom()); + if (Randomly.getBoolean()) { + // e.g. [curExpr] ASC + curExpr = new NewOrderingTerm<>(curExpr, NewOrderingTerm.Ordering.getRandom()); + } else { + // e.g. [curExpr] ASC NULLS LAST + curExpr = new NewOrderingTerm<>(curExpr, NewOrderingTerm.Ordering.getRandom(), + NewOrderingTerm.OrderingNulls.getRandom()); + } + } newExpr.add(curExpr); } @@ -224,6 +262,29 @@ public Node isNull(Node expr) { return new NewUnaryPostfixOperatorNode<>(expr, createExpr(DataFusionBaseExprType.IS_NULL)); } + // TODO(datafusion) refactor: make single generate aware of group by and aggr columns, and it can directly generate + // having clause + // Try best to generate a valid having clause + // + // Suppose query "... group by a, b ..." + // and all available columns are "a, b, c, d" + // then a valid having clause can have expr of {a, b}, and expr of aggregation of {c, d} + // e.g. "having a=b and avg(c) > avg(d)" + // + // `groupbyGen` can generate expression only with group by cols + // `aggrGen` can generate expression only with aggr cols + public static Node generateHavingClause(DataFusionExpressionGenerator groupbyGen, + DataFusionExpressionGenerator aggrGen) { + if (Randomly.getBoolean()) { + return groupbyGen.generatePredicate(); + } else { + aggrGen.supportAggregate = true; + Node expr = aggrGen.generatePredicate(); + aggrGen.supportAggregate = false; + return expr; + } + } + public static class DataFusionCastOperation extends NewUnaryPostfixOperatorNode { public DataFusionCastOperation(Node expr, DataFusionDataType type) { diff --git a/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml b/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml index 59f52f74..dd677369 100644 --- a/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml +++ b/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml @@ -18,8 +18,8 @@ async-trait = "0.1.73" bytes = "1.4" chrono = { version = "0.4.34", default-features = false } dashmap = "5.5.0" -# This version is for SQLancer CI run -#datafusion = { version = "40.0.0" } +# This version is for SQLancer CI run (disabled temporary for multiple newly fixed bugs) +# datafusion = { version = "40.0.0" } # Use following line if you want to test against the latest main branch of DataFusion datafusion = { git = "https://github.com/apache/datafusion.git", branch = "main" } env_logger = "0.11" diff --git a/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java b/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java index 11733a73..c4a7defb 100644 --- a/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java +++ b/src/sqlancer/datafusion/test/DataFusionNoRECOracle.java @@ -50,7 +50,7 @@ public void check() throws SQLException { q1.setWhereClause(randomSelect.getWhereClause()); // Q2: SELECT count(case when [expr3] then 1 else null end) FROM [expr2] DataFusionSelect q2 = new DataFusionSelect(); - String selectExpr = String.format("COUNT(CASE WHEN %S THEN 1 ELSE NULL END)", + String selectExpr = String.format("COUNT(CASE WHEN %s THEN 1 ELSE NULL END)", DataFusionToStringVisitor.asString(randomSelect.getWhereClause())); q2.setFetchColumnsString(selectExpr); q2.from = randomSelect.from; diff --git a/src/sqlancer/datafusion/test/DataFusionQueryPartitioningAggrTester.java b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningAggrTester.java new file mode 100644 index 00000000..134600a2 --- /dev/null +++ b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningAggrTester.java @@ -0,0 +1,210 @@ +package sqlancer.datafusion.test; + +import static sqlancer.datafusion.DataFusionUtil.DataFusionLogger.DataFusionLogType.ERROR; + +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Node; +import sqlancer.datafusion.DataFusionErrors; +import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; +import sqlancer.datafusion.DataFusionSchema; +import sqlancer.datafusion.DataFusionToStringVisitor; +import sqlancer.datafusion.DataFusionUtil; +import sqlancer.datafusion.ast.DataFusionExpression; +import sqlancer.datafusion.ast.DataFusionSelect; + +public class DataFusionQueryPartitioningAggrTester extends DataFusionQueryPartitioningBase { + public DataFusionQueryPartitioningAggrTester(DataFusionGlobalState state) { + super(state); + DataFusionErrors.registerExpectedExecutionErrors(errors); + } + + /* + * Query Partitioning - Aggregate + * + * q: SELECT min([expr1]) FROM [expr2] + * + * qp1: SELECT min([expr1]) FROM [expr2] WHERE [expr3] + * + * qp2: SELECT min([expr1]) FROM [expr2] WHERE NOT [expr3] + * + * qp3: SELECT min([expr1]) FROM [expr2] WHERE [expr3] IS NULL + * + * Oracle check: q's result equals to min(qp1, qp2, qp3) + */ + @Override + public void check() throws SQLException { + // generate a random 'SELECT [expr1] FROM [expr2] WHERE [expr3] + super.check(); + + DFAggrOp aggrOp = Randomly.fromOptions(DFAggrOp.values()); + checkAggregate(aggrOp); + } + + void checkAggregate(DFAggrOp aggrOP) throws SQLException { + DataFusionSelect randomSelect = select; + + String qString = ""; + String qp1String = ""; + String qp2String = ""; + String qp3String = ""; + + randomSelect.setWhereClause(null); + Node fetchExpr = randomSelect.exprGenAll + .generateExpression(DataFusionSchema.DataFusionDataType.getRandomWithoutNull()); // e.g. col1 + col2 + String fetchString = aggrOP.name() + "(" + DataFusionToStringVisitor.asString(fetchExpr) + ")"; + // after: MIN(col1 + col2) + + randomSelect.setFetchColumnsString(fetchString); + qString = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setWhereClause(predicate); + qp1String = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setWhereClause(negatedPredicate); + qp2String = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setWhereClause(isNullPredicate); + qp3String = DataFusionToStringVisitor.asString(randomSelect); + + // q - min(q1, q2, q3) + String diffQuery = aggrOP.formatDiffQuery(Arrays.asList(qString, qp1String, qp2String, qp3String)); + + List diffQueryResultSet = null; + try { + diffQueryResultSet = ComparatorHelper.getResultSetFirstColumnAsString(diffQuery, errors, state); + } catch (AssertionError e) { + // Append more error message + String replay = DataFusionUtil.getReplay(state.getDatabaseName()); + String newMessage = e.getMessage() + "\n" + e.getCause() + "\n" + replay + "\n"; + state.dfLogger.appendToLog(ERROR, newMessage); + + throw new AssertionError(newMessage); + } + + String diffResultString = diffQueryResultSet != null ? diffQueryResultSet.get(0) : "Query Failed"; + // inf - inf + if (diffResultString == null || diffResultString.equals("NaN") + || diffResultString.toLowerCase().contains("inf")) { + return; + } + double diff = -1; + try { + diff = Double.parseDouble(diffResultString); + } catch (Exception e) { + } + + // TODO(datafusion) remove 1e100 condition when overflow is solved + // https://github.com/apache/datafusion/issues/3520 + if (Math.abs(diff) > 1e-3 && Math.abs(diff) < 1e100) { + StringBuilder errorMessage = new StringBuilder().append("TLP-Aggregate oracle violated:\n") + .append(aggrOP.errorReportDescription()).append(diffResultString).append("\n").append("Q: ") + .append(qString).append("\n").append("Q1: ").append(qp1String).append("\n").append("Q2: ") + .append(qp2String).append("\n").append("Q3: ").append(qp3String).append("\n").append(diffQuery) + .append("\n").append("=======================================\n").append("Reproducer: \n"); + + String replay = DataFusionUtil.getReplay(state.getDatabaseName()); + + String errorLog = errorMessage.toString() + replay + "\n"; + String indentedErrorLog = errorLog.replaceAll("(?m)^", " "); + state.dfLogger.appendToLog(ERROR, errorLog); + + throw new AssertionError("\n\n" + indentedErrorLog); + } + } + + private interface DataFusionTLPAggregate { + // e.g. q - min(q1, q2, q3), in the form of single SQL query + // Oracle will check diff query equals to 0 + + // Accepts a list of strings with expected order: q, q1, q2, q3 + // to make linter happy :) + String formatDiffQuery(List queries); + + String errorReportDescription(); + } + + private enum DFAggrOp implements DataFusionTLPAggregate { + MIN { + @Override + public String formatDiffQuery(List queries) { + String q = queries.get(0); + String q1 = queries.get(1); + String q2 = queries.get(2); + String q3 = queries.get(3); + + return "SELECT " + "(" + q + ") - " + "(" + " SELECT MIN(value) " + " FROM (" + " SELECT (" + + q1 + ") AS value " + " UNION ALL " + " SELECT (" + q2 + ") " + + " UNION ALL " + " SELECT (" + q3 + ") " + " ) AS sub" + + ") AS result_difference;"; + } + + @Override + public String errorReportDescription() { + return "Q's result is not equalt to MIN(Q1, Q2, Q3): RS(Q) - MIN(RS(Q1), RS(Q2), RS(Q3)) is :"; + } + }, + MAX { + @Override + public String formatDiffQuery(List queries) { + String q = queries.get(0); + String q1 = queries.get(1); + String q2 = queries.get(2); + String q3 = queries.get(3); + + return "SELECT " + "(" + q + ") - " + "(" + " SELECT MAX(value) " + " FROM (" + " SELECT (" + + q1 + ") AS value " + " UNION ALL " + " SELECT (" + q2 + ") " + + " UNION ALL " + " SELECT (" + q3 + ") " + " ) AS sub" + + ") AS result_difference;"; + } + + @Override + public String errorReportDescription() { + return "Q's result is not equalt to MAX(Q1, Q2, Q3): RS(Q) - MAX(RS(Q1), RS(Q2), RS(Q3)) is :"; + } + }, + COUNT { + @Override + public String formatDiffQuery(List queries) { + String q = queries.get(0); + String q1 = queries.get(1); + String q2 = queries.get(2); + String q3 = queries.get(3); + + return "SELECT " + "(" + q + ") - " + "(" + " SELECT SUM(value) " + " FROM (" + " SELECT (" + + q1 + ") AS value " + " UNION ALL " + " SELECT (" + q2 + ") " + + " UNION ALL " + " SELECT (" + q3 + ") " + " ) AS sub" + + ") AS result_difference;"; + } + + @Override + public String errorReportDescription() { + return "Q's result is not equalt to SUM(Q1, Q2, Q3): RS(Q) - SUM(RS(Q1), RS(Q2), RS(Q3)) is :"; + } + }, + SUM { + @Override + public String formatDiffQuery(List queries) { + String q = queries.get(0); + String q1 = queries.get(1); + String q2 = queries.get(2); + String q3 = queries.get(3); + + return "SELECT " + "(" + q + ") - " + "(" + " SELECT SUM(value) " + " FROM (" + " SELECT (" + + q1 + ") AS value " + " UNION ALL " + " SELECT (" + q2 + ") " + + " UNION ALL " + " SELECT (" + q3 + ") " + " ) AS sub" + + ") AS result_difference;"; + } + + @Override + public String errorReportDescription() { + return "Q's result is not equalt to SUM(Q1, Q2, Q3): RS(Q) - SUM(RS(Q1), RS(Q2), RS(Q3)) is :"; + } + }; + } + +} diff --git a/src/sqlancer/datafusion/test/DataFusionQueryPartitioningBase.java b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningBase.java index b304a413..b5d328d2 100644 --- a/src/sqlancer/datafusion/test/DataFusionQueryPartitioningBase.java +++ b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningBase.java @@ -15,6 +15,9 @@ public class DataFusionQueryPartitioningBase extends TernaryLogicPartitioningOracleBase, DataFusionGlobalState> implements TestOracle { DataFusionGlobalState state; + // Generate expression given available columns + // This includes all columns to generate WHERE + // see DataFusionSelect's comment for other expression generators DataFusionExpressionGenerator gen; DataFusionSelect select; @@ -26,7 +29,7 @@ public DataFusionQueryPartitioningBase(DataFusionGlobalState state) { @Override public void check() throws SQLException { select = DataFusionSelect.getRandomSelect(state); - gen = select.exprGen; + gen = select.exprGenAll; initializeTernaryPredicateVariants(); } diff --git a/src/sqlancer/datafusion/test/DataFusionQueryPartitioningHavingTester.java b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningHavingTester.java new file mode 100644 index 00000000..90ac9aaf --- /dev/null +++ b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningHavingTester.java @@ -0,0 +1,129 @@ +package sqlancer.datafusion.test; + +import static sqlancer.datafusion.DataFusionUtil.DataFusionLogger.DataFusionLogType.ERROR; +import static sqlancer.datafusion.gen.DataFusionExpressionGenerator.generateHavingClause; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Node; +import sqlancer.datafusion.DataFusionErrors; +import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; +import sqlancer.datafusion.DataFusionToStringVisitor; +import sqlancer.datafusion.DataFusionUtil; +import sqlancer.datafusion.ast.DataFusionExpression; +import sqlancer.datafusion.ast.DataFusionSelect; +import sqlancer.datafusion.gen.DataFusionExpressionGenerator; + +public class DataFusionQueryPartitioningHavingTester extends DataFusionQueryPartitioningBase { + public DataFusionQueryPartitioningHavingTester(DataFusionGlobalState state) { + super(state); + DataFusionErrors.registerExpectedExecutionErrors(errors); + } + + /* + * Query Partitioning - Where + * + * q: SELECT [expr1] FROM [expr2] + * + * qp1: SELECT [expr1] FROM [expr2] HAVING [expr3] + * + * qp2: SELECT [expr1] FROM [expr2] HAVING NOT [expr3] + * + * qp3: SELECT [expr1] FROM [expr2] HAVING [expr3] IS NULL + * + * Oracle check: q's result equals to union(qp1, qp2, qp3) + */ + @Override + public void check() throws SQLException { + // generate a random 'SELECT [expr1] FROM [expr2] WHERE [expr3] + super.check(); + DataFusionSelect randomSelect = select; + + if (Randomly.getBoolean()) { + if (Randomly.getBoolean()) { + randomSelect.distinct = true; + } + + if (Randomly.getBoolean()) { + randomSelect.setOrderByClauses(gen.generateOrderBys()); + } + } + + if (Randomly.getBoolean()) { + randomSelect.setWhereClause(gen.generatePredicate()); + } + + // generate {group_by_cols, aggrs} + if (!Randomly.getBooleanWithSmallProbability()) { + List> groupByExprs = randomSelect.exprGenGroupBy + .generateExpressions(state.getRandomly().getInteger(0, 4)); + + randomSelect.exprGenAggregate.supportAggregate = true; + List> aggrExprs = randomSelect.exprGenAggregate + .generateExpressions(state.getRandomly().getInteger(0, 4)); + randomSelect.exprGenAggregate.supportAggregate = false; + + if (!groupByExprs.isEmpty()) { + randomSelect.setGroupByClause(groupByExprs); + + List> fetchCols = new ArrayList<>(); + fetchCols.addAll(groupByExprs); + fetchCols.addAll(aggrExprs); + fetchCols = Randomly.nonEmptySubset(fetchCols); + randomSelect.setFetchColumns(fetchCols); + } + } + + // DataFusionExpressionGenerator havingGen = randomSelect.exprGenAggregate; + DataFusionExpressionGenerator groupByGen = randomSelect.exprGenGroupBy; + Node havingPredicate = generateHavingClause(randomSelect.exprGenGroupBy, + randomSelect.exprGenAggregate); + Node negateHavingPredicate = groupByGen.negatePredicate(havingPredicate); + Node isNullHavingPredicate = groupByGen.isNull(havingPredicate); + + String qString = ""; + String qp1String = ""; + String qp2String = ""; + String qp3String = ""; + randomSelect.setHavingClause(null); + qString = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setHavingClause(havingPredicate); + qp1String = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setHavingClause(negateHavingPredicate); + qp2String = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setHavingClause(isNullHavingPredicate); + qp3String = DataFusionToStringVisitor.asString(randomSelect); + + List qResultSet = new ArrayList<>(); + List qpResultSet = new ArrayList<>(); + try { + /* + * Run all queires + */ + qResultSet = ComparatorHelper.getResultSetFirstColumnAsString(qString, errors, state); + List combinedString = new ArrayList<>(); + qpResultSet = ComparatorHelper.getCombinedResultSet(qp1String, qp2String, qp3String, combinedString, true, + state, errors); + /* + * Query Partitioning-Where check + */ + ComparatorHelper.assumeResultSetsAreEqual(qResultSet, qpResultSet, qString, combinedString, state, + DataFusionUtil::cleanResultSetString); + } catch (AssertionError e) { + // Append more error message + String replay = DataFusionUtil.getReplay(state.getDatabaseName()); + String newMessage = e.getMessage() + "\n" + e.getCause() + "\n" + replay + "\n" + "Query Result: " + + qResultSet + "\nPartitioned Query Result: " + qpResultSet + "\n"; + state.dfLogger.appendToLog(ERROR, newMessage); + + throw new AssertionError(newMessage); + } + } +} diff --git a/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java index 6cc37cdf..3031d231 100644 --- a/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java +++ b/src/sqlancer/datafusion/test/DataFusionQueryPartitioningWhereTester.java @@ -1,6 +1,7 @@ package sqlancer.datafusion.test; import static sqlancer.datafusion.DataFusionUtil.DataFusionLogger.DataFusionLogType.ERROR; +import static sqlancer.datafusion.gen.DataFusionBaseExprFactory.createExpr; import java.sql.SQLException; import java.util.ArrayList; @@ -8,11 +9,15 @@ import sqlancer.ComparatorHelper; import sqlancer.Randomly; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.common.ast.newast.Node; import sqlancer.datafusion.DataFusionErrors; import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; import sqlancer.datafusion.DataFusionToStringVisitor; import sqlancer.datafusion.DataFusionUtil; +import sqlancer.datafusion.ast.DataFusionExpression; import sqlancer.datafusion.ast.DataFusionSelect; +import sqlancer.datafusion.gen.DataFusionBaseExpr.DataFusionBaseExprType; public class DataFusionQueryPartitioningWhereTester extends DataFusionQueryPartitioningBase { public DataFusionQueryPartitioningWhereTester(DataFusionGlobalState state) { @@ -38,23 +43,71 @@ public void check() throws SQLException { // generate a random 'SELECT [expr1] FROM [expr2] WHERE [expr3] super.check(); DataFusionSelect randomSelect = select; - randomSelect.setWhereClause(null); - // set 'order by' - boolean orderBy = Randomly.getBooleanWithRatherLowProbability(); - if (orderBy) { - select.setOrderByClauses(gen.generateOrderBys()); + if (Randomly.getBoolean()) { + randomSelect.distinct = true; } - // Construct q - String qString = DataFusionToStringVisitor.asString(randomSelect); - // Construct qp1, qp2, qp3 - randomSelect.setWhereClause(predicate); - String qp1String = DataFusionToStringVisitor.asString(randomSelect); - randomSelect.setWhereClause(negatedPredicate); - String qp2String = DataFusionToStringVisitor.asString(randomSelect); - randomSelect.setWhereClause(isNullPredicate); - String qp3String = DataFusionToStringVisitor.asString(randomSelect); + if (Randomly.getBoolean()) { + randomSelect.setOrderByClauses(gen.generateOrderBys()); + } + + if (Randomly.getBoolean()) { + if (Randomly.getBoolean()) { + randomSelect.setGroupByClause( + randomSelect.exprGenGroupBy.generateExpressions(state.getRandomly().getInteger(1, 3))); + } + + if (Randomly.getBoolean()) { + randomSelect.setHavingClause(randomSelect.exprGenGroupBy.generatePredicate()); + } + } + + String qString = ""; + String qp1String = ""; + String qp2String = ""; + String qp3String = ""; + if (Randomly.getBoolean()) { + randomSelect.setWhereClause(null); + qString = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setWhereClause(predicate); + qp1String = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setWhereClause(negatedPredicate); + qp2String = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setWhereClause(isNullPredicate); + qp3String = DataFusionToStringVisitor.asString(randomSelect); + } else { + // Extended TLP-WHERE + // + // select * from t1 where pExist + // --------------------------------------------- + // select * from t1 where pExist AND p + // select * from t1 where pExist AND (NOT p) + // select * from t1 where pExist AND (p IS NULL) + Node pExist = gen.generatePredicate(); + Node p1 = new NewBinaryOperatorNode<>(pExist, predicate, + createExpr(DataFusionBaseExprType.AND)); + Node p2 = new NewBinaryOperatorNode<>(pExist, negatedPredicate, + createExpr(DataFusionBaseExprType.AND)); + Node p3 = new NewBinaryOperatorNode<>(pExist, isNullPredicate, + createExpr(DataFusionBaseExprType.AND)); + + randomSelect.setWhereClause(pExist); + + qString = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setWhereClause(p1); + qp1String = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setWhereClause(p2); + qp2String = DataFusionToStringVisitor.asString(randomSelect); + + randomSelect.setWhereClause(p3); + qp3String = DataFusionToStringVisitor.asString(randomSelect); + } try { /*