diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/CalciteToSubstraitVisitor.java similarity index 96% rename from isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java rename to isthmus/src/main/java/io/substrait/isthmus/CalciteToSubstraitVisitor.java index 400e99722..24ee17ee1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/CalciteToSubstraitVisitor.java @@ -44,10 +44,10 @@ @SuppressWarnings("UnstableApiUsage") @Value.Enclosing -public class SubstraitRelVisitor extends RelNodeVisitor { +public class CalciteToSubstraitVisitor extends RelNodeVisitor { static final org.slf4j.Logger logger = - org.slf4j.LoggerFactory.getLogger(SubstraitRelVisitor.class); + org.slf4j.LoggerFactory.getLogger(CalciteToSubstraitVisitor.class); private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build(); private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool(false, true); @@ -57,12 +57,12 @@ public class SubstraitRelVisitor extends RelNodeVisitor { protected final FeatureBoard featureBoard; private Map fieldAccessDepthMap; - public SubstraitRelVisitor( + public CalciteToSubstraitVisitor( RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) { this(typeFactory, extensions, FEATURES_DEFAULT); } - public SubstraitRelVisitor( + public CalciteToSubstraitVisitor( RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { @@ -80,7 +80,7 @@ public SubstraitRelVisitor( this.featureBoard = features; } - public SubstraitRelVisitor( + public CalciteToSubstraitVisitor( RelDataTypeFactory typeFactory, ScalarFunctionConverter scalarFunctionConverter, AggregateFunctionConverter aggregateFunctionConverter, @@ -386,8 +386,9 @@ public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollec public static Plan.Root convert( RelRoot relRoot, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { - SubstraitRelVisitor visitor = - new SubstraitRelVisitor(relRoot.rel.getCluster().getTypeFactory(), extensions, features); + CalciteToSubstraitVisitor visitor = + new CalciteToSubstraitVisitor( + relRoot.rel.getCluster().getTypeFactory(), extensions, features); visitor.popFieldAccessDepthMap(relRoot.rel); Rel rel = visitor.apply(relRoot.project()); @@ -403,8 +404,8 @@ public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection e public static Rel convert( RelNode relNode, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { - SubstraitRelVisitor visitor = - new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensions, features); + CalciteToSubstraitVisitor visitor = + new CalciteToSubstraitVisitor(relNode.getCluster().getTypeFactory(), extensions, features); visitor.popFieldAccessDepthMap(relNode); return visitor.apply(relNode); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index ac8802ed8..badb52515 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -1,8 +1,8 @@ package io.substrait.isthmus; import io.substrait.extension.SimpleExtension; -import io.substrait.isthmus.calcite.SubstraitOperatorTable; -import java.util.ArrayList; +import io.substrait.isthmus.calcite.SubstraitTable; +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; import java.util.List; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; @@ -16,25 +16,13 @@ import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider; import org.apache.calcite.rel.metadata.ProxyingMetadataHandlerProvider; import org.apache.calcite.rel.metadata.RelMetadataQuery; -import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.schema.Schema; -import org.apache.calcite.schema.impl.AbstractTable; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlNodeList; -import org.apache.calcite.sql.SqlOperatorTable; -import org.apache.calcite.sql.ddl.SqlColumnDeclaration; -import org.apache.calcite.sql.ddl.SqlCreateTable; -import org.apache.calcite.sql.ddl.SqlKeyConstraint; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.parser.ddl.SqlDdlParserImpl; import org.apache.calcite.sql.validate.SqlConformanceEnum; -import org.apache.calcite.sql.validate.SqlValidator; -import org.apache.calcite.sql.validate.SqlValidatorCatalogReader; -import org.apache.calcite.sql.validate.SqlValidatorImpl; import org.apache.calcite.sql2rel.SqlToRelConverter; class SqlConverterBase { @@ -76,11 +64,11 @@ CalciteCatalogReader registerCreateTables(List tables) throws SqlParseEx CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); CalciteCatalogReader catalogReader = new CalciteCatalogReader(rootSchema, List.of(), factory, config); - SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); if (tables != null) { for (String tableDef : tables) { - List tList = parseCreateTable(factory, validator, tableDef); - for (DefinedTable t : tList) { + List tList = + SubstraitCreateStatementParser.processCreateStatements(tableDef); + for (SubstraitTable t : tList) { rootSchema.add(t.getName(), t); } } @@ -96,107 +84,4 @@ CalciteCatalogReader registerSchema(String name, Schema schema) { } return new CalciteCatalogReader(rootSchema, List.of(), factory, config); } - - protected List parseCreateTable( - RelDataTypeFactory factory, SqlValidator validator, String sql) throws SqlParseException { - SqlParser parser = SqlParser.create(sql, parserConfig); - List definedTableList = new ArrayList<>(); - - SqlNodeList nodeList = parser.parseStmtList(); - for (SqlNode parsed : nodeList) { - if (!(parsed instanceof SqlCreateTable)) { - throw fail("Not a valid CREATE TABLE statement."); - } - - SqlCreateTable create = (SqlCreateTable) parsed; - if (create.name.names.size() > 1) { - throw fail("Only simple table names are allowed.", create.name.getParserPosition()); - } - - if (create.query != null) { - throw fail("CTAS not supported.", create.name.getParserPosition()); - } - - List names = new ArrayList<>(); - List columnTypes = new ArrayList<>(); - - for (SqlNode node : create.columnList) { - if (!(node instanceof SqlColumnDeclaration)) { - if (node instanceof SqlKeyConstraint) { - // key constraints declarations, like primary key declaration, are valid and should not - // result in parse exceptions. Ignore the constraint declaration. - continue; - } - - throw fail("Unexpected column list construction.", node.getParserPosition()); - } - - SqlColumnDeclaration col = (SqlColumnDeclaration) node; - if (col.name.names.size() != 1) { - throw fail("Expected simple column names.", col.name.getParserPosition()); - } - - names.add(col.name.names.get(0)); - columnTypes.add(col.dataType.deriveType(validator)); - } - - definedTableList.add( - new DefinedTable( - create.name.names.get(0), factory, factory.createStructType(columnTypes, names))); - } - - return definedTableList; - } - - protected static SqlParseException fail(String text, SqlParserPos pos) { - return new SqlParseException(text, pos, null, null, new RuntimeException("fake lineage")); - } - - protected static SqlParseException fail(String text) { - return fail(text, SqlParserPos.ZERO); - } - - protected static final class Validator extends SqlValidatorImpl { - - private Validator( - SqlOperatorTable opTab, - SqlValidatorCatalogReader catalogReader, - RelDataTypeFactory typeFactory, - Config config) { - super(opTab, catalogReader, typeFactory, config); - } - - public static Validator create( - RelDataTypeFactory factory, - SqlValidatorCatalogReader validatorCatalog, - SqlValidator.Config config) { - return new Validator(SubstraitOperatorTable.INSTANCE, validatorCatalog, factory, config); - } - } - - /** A fully defined pre-specified table. */ - protected static final class DefinedTable extends AbstractTable { - - private final String name; - private final RelDataTypeFactory factory; - private final RelDataType type; - - public DefinedTable(String name, RelDataTypeFactory factory, RelDataType type) { - this.name = name; - this.factory = factory; - this.type = type; - } - - @Override - public RelDataType getRowType(RelDataTypeFactory typeFactory) { - // if (factory != typeFactory) { - // throw new IllegalStateException("Different type factory than previously used."); - // } - return type; - } - - public String getName() { - return name; - } - } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 5932a0d35..83d79d9b9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -5,8 +5,11 @@ import io.substrait.extendedexpression.ImmutableExpressionReference; import io.substrait.extendedexpression.ImmutableExtendedExpression; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.SubstraitTable; import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import io.substrait.isthmus.sql.SubstraitSqlValidator; import io.substrait.proto.ExtendedExpression; import io.substrait.type.NamedStruct; import io.substrait.type.Type; @@ -142,11 +145,11 @@ private Result registerCreateTablesForExtendedExpression(List tables) CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); CalciteCatalogReader catalogReader = new CalciteCatalogReader(rootSchema, List.of(), factory, config); - SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); if (tables != null) { for (String tableDef : tables) { - List tList = parseCreateTable(factory, validator, tableDef); - for (DefinedTable t : tList) { + List tList = + SubstraitCreateStatementParser.processCreateStatements(tableDef); + for (SubstraitTable t : tList) { rootSchema.add(t.getName(), t); for (RelDataTypeField field : t.getRowType(factory).getFieldList()) { nameToTypeMap.merge( // to validate the sql expression tree @@ -167,6 +170,7 @@ private Result registerCreateTablesForExtendedExpression(List tables) } } } + SqlValidator validator = new SubstraitSqlValidator(catalogReader); return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 6fb64c6f7..63b1a5ecc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,23 +1,16 @@ package io.substrait.isthmus; -import com.google.common.annotations.VisibleForTesting; import io.substrait.extension.ExtensionCollector; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.proto.Plan; import io.substrait.proto.PlanRel; import io.substrait.relation.RelProtoConverter; import java.util.List; -import org.apache.calcite.plan.hep.HepPlanner; -import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.schema.Schema; -import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.validate.SqlValidator; -import org.apache.calcite.sql2rel.SqlToRelConverter; -import org.apache.calcite.sql2rel.StandardConvertletTable; /** Take a SQL statement and a set of table definitions and return a substrait plan. */ public class SqlToSubstrait extends SqlConverterBase { @@ -32,84 +25,37 @@ public SqlToSubstrait(FeatureBoard features) { public Plan execute(String sql, List tables) throws SqlParseException { CalciteCatalogReader catalogReader = registerCreateTables(tables); - SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); - return executeInner(sql, validator, catalogReader); + return executeInner(sql, catalogReader); } public Plan execute(String sql, String name, Schema schema) throws SqlParseException { CalciteCatalogReader catalogReader = registerSchema(name, schema); - SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); - return executeInner(sql, validator, catalogReader); + return executeInner(sql, catalogReader); } public Plan execute(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException { - SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); - return executeInner(sql, validator, catalogReader); + return executeInner(sql, catalogReader); } - // Package protected for testing - List sqlToRelNode(String sql, List tables) throws SqlParseException { - Prepare.CatalogReader catalogReader = registerCreateTables(tables); - SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); - return sqlToRelNode(sql, validator, catalogReader); - } - - private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogReader catalogReader) + private Plan executeInner(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException { - var plan = Plan.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); var relProtoConverter = new RelProtoConverter(functionCollector); + + List relRoots = SubstraitSqlToCalcite.convertSelects(sql, catalogReader); + // TODO: consider case in which one sql passes conversion while others don't - sqlToRelNode(sql, validator, catalogReader) - .forEach( - root -> { - plan.addRelations( - PlanRel.newBuilder() - .setRoot( - relProtoConverter.toProto( - SubstraitRelVisitor.convert( - root, EXTENSION_COLLECTION, featureBoard)))); - }); + Plan.Builder plan = Plan.newBuilder(); + relRoots.forEach( + root -> { + plan.addRelations( + PlanRel.newBuilder() + .setRoot( + relProtoConverter.toProto( + CalciteToSubstraitVisitor.convert( + root, EXTENSION_COLLECTION, featureBoard)))); + }); functionCollector.addExtensionsToPlan(plan); return plan.build(); } - - private List sqlToRelNode( - String sql, SqlValidator validator, Prepare.CatalogReader catalogReader) - throws SqlParseException { - SqlParser parser = SqlParser.create(sql, parserConfig); - var parsedList = parser.parseStmtList(); - SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); - List roots = - parsedList.stream() - .map(parsed -> getBestExpRelRoot(converter, parsed)) - .collect(java.util.stream.Collectors.toList()); - return roots; - } - - @VisibleForTesting - SqlToRelConverter createSqlToRelConverter( - SqlValidator validator, Prepare.CatalogReader catalogReader) { - SqlToRelConverter converter = - new SqlToRelConverter( - null, - validator, - catalogReader, - relOptCluster, - StandardConvertletTable.INSTANCE, - converterConfig); - return converter; - } - - @VisibleForTesting - static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) { - RelRoot root = converter.convertQuery(parsed, true, true); - { - var program = HepProgram.builder().build(); - HepPlanner hepPlanner = new HepPlanner(program); - hepPlanner.setRoot(root.rel); - root = root.withRel(hepPlanner.findBestExp()); - } - return root; - } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index a96185a22..d35e217d0 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -1,6 +1,7 @@ package io.substrait.isthmus; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.SubstraitTable; import io.substrait.plan.Plan; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; @@ -28,7 +29,7 @@ /** * Converts between Substrait {@link Rel}s and Calcite {@link RelNode}s. * - *

Can be extended to customize the {@link RelBuilder} and {@link SubstraitRelNodeConverter} used + *

Can be extended to customize the {@link RelBuilder} and {@link SubstraitToCalciteVisitor} used * in the conversion. */ public class SubstraitToCalcite { @@ -64,9 +65,8 @@ protected CalciteSchema toSchema(Rel rel) { if (table == null) { return null; } - return new SqlConverterBase.DefinedTable( + return new SubstraitTable( id.get(id.size() - 1), - typeFactory, typeConverter.toCalcite(typeFactory, table.struct(), table.names())); }; return LookupCalciteSchema.createRootSchema(lookup); @@ -82,12 +82,12 @@ protected RelBuilder createRelBuilder(CalciteSchema schema) { } /** - * Creates a {@link SubstraitRelNodeConverter} from the {@link RelBuilder} + * Creates a {@link SubstraitToCalciteVisitor} from the {@link RelBuilder} * - *

Override this method to customize the {@link SubstraitRelNodeConverter}. + *

Override this method to customize the {@link SubstraitToCalciteVisitor}. */ - protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { - return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder); + protected SubstraitToCalciteVisitor createSubstraitRelNodeConverter(RelBuilder relBuilder) { + return new SubstraitToCalciteVisitor(extensions, typeFactory, relBuilder); } /** @@ -95,7 +95,7 @@ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder r * *

Generates a {@link CalciteSchema} based on the contents of the {@link Rel}, which will be * used to construct a {@link RelBuilder} with the required schema information to build {@link - * RelNode}s, and a then a {@link SubstraitRelNodeConverter} to perform the actual conversion. + * RelNode}s, and a then a {@link SubstraitToCalciteVisitor} to perform the actual conversion. * * @param rel {@link Rel} to convert * @return {@link RelNode} @@ -103,7 +103,7 @@ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder r public RelNode convert(Rel rel) { CalciteSchema rootSchema = toSchema(rel); RelBuilder relBuilder = createRelBuilder(rootSchema); - SubstraitRelNodeConverter converter = createSubstraitRelNodeConverter(relBuilder); + SubstraitToCalciteVisitor converter = createSubstraitRelNodeConverter(relBuilder); return rel.accept(converter); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalciteVisitor.java similarity index 98% rename from isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java rename to isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalciteVisitor.java index 45b42f814..6e09e7a8a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalciteVisitor.java @@ -56,7 +56,7 @@ * RelVisitor to convert Substrait Rel plan to Calcite RelNode plan. Unsupported Rel node will call * visitFallback and throw UnsupportedOperationException. */ -public class SubstraitRelNodeConverter extends AbstractRelVisitor { +public class SubstraitToCalciteVisitor extends AbstractRelVisitor { protected final RelDataTypeFactory typeFactory; @@ -68,7 +68,7 @@ public class SubstraitRelNodeConverter extends AbstractRelVisitor tables) throws SqlParseException { CalciteCatalogReader catalogReader = registerCreateTables(tables); - return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, catalogReader, parserConfig); + return SubstraitToCalciteVisitor.convert(relRoot, relOptCluster, catalogReader, parserConfig); } public RelNode substraitRelToCalciteRel(Rel relRoot, Prepare.CatalogReader catalog) { - return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, catalog, parserConfig); - } - - // DEFAULT_SQL_DIALECT uses Calcite's EMPTY_CONTEXT with setting: - // identifierQuoteString : null, identifierEscapeQuoteString : null - // quotedCasing : UNCHANGED, unquotedCasing : TO_UPPER - // caseSensitive: true - // supportsApproxCountDistinct is true - private static final SqlDialect DEFAULT_SQL_DIALECT = - new SqlDialect(SqlDialect.EMPTY_CONTEXT) { - @Override - public boolean supportsApproxCountDistinct() { - return true; - } - }; - - public static String toSql(RelNode root) { - return toSql(root, DEFAULT_SQL_DIALECT); - } - - public static String toSql(RelNode root, SqlDialect dialect) { - return toSql( - root, - dialect, - c -> - c.withAlwaysUseParentheses(false) - .withSelectListItemsOnSeparateLines(false) - .withUpdateSetListNewline(false) - .withIndentation(0)); - } - - private static String toSql( - RelNode root, SqlDialect dialect, UnaryOperator transform) { - final RelToSqlConverter converter = new RelToSqlConverter(dialect); - final SqlNode sqlNode = converter.visitRoot(root).asStatement(); - return sqlNode.toSqlString(c -> transform.apply(c.withDialect(dialect))).getSql(); + return SubstraitToCalciteVisitor.convert(relRoot, relOptCluster, catalog, parserConfig); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitSchema.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitSchema.java new file mode 100644 index 000000000..34d75dc8a --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitSchema.java @@ -0,0 +1,23 @@ +package io.substrait.isthmus.calcite; + +import java.util.Map; +import org.apache.calcite.schema.Table; +import org.apache.calcite.schema.impl.AbstractSchema; + +/** + * Basic {@link AbstractSchema} implementation for associating table names to {@link Table} objects + */ +public class SubstraitSchema extends AbstractSchema { + + /** Maps of table names to their associated tables */ + protected final Map tableMap; + + public SubstraitSchema(Map tableMap) { + this.tableMap = tableMap; + } + + @Override + public Map getTableMap() { + return tableMap; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitTable.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitTable.java new file mode 100644 index 000000000..f642c73d8 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitTable.java @@ -0,0 +1,26 @@ +package io.substrait.isthmus.calcite; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.schema.impl.AbstractTable; + +/** Basic {@link AbstractTable} implementation */ +public class SubstraitTable extends AbstractTable { + + private final RelDataType rowType; + private final String tableName; + + public SubstraitTable(String tableName, RelDataType rowType) { + this.tableName = tableName; + this.rowType = rowType; + } + + public String getName() { + return tableName; + } + + @Override + public RelDataType getRowType(RelDataTypeFactory typeFactory) { + return rowType; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java index a62f7c0e7..2c766ed59 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java @@ -7,7 +7,7 @@ import io.substrait.expression.FunctionArg; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.AggregateFunctions; -import io.substrait.isthmus.SubstraitRelVisitor; +import io.substrait.isthmus.CalciteToSubstraitVisitor; import io.substrait.isthmus.TypeConverter; import io.substrait.type.Type; import java.util.Collections; @@ -59,7 +59,7 @@ protected AggregateFunctionInvocation generateBinding( List sorts = agg.getCollation() != null ? agg.getCollation().getFieldCollations().stream() - .map(r -> SubstraitRelVisitor.toSortField(r, call.inputType)) + .map(r -> CalciteToSubstraitVisitor.toSortField(r, call.inputType)) .collect(java.util.stream.Collectors.toList()) : Collections.emptyList(); Expression.AggregationInvocation invocation = diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index f7dd76f6e..27cdffafd 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -16,7 +16,7 @@ import io.substrait.expression.FunctionArg; import io.substrait.expression.WindowBound; import io.substrait.extension.SimpleExtension; -import io.substrait.isthmus.SubstraitRelNodeConverter; +import io.substrait.isthmus.SubstraitToCalciteVisitor; import io.substrait.isthmus.TypeConverter; import io.substrait.type.StringTypeVisitor; import io.substrait.type.Type; @@ -60,7 +60,7 @@ public class ExpressionRexConverter extends AbstractExpressionVisitor { org.slf4j.LoggerFactory.getLogger(RexExpressionConverter.class); private final List callConverters; - private final SubstraitRelVisitor relVisitor; + private final CalciteToSubstraitVisitor relVisitor; private final TypeConverter typeConverter; private WindowFunctionConverter windowFunctionConverter; - public RexExpressionConverter(SubstraitRelVisitor relVisitor, CallConverter... callConverters) { + public RexExpressionConverter( + CalciteToSubstraitVisitor relVisitor, CallConverter... callConverters) { this(relVisitor, Arrays.asList(callConverters), null, TypeConverter.DEFAULT); } @@ -49,7 +50,7 @@ public RexExpressionConverter(CallConverter... callConverters) { } public RexExpressionConverter( - SubstraitRelVisitor relVisitor, + CalciteToSubstraitVisitor relVisitor, List callConverters, WindowFunctionConverter windowFunctionConverter, TypeConverter typeConverter) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java new file mode 100644 index 000000000..0eb1cea89 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java @@ -0,0 +1,133 @@ +package io.substrait.isthmus.sql; + +import io.substrait.isthmus.SubstraitTypeSystem; +import io.substrait.isthmus.calcite.SubstraitTable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.apache.calcite.avatica.util.Casing; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.ddl.SqlColumnDeclaration; +import org.apache.calcite.sql.ddl.SqlCreateTable; +import org.apache.calcite.sql.ddl.SqlKeyConstraint; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.parser.ddl.SqlDdlParserImpl; +import org.apache.calcite.sql.validate.SqlConformanceEnum; +import org.apache.calcite.sql.validate.SqlValidator; + +/** Utility class for parsing CREATE statements into a {@link CalciteSchema} */ +public class SubstraitCreateStatementParser { + + private static final RelDataTypeFactory TYPE_FACTORY = + new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM); + + private static final CalciteConnectionConfig CONNECTION_CONFIG = + CalciteConnectionConfig.DEFAULT.set( + CalciteConnectionProperty.CASE_SENSITIVE, Boolean.FALSE.toString()); + + private static final SqlParser.Config PARSER_CONFIG = + SqlParser.config() + // To process CREATE statements we must use the SqlDdlParserImpl, as the default + // parser does not handle them + .withParserFactory(SqlDdlParserImpl.FACTORY) + .withUnquotedCasing(Casing.TO_UPPER) + .withConformance(SqlConformanceEnum.LENIENT); + + private static final CalciteCatalogReader EMPTY_CATALOG = + new CalciteCatalogReader( + CalciteSchema.createRootSchema(false), List.of(), TYPE_FACTORY, CONNECTION_CONFIG); + + // A validator is needed to convert the types in column declarations to Calcite types + private static final SqlValidator VALIDATOR = + new SubstraitSqlValidator( + // as we are validating CREATE statements, an empty catalog suffices + EMPTY_CATALOG); + + /** + * @param createStatements a SQL string containing only CREATE statements + * @return a list of {@link SubstraitTable}s generated from the CREATE statements + * @throws SqlParseException + */ + public static List processCreateStatements(String createStatements) + throws SqlParseException { + SqlParser parser = SqlParser.create(createStatements, PARSER_CONFIG); + List tableList = new ArrayList<>(); + + SqlNodeList sqlNode = parser.parseStmtList(); + for (SqlNode parsed : sqlNode) { + if (!(parsed instanceof SqlCreateTable create)) { + throw fail("Not a valid CREATE TABLE statement."); + } + + if (create.name.names.size() > 1) { + throw fail("Only simple table names are allowed.", create.name.getParserPosition()); + } + + if (create.query != null) { + throw fail("CTAS not supported.", create.name.getParserPosition()); + } + + List names = new ArrayList<>(); + List columnTypes = new ArrayList<>(); + + for (SqlNode node : create.columnList) { + if (!(node instanceof SqlColumnDeclaration col)) { + if (node instanceof SqlKeyConstraint) { + // key constraints declarations, like primary key declaration, are valid and should not + // result in parse exceptions. Ignore the constraint declaration. + continue; + } + + throw fail("Unexpected column list construction.", node.getParserPosition()); + } + + if (col.name.names.size() != 1) { + throw fail("Expected simple column names.", col.name.getParserPosition()); + } + + names.add(col.name.names.get(0)); + columnTypes.add(col.dataType.deriveType(VALIDATOR)); + } + + tableList.add( + new SubstraitTable( + create.name.names.get(0), TYPE_FACTORY.createStructType(columnTypes, names))); + } + + return tableList; + } + + /** + * @param createStatements a SQL string containing only CREATE statements + * @return a {@link CalciteCatalogReader} generated from the CREATE statements + * @throws SqlParseException + */ + public static CalciteCatalogReader processCreateStatementsToCatalog(String createStatements) + throws SqlParseException { + List tables = processCreateStatements(createStatements); + CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); + for (SubstraitTable table : tables) { + rootSchema.add(table.getName(), table); + } + List defaultSchema = Collections.emptyList(); + return new CalciteCatalogReader(rootSchema, defaultSchema, TYPE_FACTORY, CONNECTION_CONFIG); + } + + private static SqlParseException fail(String text, SqlParserPos pos) { + return new SqlParseException(text, pos, null, null, new RuntimeException("fake lineage")); + } + + private static SqlParseException fail(String text) { + return fail(text, SqlParserPos.ZERO); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSelectStatementParser.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSelectStatementParser.java new file mode 100644 index 000000000..812c4170c --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSelectStatementParser.java @@ -0,0 +1,26 @@ +package io.substrait.isthmus.sql; + +import java.util.List; +import org.apache.calcite.avatica.util.Casing; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.validate.SqlConformanceEnum; + +/** Utility class for parsing SELECT statements to {@link org.apache.calcite.rel.RelRoot}s */ +public class SubstraitSelectStatementParser { + + private static final SqlParser.Config PARSER_CONFIG = + SqlParser.config() + // TODO: switch to Casing.UNCHANGED + .withUnquotedCasing(Casing.TO_UPPER) + // use LENIENT conformance to allow for parsing a wide variety of dialects + .withConformance(SqlConformanceEnum.LENIENT); + + /** Parse one or more SELECT statements */ + public static List parseSelectStatements(String selectStatements) + throws SqlParseException { + SqlParser parser = SqlParser.create(selectStatements, PARSER_CONFIG); + return parser.parseStmtList(); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlDialect.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlDialect.java new file mode 100644 index 000000000..dd70db0cb --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlDialect.java @@ -0,0 +1,39 @@ +package io.substrait.isthmus.sql; + +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.rel2sql.RelToSqlConverter; +import org.apache.calcite.sql.SqlDialect; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.util.SqlString; + +/** + * {@link SqlDialect} used by Isthmus for parsing + * + *

Intended primarily for internal testing + */ +public class SubstraitSqlDialect extends SqlDialect { + + public static SqlDialect.Context DEFAULT_CONTEXT = SqlDialect.EMPTY_CONTEXT; + + public static SqlDialect DEFAULT = new SubstraitSqlDialect(DEFAULT_CONTEXT); + + public static SqlString toSql(RelNode relNode) { + RelToSqlConverter relToSql = new RelToSqlConverter(DEFAULT); + SqlNode sqlNode = relToSql.visitRoot(relNode).asStatement(); + return sqlNode.toSqlString( + c -> + c.withAlwaysUseParentheses(false) + .withSelectListItemsOnSeparateLines(false) + .withUpdateSetListNewline(false) + .withIndentation(0)); + } + + public SubstraitSqlDialect(Context context) { + super(context); + } + + @Override + public boolean supportsApproxCountDistinct() { + return true; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java new file mode 100644 index 000000000..280ff3d86 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java @@ -0,0 +1,79 @@ +package io.substrait.isthmus.sql; + +import io.substrait.isthmus.SubstraitTypeSystem; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.hep.HepPlanner; +import org.apache.calcite.plan.hep.HepProgram; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.sql2rel.StandardConvertletTable; + +public class SubstraitSqlToCalcite { + + public static RelRoot convertSelect(String selectStatement, Prepare.CatalogReader catalogReader) + throws SqlParseException { + return convertSelect(selectStatement, catalogReader, createRelOptCluster()); + } + + public static RelRoot convertSelect( + String selectStatement, Prepare.CatalogReader catalogReader, RelOptCluster cluster) + throws SqlParseException { + List sqlNodes = SubstraitSelectStatementParser.parseSelectStatements(selectStatement); + if (sqlNodes.size() != 1) { + throw new IllegalArgumentException( + String.format("Expected one SELECT statement, found: %d", sqlNodes.size())); + } + List relRoots = convert(sqlNodes, catalogReader, cluster); + // as there was only 1 select statement, there should only be 1 root + return relRoots.get(0); + } + + public static List convertSelects( + String selectStatements, Prepare.CatalogReader catalogReader) throws SqlParseException { + return convertSelects(selectStatements, catalogReader, createRelOptCluster()); + } + + public static List convertSelects( + String selectStatements, Prepare.CatalogReader catalogReader, RelOptCluster cluster) + throws SqlParseException { + List sqlNodes = SubstraitSelectStatementParser.parseSelectStatements(selectStatements); + return convert(sqlNodes, catalogReader, cluster); + } + + static List convert( + List selectStatements, Prepare.CatalogReader catalogReader, RelOptCluster cluster) { + RelOptTable.ViewExpander viewExpander = null; + SqlToRelConverter converter = + new SqlToRelConverter( + viewExpander, + new SubstraitSqlValidator(catalogReader), + catalogReader, + cluster, + StandardConvertletTable.INSTANCE, + SqlToRelConverter.CONFIG); + // apply validation + boolean needsValidation = true; + // query is the root of the tree + boolean top = true; + return selectStatements.stream() + .map(sqlNode -> converter.convertQuery(sqlNode, needsValidation, top)) + .collect(Collectors.toList()); + } + + static RelOptCluster createRelOptCluster() { + RexBuilder rexBuilder = + new RexBuilder(new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM)); + HepProgram program = HepProgram.builder().build(); + RelOptPlanner emptyPlanner = new HepPlanner(program); + return RelOptCluster.create(emptyPlanner, rexBuilder); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java new file mode 100644 index 000000000..eddcb1d0f --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java @@ -0,0 +1,15 @@ +package io.substrait.isthmus.sql; + +import io.substrait.isthmus.calcite.SubstraitOperatorTable; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorImpl; + +public class SubstraitSqlValidator extends SqlValidatorImpl { + + static SqlValidator.Config CONFIG = Config.DEFAULT; + + public SubstraitSqlValidator(Prepare.CatalogReader catalogReader) { + super(SubstraitOperatorTable.INSTANCE, catalogReader, catalogReader.getTypeFactory(), CONFIG); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java b/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java index ae449105a..6dfd34153 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java @@ -1,28 +1,35 @@ package io.substrait.isthmus; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; +import java.util.List; import java.util.Map; import org.apache.calcite.adapter.tpcds.TpcdsSchema; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.validate.SqlValidator; -import org.apache.calcite.sql2rel.SqlToRelConverter; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -public class ApplyJoinPlanTest { - - private static RelRoot getCalcitePlan(SqlToSubstrait s, TpcdsSchema schema, String sql) - throws SqlParseException { - CalciteCatalogReader catalogReader = s.registerSchema("tpcds", schema); - SqlConverterBase.Validator validator = - SqlConverterBase.Validator.create( - catalogReader.getTypeFactory(), catalogReader, SqlValidator.Config.DEFAULT); - SqlToRelConverter converter = s.createSqlToRelConverter(validator, catalogReader); - SqlParser parser = SqlParser.create(sql, s.parserConfig); - return s.getBestExpRelRoot(converter, parser.parseQuery()); +public class ApplyJoinPlanTest extends PlanTestBase { + static CalciteCatalogReader TPCDS_CATALOG; + + static { + TpcdsSchema tpcdsSchema = new TpcdsSchema(1.0); + CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); + rootSchema.add("tpcds", tpcdsSchema); + + TPCDS_CATALOG = + new CalciteCatalogReader( + rootSchema, + List.of("tpcds"), + new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM), + CalciteConnectionConfig.DEFAULT.set( + CalciteConnectionProperty.CASE_SENSITIVE, Boolean.FALSE.toString())); } private static void validateOuterRef( @@ -64,7 +71,7 @@ public void lateralJoinQuery() throws SqlParseException { */ // validate outer reference map - RelRoot root = getCalcitePlan(new SqlToSubstrait(), schema, sql); + RelRoot root = SubstraitSqlToCalcite.convertSelect(sql, TPCDS_CATALOG); Map fieldAccessDepthMap = buildOuterFieldRefMap(root); Assertions.assertEquals(1, fieldAccessDepthMap.size()); validateOuterRef(fieldAccessDepthMap, "$cor0", "SS_ITEM_SK", 1); @@ -79,26 +86,23 @@ public void lateralJoinQuery() throws SqlParseException { @Test public void outerApplyQuery() throws SqlParseException { - TpcdsSchema schema = new TpcdsSchema(1.0); String sql; sql = """ SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk FROM store_sales OUTER APPLY (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"""; - - FeatureBoard featureBoard = ImmutableFeatureBoard.builder().build(); - SqlToSubstrait s = new SqlToSubstrait(featureBoard); - RelRoot root = getCalcitePlan(s, schema, sql); + RelRoot root = SubstraitSqlToCalcite.convertSelect(sql, TPCDS_CATALOG); Map fieldAccessDepthMap = buildOuterFieldRefMap(root); Assertions.assertEquals(1, fieldAccessDepthMap.size()); validateOuterRef(fieldAccessDepthMap, "$cor0", "SS_ITEM_SK", 1); // TODO validate end to end conversion + SqlToSubstrait s = new SqlToSubstrait(); Assertions.assertThrows( UnsupportedOperationException.class, - () -> s.execute(sql, "tpcds", schema), + () -> s.execute(sql, TPCDS_CATALOG), "APPLY is not supported"); } @@ -129,9 +133,7 @@ public void nestedApplyJoinQuery() throws SqlParseException { LogicalFilter(condition=[AND(=($4, $cor0.I_ITEM_SK), =($4, $cor2.SS_ITEM_SK))]) LogicalTableScan(table=[[tpcds, PROMOTION]]) */ - FeatureBoard featureBoard = ImmutableFeatureBoard.builder().build(); - SqlToSubstrait s = new SqlToSubstrait(featureBoard); - RelRoot root = getCalcitePlan(s, schema, sql); + RelRoot root = SubstraitSqlToCalcite.convertSelect(sql, TPCDS_CATALOG); Map fieldAccessDepthMap = buildOuterFieldRefMap(root); Assertions.assertEquals(3, fieldAccessDepthMap.size()); @@ -140,15 +142,15 @@ public void nestedApplyJoinQuery() throws SqlParseException { validateOuterRef(fieldAccessDepthMap, "$cor0", "I_ITEM_SK", 1); // TODO validate end to end conversion + SqlToSubstrait s = new SqlToSubstrait(); Assertions.assertThrows( UnsupportedOperationException.class, - () -> s.execute(sql, "tpcds", schema), + () -> s.execute(sql, TPCDS_CATALOG), "APPLY is not supported"); } @Test - public void crossApplyQuery() throws SqlParseException { - TpcdsSchema schema = new TpcdsSchema(1.0); + public void crossApplyQuery() { String sql; sql = """ @@ -156,13 +158,12 @@ public void crossApplyQuery() throws SqlParseException { FROM store_sales CROSS APPLY (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"""; - FeatureBoard featureBoard = ImmutableFeatureBoard.builder().build(); - SqlToSubstrait s = new SqlToSubstrait(featureBoard); + SqlToSubstrait s = new SqlToSubstrait(); // TODO validate end to end conversion Assertions.assertThrows( UnsupportedOperationException.class, - () -> s.execute(sql, "tpcds", schema), + () -> s.execute(sql, TPCDS_CATALOG), "APPLY is not supported"); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java index 823ef6a48..2ea75bac2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java @@ -1,14 +1,12 @@ package io.substrait.isthmus; -import java.util.List; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; public class ArithmeticFunctionTest extends PlanTestBase { - static List CREATES = - List.of( - "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 REAL, fp64 DOUBLE)"); + static String CREATES = + "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 REAL, fp64 DOUBLE)"; @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index 4772468ad..8d99c95cd 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -260,8 +260,8 @@ public CustomSubstraitToCalcite( } @Override - protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { - return new SubstraitRelNodeConverter( + protected SubstraitToCalciteVisitor createSubstraitRelNodeConverter(RelBuilder relBuilder) { + return new SubstraitToCalciteVisitor( typeFactory, relBuilder, scalarFunctionConverter, @@ -275,8 +275,8 @@ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder r new CustomSubstraitToCalcite(extensionCollection, typeFactory, typeConverter); // Create a SubstraitRelVisitor that uses the custom Function Converters - final SubstraitRelVisitor calciteToSubstrait = - new SubstraitRelVisitor( + final CalciteToSubstraitVisitor calciteToSubstrait = + new CalciteToSubstraitVisitor( typeFactory, scalarFunctionConverter, aggregateFunctionConverter, diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index f8d6e9897..8147bc242 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -3,16 +3,21 @@ import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.Plan; import io.substrait.relation.NamedScan; import java.util.List; +import org.apache.calcite.prepare.CalciteCatalogReader; import org.junit.jupiter.api.Test; public class NameRoundtripTest extends PlanTestBase { @Test void preserveNamesFromSql() throws Exception { - List creates = List.of("CREATE TABLE foo(a BIGINT, b BIGINT)"); + String createStatement = "CREATE TABLE foo(a BIGINT, b BIGINT)"; + CalciteCatalogReader catalogReader = + SubstraitCreateStatementParser.processCreateStatementsToCatalog(createStatement); SqlToSubstrait s = new SqlToSubstrait(); var substraitToCalcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); @@ -22,14 +27,12 @@ void preserveNamesFromSql() throws Exception { """; List expectedNames = List.of("a", "B"); - List calciteRelRoots = s.sqlToRelNode(query, creates); - assertEquals(1, calciteRelRoots.size()); - - org.apache.calcite.rel.RelRoot calciteRelRoot1 = calciteRelRoots.get(0); + org.apache.calcite.rel.RelRoot calciteRelRoot1 = + SubstraitSqlToCalcite.convertSelect(query, catalogReader); assertEquals(expectedNames, calciteRelRoot1.validatedRowType.getFieldNames()); io.substrait.plan.Plan.Root substraitRelRoot = - SubstraitRelVisitor.convert(calciteRelRoot1, EXTENSION_COLLECTION); + CalciteToSubstraitVisitor.convert(calciteRelRoot1, EXTENSION_COLLECTION); assertEquals(expectedNames, substraitRelRoot.getNames()); org.apache.calcite.rel.RelRoot calciteRelRoot2 = substraitToCalcite.convert(substraitRelRoot); diff --git a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java index 1c3d923e2..4f7c45fd8 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java @@ -3,6 +3,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import com.google.protobuf.TextFormat; +import io.substrait.isthmus.calcite.SubstraitSchema; import io.substrait.plan.ProtoPlanConverter; import io.substrait.proto.Expression; import io.substrait.proto.Plan; @@ -13,7 +14,6 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.schema.Schema; import org.apache.calcite.schema.Table; -import org.apache.calcite.schema.impl.AbstractSchema; import org.apache.calcite.schema.impl.AbstractTable; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.type.SqlTypeName; @@ -56,14 +56,7 @@ RelDataType map(RelDataType key, RelDataType value) { private void test(Table table, String query, String expectedExpressionText) throws SqlParseException, IOException { - final Schema schema = - new AbstractSchema() { - @Override - protected Map getTableMap() { - return Map.of("my_table", table); - } - }; - + final Schema schema = new SubstraitSchema(Map.of("my_table", table)); final SqlToSubstrait sqlToSubstrait = new SqlToSubstrait(); Plan plan = sqlToSubstrait.execute(query, "nested", schema); Expression obtainedExpression = diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index 8545bf21a..4d6f41c06 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -2,10 +2,9 @@ import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import java.io.IOException; -import java.util.List; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; @@ -24,11 +23,8 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE // verify that the query works generally assertFullRoundTrip(query); - SqlToSubstrait sqlConverter = new SqlToSubstrait(); - List relRoots = sqlConverter.sqlToRelNode(query, tpchSchemaCreateStatements()); - assertEquals(1, relRoots.size()); - RelRoot planRoot = relRoots.get(0); - RelNode originalPlan = planRoot.rel; + RelRoot relRoot = SubstraitSqlToCalcite.convertSelect(query, TPCH_CATALOG); + RelNode originalPlan = relRoot.rel; // Create a program to apply the AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN rule. // This will introduce a SqlSumEmptyIsZeroAggFunction to the plan. @@ -46,6 +42,7 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE assertDoesNotThrow( () -> // Conversion of the new plan should succeed - SubstraitRelVisitor.convert(RelRoot.of(newPlan, planRoot.kind), EXTENSION_COLLECTION)); + CalciteToSubstraitVisitor.convert( + RelRoot.of(newPlan, relRoot.kind), EXTENSION_COLLECTION)); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 7c7770172..0aa1e55c0 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -11,6 +11,8 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.Plan; import io.substrait.plan.PlanProtoConverter; import io.substrait.plan.ProtoPlanConverter; @@ -22,6 +24,8 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.type.RelDataType; @@ -45,11 +49,16 @@ public static String asString(String resource) throws IOException { return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); } - public static List tpchSchemaCreateStatements() throws IOException { - String[] values = asString("tpch/schema.sql").split(";"); - return Arrays.stream(values) - .filter(t -> !t.trim().isBlank()) - .collect(java.util.stream.Collectors.toList()); + protected static CalciteCatalogReader TPCH_CATALOG; + + static { + try { + String tpchCreateStatements = asString("tpch/schema.sql"); + TPCH_CATALOG = + SubstraitCreateStatementParser.processCreateStatementsToCatalog(tpchCreateStatements); + } catch (IOException | SqlParseException e) { + throw new RuntimeException(e); + } } protected Plan assertProtoPlanRoundrip(String query) throws IOException, SqlParseException { @@ -58,19 +67,20 @@ protected Plan assertProtoPlanRoundrip(String query) throws IOException, SqlPars protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s) throws IOException, SqlParseException { - return assertProtoPlanRoundrip(query, s, tpchSchemaCreateStatements()); + return assertProtoPlanRoundrip(query, s, TPCH_CATALOG); } - protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, List creates) + protected Plan assertProtoPlanRoundrip( + String query, SqlToSubstrait s, Prepare.CatalogReader catalogReader) throws SqlParseException { - io.substrait.proto.Plan protoPlan1 = s.execute(query, creates); + io.substrait.proto.Plan protoPlan1 = s.execute(query, catalogReader); Plan plan = new ProtoPlanConverter(EXTENSION_COLLECTION).from(protoPlan1); io.substrait.proto.Plan protoPlan2 = new PlanProtoConverter().toProto(plan); assertEquals(protoPlan1, protoPlan2); - var rootRels = s.sqlToRelNode(query, creates); + var rootRels = SubstraitSqlToCalcite.convertSelects(query, catalogReader); assertEquals(rootRels.size(), plan.getRoots().size()); for (int i = 0; i < rootRels.size(); i++) { - Plan.Root rootRel = SubstraitRelVisitor.convert(rootRels.get(i), EXTENSION_COLLECTION); + Plan.Root rootRel = CalciteToSubstraitVisitor.convert(rootRels.get(i), EXTENSION_COLLECTION); assertEquals( rootRel.getInput().getRecordType(), plan.getRoots().get(i).getInput().getRecordType()); } @@ -85,11 +95,19 @@ protected void assertPlanRoundtrip(Plan plan) { } protected RelRoot assertSqlSubstraitRelRoundTrip(String query) throws Exception { - return assertSqlSubstraitRelRoundTrip(query, tpchSchemaCreateStatements()); + return assertSqlSubstraitRelRoundTrip(query, TPCH_CATALOG); } - protected RelRoot assertSqlSubstraitRelRoundTrip(String query, List creates) + protected RelRoot assertSqlSubstraitRelRoundTrip(String query, List createStatements) throws Exception { + CalciteCatalogReader catalogReader = + SubstraitCreateStatementParser.processCreateStatementsToCatalog( + String.join(";", createStatements)); + return assertSqlSubstraitRelRoundTrip(query, catalogReader); + } + + protected RelRoot assertSqlSubstraitRelRoundTrip( + String query, Prepare.CatalogReader catalogReader) throws Exception { // sql <--> substrait round trip test. // Assert (sql -> calcite -> substrait) and (sql -> substrait -> calcite -> substrait) are same. // Return list of sql -> Substrait rel -> Calcite rel. @@ -99,18 +117,16 @@ protected RelRoot assertSqlSubstraitRelRoundTrip(String query, List crea SqlToSubstrait s = new SqlToSubstrait(); // 1. SQL -> Calcite RelRoot - List relRoots = s.sqlToRelNode(query, creates); - assertEquals(1, relRoots.size()); - RelRoot relRoot1 = relRoots.get(0); + RelRoot relRoot1 = SubstraitSqlToCalcite.convertSelect(query, catalogReader); // 2. Calcite RelRoot -> Substrait Rel - Plan.Root pojo1 = SubstraitRelVisitor.convert(relRoot1, EXTENSION_COLLECTION); + Plan.Root pojo1 = CalciteToSubstraitVisitor.convert(relRoot1, EXTENSION_COLLECTION); // 3. Substrait Rel -> Calcite RelNode RelRoot relRoot2 = substraitToCalcite.convert(pojo1); // 4. Calcite RelNode -> Substrait Rel - Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, EXTENSION_COLLECTION); + Plan.Root pojo2 = CalciteToSubstraitVisitor.convert(relRoot2, EXTENSION_COLLECTION); Assertions.assertEquals(pojo1, pojo2); return relRoot2; @@ -118,7 +134,15 @@ protected RelRoot assertSqlSubstraitRelRoundTrip(String query, List crea @Beta protected void assertFullRoundTrip(String query) throws IOException, SqlParseException { - assertFullRoundTrip(query, tpchSchemaCreateStatements()); + assertFullRoundTrip(query, TPCH_CATALOG); + } + + @Beta + protected void assertFullRoundTrip(String query, String createStatements) + throws IOException, SqlParseException { + CalciteCatalogReader catalogReader = + SubstraitCreateStatementParser.processCreateStatementsToCatalog(createStatements); + assertFullRoundTrip(query, catalogReader); } /** @@ -134,18 +158,15 @@ protected void assertFullRoundTrip(String query) throws IOException, SqlParseExc *

  • Substrait POJO 2 == Substrait POJO 3 * */ - protected void assertFullRoundTrip(String sqlQuery, List createStatements) + protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalogReader) throws SqlParseException { - SqlToSubstrait sqlConverter = new SqlToSubstrait(); ExtensionCollector extensionCollector = new ExtensionCollector(); // SQL -> Calcite 1 - List relRoots = sqlConverter.sqlToRelNode(sqlQuery, createStatements); - assertEquals(1, relRoots.size()); - RelRoot calcite1 = relRoots.get(0); + RelRoot calcite1 = SubstraitSqlToCalcite.convertSelect(sqlQuery, catalogReader); // Calcite 1 -> Substrait POJO 1 - Plan.Root pojo1 = SubstraitRelVisitor.convert(calcite1, EXTENSION_COLLECTION); + Plan.Root pojo1 = CalciteToSubstraitVisitor.convert(calcite1, EXTENSION_COLLECTION); // Substrait POJO 1 -> Substrait Proto io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(pojo1); @@ -163,7 +184,7 @@ protected void assertFullRoundTrip(String sqlQuery, List createStatement assertNotNull(calcite2); // Calcite 2 -> Substrait POJO 3 - Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite2, EXTENSION_COLLECTION); + Plan.Root pojo3 = CalciteToSubstraitVisitor.convert(calcite2, EXTENSION_COLLECTION); // Verify that POJOs are the same assertEquals(pojo1, pojo3); @@ -195,7 +216,8 @@ protected void assertFullRoundTrip(Rel pojo1) { RelNode calcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2); // Calcite -> Substrait POJO 3 - io.substrait.relation.Rel pojo3 = SubstraitRelVisitor.convert(calcite, EXTENSION_COLLECTION); + io.substrait.relation.Rel pojo3 = + CalciteToSubstraitVisitor.convert(calcite, EXTENSION_COLLECTION); // Verify that POJOs are the same assertEquals(pojo1, pojo3); @@ -226,7 +248,8 @@ protected void assertFullRoundTrip(Plan.Root pojo1) { RelRoot calcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2); // Calcite -> Substrait POJO 3 - io.substrait.plan.Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite, EXTENSION_COLLECTION); + io.substrait.plan.Plan.Root pojo3 = + CalciteToSubstraitVisitor.convert(calcite, EXTENSION_COLLECTION); // Verify that POJOs are the same assertEquals(pojo1, pojo3); diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java index fe3eac106..6d222c9cf 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java @@ -7,6 +7,7 @@ import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.sql.SubstraitSqlDialect; import io.substrait.plan.Plan; import io.substrait.plan.ProtoPlanConverter; import io.substrait.relation.Aggregate; @@ -139,7 +140,7 @@ public void approximateCountDistinct() throws IOException, SqlParseException { .filter(t -> !t.trim().isBlank()) .collect(java.util.stream.Collectors.toList()); RelNode relnodeRoot = new SubstraitToSql().substraitRelToCalciteRel(pojoRel, creates); - String newSql = SubstraitToSql.toSql(relnodeRoot); + String newSql = SubstraitSqlDialect.toSql(relnodeRoot).getSql(); assertTrue(newSql.toUpperCase().contains("APPROX_COUNT_DISTINCT")); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java index c71f1539f..08ca00d1f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java @@ -68,10 +68,10 @@ void roundtrip(Rel pojo1) { // Substrait POJO 2 -> Calcite var calcite = - pojo2.accept(new CustomSubstraitRelNodeConverter(extensions, typeFactory, builder)); + pojo2.accept(new CustomSubstraitToCalciteVisitor(extensions, typeFactory, builder)); // Calcite -> Substrait POJO 3 - var pojo3 = (new CustomSubstraitRelVisitor(typeFactory, extensions)).apply(calcite); + var pojo3 = (new CustomCalciteToSubstraitVisitor(typeFactory, extensions)).apply(calcite); assertEquals(pojo1, pojo3); } @@ -180,12 +180,12 @@ protected Extension.MultiRelDetail detailFromExtensionMultiRel(Any any) { } /** - * Extends the standard {@link SubstraitRelNodeConverter} to handle Extension relations containing + * Extends the standard {@link SubstraitToCalciteVisitor} to handle Extension relations containing * {@link ColumnAppendDetail} */ - static class CustomSubstraitRelNodeConverter extends SubstraitRelNodeConverter { + static class CustomSubstraitToCalciteVisitor extends SubstraitToCalciteVisitor { - public CustomSubstraitRelNodeConverter( + public CustomSubstraitToCalciteVisitor( SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory, RelBuilder relBuilder) { @@ -232,10 +232,12 @@ public RelNode visit(ExtensionMulti extensionMulti) throws RuntimeException { } } - /** Extends the standard {@link SubstraitRelVisitor} to handle the {@link ColumnAppenderRel} */ - static class CustomSubstraitRelVisitor extends SubstraitRelVisitor { + /** + * Extends the standard {@link CalciteToSubstraitVisitor} to handle the {@link ColumnAppenderRel} + */ + static class CustomCalciteToSubstraitVisitor extends CalciteToSubstraitVisitor { - public CustomSubstraitRelVisitor( + public CustomCalciteToSubstraitVisitor( RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) { super(typeFactory, extensions); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java index 090fc98a8..bc055a5ee 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java @@ -37,8 +37,8 @@ public class SubstraitExpressionConverterTest extends PlanTestBase { final Rel commonTable = b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType); - final SubstraitRelNodeConverter relNodeConverter = - new SubstraitRelNodeConverter(extensions, typeFactory, builder); + final SubstraitToCalciteVisitor relNodeConverter = + new SubstraitToCalciteVisitor(extensions, typeFactory, builder); public SubstraitExpressionConverterTest() { converter = relNodeConverter.expressionRexConverter; diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteVisitorTest.java similarity index 99% rename from isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java rename to isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteVisitorTest.java index 3351c8b7d..db5386503 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteVisitorTest.java @@ -15,7 +15,7 @@ import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; -public class SubstraitRelNodeConverterTest extends PlanTestBase { +public class SubstraitToCalciteVisitorTest extends PlanTestBase { static final TypeCreator R = TypeCreator.of(false); static final TypeCreator N = TypeCreator.of(true);