Skip to content

refactor: modularize and re-use conversion APIs #366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@

@SuppressWarnings("UnstableApiUsage")
@Value.Enclosing
public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {
public class CalciteToSubstraitVisitor extends RelNodeVisitor<Rel, RuntimeException> {

static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(SubstraitRelVisitor.class);
org.slf4j.LoggerFactory.getLogger(CalciteToSubstraitVisitor.class);
private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool(false, true);

Expand All @@ -57,12 +57,12 @@ public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {
protected final FeatureBoard featureBoard;
private Map<RexFieldAccess, Integer> fieldAccessDepthMap;

public SubstraitRelVisitor(
public CalciteToSubstraitVisitor(
RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) {
this(typeFactory, extensions, FEATURES_DEFAULT);
}

public SubstraitRelVisitor(
public CalciteToSubstraitVisitor(
RelDataTypeFactory typeFactory,
SimpleExtension.ExtensionCollection extensions,
FeatureBoard features) {
Expand All @@ -80,7 +80,7 @@ public SubstraitRelVisitor(
this.featureBoard = features;
}

public SubstraitRelVisitor(
public CalciteToSubstraitVisitor(
RelDataTypeFactory typeFactory,
ScalarFunctionConverter scalarFunctionConverter,
AggregateFunctionConverter aggregateFunctionConverter,
Expand Down Expand Up @@ -386,8 +386,9 @@ public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollec

public static Plan.Root convert(
RelRoot relRoot, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
SubstraitRelVisitor visitor =
new SubstraitRelVisitor(relRoot.rel.getCluster().getTypeFactory(), extensions, features);
CalciteToSubstraitVisitor visitor =
new CalciteToSubstraitVisitor(
relRoot.rel.getCluster().getTypeFactory(), extensions, features);
visitor.popFieldAccessDepthMap(relRoot.rel);
Rel rel = visitor.apply(relRoot.project());

Expand All @@ -403,8 +404,8 @@ public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection e

public static Rel convert(
RelNode relNode, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
SubstraitRelVisitor visitor =
new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensions, features);
CalciteToSubstraitVisitor visitor =
new CalciteToSubstraitVisitor(relNode.getCluster().getTypeFactory(), extensions, features);
visitor.popFieldAccessDepthMap(relNode);
return visitor.apply(relNode);
}
Expand Down
95 changes: 3 additions & 92 deletions isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package io.substrait.isthmus;

import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.calcite.SubstraitOperatorTable;
import io.substrait.isthmus.calcite.SubstraitTable;
import java.util.ArrayList;
import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
import java.util.List;
import org.apache.calcite.config.CalciteConnectionConfig;
import org.apache.calcite.config.CalciteConnectionProperty;
Expand All @@ -17,24 +16,13 @@
import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider;
import org.apache.calcite.rel.metadata.ProxyingMetadataHandlerProvider;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.ddl.SqlColumnDeclaration;
import org.apache.calcite.sql.ddl.SqlCreateTable;
import org.apache.calcite.sql.ddl.SqlKeyConstraint;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.parser.ddl.SqlDdlParserImpl;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorCatalogReader;
import org.apache.calcite.sql.validate.SqlValidatorImpl;
import org.apache.calcite.sql2rel.SqlToRelConverter;

class SqlConverterBase {
Expand Down Expand Up @@ -76,10 +64,10 @@ CalciteCatalogReader registerCreateTables(List<String> tables) throws SqlParseEx
CalciteSchema rootSchema = CalciteSchema.createRootSchema(false);
CalciteCatalogReader catalogReader =
new CalciteCatalogReader(rootSchema, List.of(), factory, config);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
if (tables != null) {
for (String tableDef : tables) {
List<SubstraitTable> tList = parseCreateTable(factory, validator, tableDef);
List<SubstraitTable> tList =
SubstraitCreateStatementParser.processCreateStatements(tableDef);
for (SubstraitTable t : tList) {
rootSchema.add(t.getName(), t);
}
Expand All @@ -96,81 +84,4 @@ CalciteCatalogReader registerSchema(String name, Schema schema) {
}
return new CalciteCatalogReader(rootSchema, List.of(), factory, config);
}

protected List<SubstraitTable> parseCreateTable(
RelDataTypeFactory factory, SqlValidator validator, String sql) throws SqlParseException {
SqlParser parser = SqlParser.create(sql, parserConfig);
List<SubstraitTable> tableList = new ArrayList<>();

SqlNodeList nodeList = parser.parseStmtList();
for (SqlNode parsed : nodeList) {
if (!(parsed instanceof SqlCreateTable)) {
throw fail("Not a valid CREATE TABLE statement.");
}

SqlCreateTable create = (SqlCreateTable) parsed;
if (create.name.names.size() > 1) {
throw fail("Only simple table names are allowed.", create.name.getParserPosition());
}

if (create.query != null) {
throw fail("CTAS not supported.", create.name.getParserPosition());
}

List<String> names = new ArrayList<>();
List<RelDataType> columnTypes = new ArrayList<>();

for (SqlNode node : create.columnList) {
if (!(node instanceof SqlColumnDeclaration)) {
if (node instanceof SqlKeyConstraint) {
// key constraints declarations, like primary key declaration, are valid and should not
// result in parse exceptions. Ignore the constraint declaration.
continue;
}

throw fail("Unexpected column list construction.", node.getParserPosition());
}

SqlColumnDeclaration col = (SqlColumnDeclaration) node;
if (col.name.names.size() != 1) {
throw fail("Expected simple column names.", col.name.getParserPosition());
}

names.add(col.name.names.get(0));
columnTypes.add(col.dataType.deriveType(validator));
}

tableList.add(
new SubstraitTable(
create.name.names.get(0), factory.createStructType(columnTypes, names)));
}

return tableList;
}

protected static SqlParseException fail(String text, SqlParserPos pos) {
return new SqlParseException(text, pos, null, null, new RuntimeException("fake lineage"));
}

protected static SqlParseException fail(String text) {
return fail(text, SqlParserPos.ZERO);
}

protected static final class Validator extends SqlValidatorImpl {

private Validator(
SqlOperatorTable opTab,
SqlValidatorCatalogReader catalogReader,
RelDataTypeFactory typeFactory,
Config config) {
super(opTab, catalogReader, typeFactory, config);
}

public static Validator create(
RelDataTypeFactory factory,
SqlValidatorCatalogReader validatorCatalog,
SqlValidator.Config config) {
return new Validator(SubstraitOperatorTable.INSTANCE, validatorCatalog, factory, config);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import io.substrait.isthmus.calcite.SubstraitTable;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
import io.substrait.isthmus.sql.SubstraitSqlValidator;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import java.util.ArrayList;
Expand Down Expand Up @@ -140,10 +142,10 @@ private Result registerCreateTablesForExtendedExpression(List<String> tables)
CalciteSchema rootSchema = CalciteSchema.createRootSchema(false);
CalciteCatalogReader catalogReader =
new CalciteCatalogReader(rootSchema, List.of(), factory, config);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
if (tables != null) {
for (String tableDef : tables) {
List<SubstraitTable> tList = parseCreateTable(factory, validator, tableDef);
List<SubstraitTable> tList =
SubstraitCreateStatementParser.processCreateStatements(tableDef);
for (SubstraitTable t : tList) {
rootSchema.add(t.getName(), t);
for (RelDataTypeField field : t.getRowType(factory).getFieldList()) {
Expand All @@ -165,6 +167,7 @@ private Result registerCreateTablesForExtendedExpression(List<String> tables)
}
}
}
SqlValidator validator = new SubstraitSqlValidator(catalogReader);
return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap);
}

Expand Down
72 changes: 7 additions & 65 deletions isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
package io.substrait.isthmus;

import com.google.common.annotations.VisibleForTesting;
import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import io.substrait.plan.Plan.Version;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.proto.Plan;
import java.util.List;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.prepare.CalciteCatalogReader;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.sql2rel.StandardConvertletTable;

/** Take a SQL statement and a set of table definitions and return a substrait plan. */
public class SqlToSubstrait extends SqlConverterBase {
Expand All @@ -31,79 +23,29 @@ public SqlToSubstrait(FeatureBoard features) {

public Plan execute(String sql, List<String> tables) throws SqlParseException {
CalciteCatalogReader catalogReader = registerCreateTables(tables);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
return executeInner(sql, validator, catalogReader);
return executeInner(sql, catalogReader);
}

public Plan execute(String sql, String name, Schema schema) throws SqlParseException {
CalciteCatalogReader catalogReader = registerSchema(name, schema);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
return executeInner(sql, validator, catalogReader);
return executeInner(sql, catalogReader);
}

public Plan execute(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException {
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
return executeInner(sql, validator, catalogReader);
return executeInner(sql, catalogReader);
}

// Package protected for testing
List<RelRoot> sqlToRelNode(String sql, List<String> tables) throws SqlParseException {
Prepare.CatalogReader catalogReader = registerCreateTables(tables);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
return sqlToRelNode(sql, validator, catalogReader);
}

private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogReader catalogReader)
private Plan executeInner(String sql, Prepare.CatalogReader catalogReader)
throws SqlParseException {
var builder = io.substrait.plan.Plan.builder();
builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build());

// TODO: consider case in which one sql passes conversion while others don't
sqlToRelNode(sql, validator, catalogReader).stream()
.map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard))
SubstraitSqlToCalcite.convertSelects(sql, catalogReader).stream()
.map(root -> CalciteToSubstraitVisitor.convert(root, EXTENSION_COLLECTION, featureBoard))
.forEach(root -> builder.addRoots(root));

PlanProtoConverter planToProto = new PlanProtoConverter();

return planToProto.toProto(builder.build());
}

private List<RelRoot> sqlToRelNode(
String sql, SqlValidator validator, Prepare.CatalogReader catalogReader)
throws SqlParseException {
SqlParser parser = SqlParser.create(sql, parserConfig);
var parsedList = parser.parseStmtList();
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
List<RelRoot> roots =
parsedList.stream()
.map(parsed -> getBestExpRelRoot(converter, parsed))
.collect(java.util.stream.Collectors.toList());
return roots;
}

@VisibleForTesting
SqlToRelConverter createSqlToRelConverter(
SqlValidator validator, Prepare.CatalogReader catalogReader) {
SqlToRelConverter converter =
new SqlToRelConverter(
null,
validator,
catalogReader,
relOptCluster,
StandardConvertletTable.INSTANCE,
converterConfig);
return converter;
}

@VisibleForTesting
static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) {
RelRoot root = converter.convertQuery(parsed, true, true);
{
var program = HepProgram.builder().build();
HepPlanner hepPlanner = new HepPlanner(program);
hepPlanner.setRoot(root.rel);
root = root.withRel(hepPlanner.findBestExp());
}
return root;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
/**
* Converts between Substrait {@link Rel}s and Calcite {@link RelNode}s.
*
* <p>Can be extended to customize the {@link RelBuilder} and {@link SubstraitRelNodeConverter} used
* <p>Can be extended to customize the {@link RelBuilder} and {@link SubstraitToCalciteVisitor} used
* in the conversion.
*/
public class SubstraitToCalcite {
Expand Down Expand Up @@ -68,28 +68,28 @@ protected RelBuilder createRelBuilder(CalciteSchema schema) {
}

/**
* Creates a {@link SubstraitRelNodeConverter} from the {@link RelBuilder}
* Creates a {@link SubstraitToCalciteVisitor} from the {@link RelBuilder}
*
* <p>Override this method to customize the {@link SubstraitRelNodeConverter}.
* <p>Override this method to customize the {@link SubstraitToCalciteVisitor}.
*/
protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) {
return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder);
protected SubstraitToCalciteVisitor createSubstraitRelNodeConverter(RelBuilder relBuilder) {
return new SubstraitToCalciteVisitor(extensions, typeFactory, relBuilder);
}

/**
* Converts a Substrait {@link Rel} to a Calcite {@link RelNode}
*
* <p>Generates a {@link CalciteSchema} based on the contents of the {@link Rel}, which will be
* used to construct a {@link RelBuilder} with the required schema information to build {@link
* RelNode}s, and a then a {@link SubstraitRelNodeConverter} to perform the actual conversion.
* RelNode}s, and a then a {@link SubstraitToCalciteVisitor} to perform the actual conversion.
*
* @param rel {@link Rel} to convert
* @return {@link RelNode}
*/
public RelNode convert(Rel rel) {
CalciteSchema rootSchema = toSchema(rel);
RelBuilder relBuilder = createRelBuilder(rootSchema);
SubstraitRelNodeConverter converter = createSubstraitRelNodeConverter(relBuilder);
SubstraitToCalciteVisitor converter = createSubstraitRelNodeConverter(relBuilder);
return rel.accept(converter);
}

Expand Down
Loading
Loading