Skip to content

refactor: modularize and re-use conversion APIs #366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@

@SuppressWarnings("UnstableApiUsage")
@Value.Enclosing
public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {
public class CalciteToSubstraitVisitor extends RelNodeVisitor<Rel, RuntimeException> {

static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(SubstraitRelVisitor.class);
org.slf4j.LoggerFactory.getLogger(CalciteToSubstraitVisitor.class);
private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool(false, true);

Expand All @@ -57,12 +57,12 @@ public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {
protected final FeatureBoard featureBoard;
private Map<RexFieldAccess, Integer> fieldAccessDepthMap;

public SubstraitRelVisitor(
public CalciteToSubstraitVisitor(
RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) {
this(typeFactory, extensions, FEATURES_DEFAULT);
}

public SubstraitRelVisitor(
public CalciteToSubstraitVisitor(
RelDataTypeFactory typeFactory,
SimpleExtension.ExtensionCollection extensions,
FeatureBoard features) {
Expand All @@ -80,7 +80,7 @@ public SubstraitRelVisitor(
this.featureBoard = features;
}

public SubstraitRelVisitor(
public CalciteToSubstraitVisitor(
RelDataTypeFactory typeFactory,
ScalarFunctionConverter scalarFunctionConverter,
AggregateFunctionConverter aggregateFunctionConverter,
Expand Down Expand Up @@ -386,8 +386,9 @@ public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollec

public static Plan.Root convert(
RelRoot relRoot, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
SubstraitRelVisitor visitor =
new SubstraitRelVisitor(relRoot.rel.getCluster().getTypeFactory(), extensions, features);
CalciteToSubstraitVisitor visitor =
new CalciteToSubstraitVisitor(
relRoot.rel.getCluster().getTypeFactory(), extensions, features);
visitor.popFieldAccessDepthMap(relRoot.rel);
Rel rel = visitor.apply(relRoot.project());

Expand All @@ -403,8 +404,8 @@ public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection e

public static Rel convert(
RelNode relNode, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
SubstraitRelVisitor visitor =
new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensions, features);
CalciteToSubstraitVisitor visitor =
new CalciteToSubstraitVisitor(relNode.getCluster().getTypeFactory(), extensions, features);
visitor.popFieldAccessDepthMap(relNode);
return visitor.apply(relNode);
}
Expand Down
125 changes: 5 additions & 120 deletions isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package io.substrait.isthmus;

import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.calcite.SubstraitOperatorTable;
import java.util.ArrayList;
import io.substrait.isthmus.calcite.SubstraitTable;
import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
import java.util.List;
import org.apache.calcite.config.CalciteConnectionConfig;
import org.apache.calcite.config.CalciteConnectionProperty;
Expand All @@ -16,25 +16,13 @@
import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider;
import org.apache.calcite.rel.metadata.ProxyingMetadataHandlerProvider;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.impl.AbstractTable;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.ddl.SqlColumnDeclaration;
import org.apache.calcite.sql.ddl.SqlCreateTable;
import org.apache.calcite.sql.ddl.SqlKeyConstraint;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.parser.ddl.SqlDdlParserImpl;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorCatalogReader;
import org.apache.calcite.sql.validate.SqlValidatorImpl;
import org.apache.calcite.sql2rel.SqlToRelConverter;

class SqlConverterBase {
Expand Down Expand Up @@ -76,11 +64,11 @@ CalciteCatalogReader registerCreateTables(List<String> tables) throws SqlParseEx
CalciteSchema rootSchema = CalciteSchema.createRootSchema(false);
CalciteCatalogReader catalogReader =
new CalciteCatalogReader(rootSchema, List.of(), factory, config);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
if (tables != null) {
for (String tableDef : tables) {
List<DefinedTable> tList = parseCreateTable(factory, validator, tableDef);
for (DefinedTable t : tList) {
List<SubstraitTable> tList =
SubstraitCreateStatementParser.processCreateStatements(tableDef);
for (SubstraitTable t : tList) {
rootSchema.add(t.getName(), t);
}
}
Expand All @@ -96,107 +84,4 @@ CalciteCatalogReader registerSchema(String name, Schema schema) {
}
return new CalciteCatalogReader(rootSchema, List.of(), factory, config);
}

protected List<DefinedTable> parseCreateTable(
RelDataTypeFactory factory, SqlValidator validator, String sql) throws SqlParseException {
SqlParser parser = SqlParser.create(sql, parserConfig);
List<DefinedTable> definedTableList = new ArrayList<>();

SqlNodeList nodeList = parser.parseStmtList();
for (SqlNode parsed : nodeList) {
if (!(parsed instanceof SqlCreateTable)) {
throw fail("Not a valid CREATE TABLE statement.");
}

SqlCreateTable create = (SqlCreateTable) parsed;
if (create.name.names.size() > 1) {
throw fail("Only simple table names are allowed.", create.name.getParserPosition());
}

if (create.query != null) {
throw fail("CTAS not supported.", create.name.getParserPosition());
}

List<String> names = new ArrayList<>();
List<RelDataType> columnTypes = new ArrayList<>();

for (SqlNode node : create.columnList) {
if (!(node instanceof SqlColumnDeclaration)) {
if (node instanceof SqlKeyConstraint) {
// key constraints declarations, like primary key declaration, are valid and should not
// result in parse exceptions. Ignore the constraint declaration.
continue;
}

throw fail("Unexpected column list construction.", node.getParserPosition());
}

SqlColumnDeclaration col = (SqlColumnDeclaration) node;
if (col.name.names.size() != 1) {
throw fail("Expected simple column names.", col.name.getParserPosition());
}

names.add(col.name.names.get(0));
columnTypes.add(col.dataType.deriveType(validator));
}

definedTableList.add(
new DefinedTable(
create.name.names.get(0), factory, factory.createStructType(columnTypes, names)));
}

return definedTableList;
}

protected static SqlParseException fail(String text, SqlParserPos pos) {
return new SqlParseException(text, pos, null, null, new RuntimeException("fake lineage"));
}

protected static SqlParseException fail(String text) {
return fail(text, SqlParserPos.ZERO);
}

protected static final class Validator extends SqlValidatorImpl {

private Validator(
SqlOperatorTable opTab,
SqlValidatorCatalogReader catalogReader,
RelDataTypeFactory typeFactory,
Config config) {
super(opTab, catalogReader, typeFactory, config);
}

public static Validator create(
RelDataTypeFactory factory,
SqlValidatorCatalogReader validatorCatalog,
SqlValidator.Config config) {
return new Validator(SubstraitOperatorTable.INSTANCE, validatorCatalog, factory, config);
}
}

/** A fully defined pre-specified table. */
protected static final class DefinedTable extends AbstractTable {

private final String name;
private final RelDataTypeFactory factory;
private final RelDataType type;

public DefinedTable(String name, RelDataTypeFactory factory, RelDataType type) {
this.name = name;
this.factory = factory;
this.type = type;
}

@Override
public RelDataType getRowType(RelDataTypeFactory typeFactory) {
// if (factory != typeFactory) {
// throw new IllegalStateException("Different type factory than previously used.");
// }
return type;
}

public String getName() {
return name;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import io.substrait.extendedexpression.ImmutableExpressionReference;
import io.substrait.extendedexpression.ImmutableExtendedExpression;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.calcite.SubstraitTable;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
import io.substrait.isthmus.sql.SubstraitSqlValidator;
import io.substrait.proto.ExtendedExpression;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
Expand Down Expand Up @@ -142,11 +145,11 @@ private Result registerCreateTablesForExtendedExpression(List<String> tables)
CalciteSchema rootSchema = CalciteSchema.createRootSchema(false);
CalciteCatalogReader catalogReader =
new CalciteCatalogReader(rootSchema, List.of(), factory, config);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
if (tables != null) {
for (String tableDef : tables) {
List<DefinedTable> tList = parseCreateTable(factory, validator, tableDef);
for (DefinedTable t : tList) {
List<SubstraitTable> tList =
SubstraitCreateStatementParser.processCreateStatements(tableDef);
for (SubstraitTable t : tList) {
rootSchema.add(t.getName(), t);
for (RelDataTypeField field : t.getRowType(factory).getFieldList()) {
nameToTypeMap.merge( // to validate the sql expression tree
Expand All @@ -167,6 +170,7 @@ private Result registerCreateTablesForExtendedExpression(List<String> tables)
}
}
}
SqlValidator validator = new SubstraitSqlValidator(catalogReader);
return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap);
}

Expand Down
90 changes: 18 additions & 72 deletions isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
package io.substrait.isthmus;

import com.google.common.annotations.VisibleForTesting;
import io.substrait.extension.ExtensionCollector;
import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import io.substrait.proto.Plan;
import io.substrait.proto.PlanRel;
import io.substrait.relation.RelProtoConverter;
import java.util.List;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.prepare.CalciteCatalogReader;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.sql2rel.StandardConvertletTable;

/** Take a SQL statement and a set of table definitions and return a substrait plan. */
public class SqlToSubstrait extends SqlConverterBase {
Expand All @@ -32,84 +25,37 @@ public SqlToSubstrait(FeatureBoard features) {

public Plan execute(String sql, List<String> tables) throws SqlParseException {
CalciteCatalogReader catalogReader = registerCreateTables(tables);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
return executeInner(sql, validator, catalogReader);
return executeInner(sql, catalogReader);
}

public Plan execute(String sql, String name, Schema schema) throws SqlParseException {
CalciteCatalogReader catalogReader = registerSchema(name, schema);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
return executeInner(sql, validator, catalogReader);
return executeInner(sql, catalogReader);
}

public Plan execute(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException {
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
return executeInner(sql, validator, catalogReader);
return executeInner(sql, catalogReader);
}

// Package protected for testing
List<RelRoot> sqlToRelNode(String sql, List<String> tables) throws SqlParseException {
Prepare.CatalogReader catalogReader = registerCreateTables(tables);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
return sqlToRelNode(sql, validator, catalogReader);
}

private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogReader catalogReader)
private Plan executeInner(String sql, Prepare.CatalogReader catalogReader)
throws SqlParseException {
var plan = Plan.newBuilder();
ExtensionCollector functionCollector = new ExtensionCollector();
var relProtoConverter = new RelProtoConverter(functionCollector);

List<RelRoot> relRoots = SubstraitSqlToCalcite.convertSelects(sql, catalogReader);

// TODO: consider case in which one sql passes conversion while others don't
sqlToRelNode(sql, validator, catalogReader)
.forEach(
root -> {
plan.addRelations(
PlanRel.newBuilder()
.setRoot(
relProtoConverter.toProto(
SubstraitRelVisitor.convert(
root, EXTENSION_COLLECTION, featureBoard))));
});
Plan.Builder plan = Plan.newBuilder();
relRoots.forEach(
root -> {
plan.addRelations(
PlanRel.newBuilder()
.setRoot(
relProtoConverter.toProto(
CalciteToSubstraitVisitor.convert(
root, EXTENSION_COLLECTION, featureBoard))));
});
functionCollector.addExtensionsToPlan(plan);
return plan.build();
}

private List<RelRoot> sqlToRelNode(
String sql, SqlValidator validator, Prepare.CatalogReader catalogReader)
throws SqlParseException {
SqlParser parser = SqlParser.create(sql, parserConfig);
var parsedList = parser.parseStmtList();
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
List<RelRoot> roots =
parsedList.stream()
.map(parsed -> getBestExpRelRoot(converter, parsed))
.collect(java.util.stream.Collectors.toList());
return roots;
}

@VisibleForTesting
SqlToRelConverter createSqlToRelConverter(
SqlValidator validator, Prepare.CatalogReader catalogReader) {
SqlToRelConverter converter =
new SqlToRelConverter(
null,
validator,
catalogReader,
relOptCluster,
StandardConvertletTable.INSTANCE,
converterConfig);
return converter;
}

@VisibleForTesting
static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) {
RelRoot root = converter.convertQuery(parsed, true, true);
{
var program = HepProgram.builder().build();
HepPlanner hepPlanner = new HepPlanner(program);
hepPlanner.setRoot(root.rel);
root = root.withRel(hepPlanner.findBestExp());
}
return root;
}
}
Loading
Loading