Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 18 additions & 64 deletions isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
package io.substrait.isthmus;

import com.google.common.annotations.VisibleForTesting;
import io.substrait.isthmus.sql.SubstraitSqlValidator;
import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import io.substrait.plan.ImmutablePlan.Builder;
import io.substrait.plan.Plan.Version;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.proto.Plan;
import java.util.List;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.sql2rel.StandardConvertletTable;

/** Take a SQL statement and a set of table definitions and return a substrait plan. */
public class SqlToSubstrait extends SqlConverterBase {
Expand All @@ -32,69 +20,35 @@ public SqlToSubstrait(FeatureBoard features) {
}

public Plan execute(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException {
SqlValidator validator = new SubstraitSqlValidator(catalogReader);
return executeInner(sql, validator, catalogReader);
return executeInner(sql, catalogReader);
}

List<RelRoot> sqlToRelNode(String sql, Prepare.CatalogReader catalogReader)
throws SqlParseException {
SqlValidator validator = new SubstraitSqlValidator(catalogReader);
return sqlToRelNode(sql, validator, catalogReader);
}

private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogReader catalogReader)
private Plan executeInner(String sql, Prepare.CatalogReader catalogReader)
throws SqlParseException {
Builder builder = io.substrait.plan.Plan.builder();
builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build());

// TODO: consider case in which one sql passes conversion while others don't
sqlToRelNode(sql, validator, catalogReader).stream()
SubstraitSqlToCalcite.convertSelects(sql, catalogReader).stream()
.map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard))
.forEach(root -> builder.addRoots(root));

PlanProtoConverter planToProto = new PlanProtoConverter();

return planToProto.toProto(builder.build());
}

private List<RelRoot> sqlToRelNode(
String sql, SqlValidator validator, Prepare.CatalogReader catalogReader)
throws SqlParseException {
SqlParser parser = SqlParser.create(sql, parserConfig);
SqlNodeList parsedList = parser.parseStmtList();
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
List<RelRoot> roots =
parsedList.stream()
.map(parsed -> getBestExpRelRoot(converter, parsed))
.collect(java.util.stream.Collectors.toList());
return roots;
}

@VisibleForTesting
SqlToRelConverter createSqlToRelConverter(
SqlValidator validator, Prepare.CatalogReader catalogReader) {
SqlToRelConverter converter =
new SqlToRelConverter(
null,
validator,
catalogReader,
relOptCluster,
StandardConvertletTable.INSTANCE,
converterConfig);
return converter;
}

@VisibleForTesting
static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) {
RelRoot root = converter.convertQuery(parsed, true, true);
{
// RelBuilder seems to implicitly use the rule below,
// need to add to avoid discrepancies in assertFullRoundTrip
HepProgram program = HepProgram.builder().addRuleInstance(CoreRules.PROJECT_REMOVE).build();
HepPlanner hepPlanner = new HepPlanner(program);
hepPlanner.setRoot(root.rel);
root = root.withRel(hepPlanner.findBestExp());
}
return root;
}
// @VisibleForTesting
// static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) {
// RelRoot root = converter.convertQuery(parsed, true, true);
// {
// // RelBuilder seems to implicitly use the rule below,
// // need to add to avoid discrepancies in assertFullRoundTrip
// HepProgram program =
// HepProgram.builder().addRuleInstance(CoreRules.PROJECT_REMOVE).build();
// HepPlanner hepPlanner = new HepPlanner(program);
// hepPlanner.setRoot(root.rel);
// root = root.withRel(hepPlanner.findBestExp());
// }
// return root;
// }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.substrait.isthmus.sql;

import java.util.List;
import org.apache.calcite.avatica.util.Casing;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlConformanceEnum;

/** Utility class for parsing SELECT statements to {@link org.apache.calcite.rel.RelRoot}s */
public class SubstraitSelectStatementParser {

private static final SqlParser.Config PARSER_CONFIG =
SqlParser.config()
// TODO: switch to Casing.UNCHANGED
.withUnquotedCasing(Casing.TO_UPPER)
// use LENIENT conformance to allow for parsing a wide variety of dialects
.withConformance(SqlConformanceEnum.LENIENT);

/** Parse one or more SELECT statements */
public static List<SqlNode> parseSelectStatements(String selectStatements)
throws SqlParseException {
SqlParser parser = SqlParser.create(selectStatements, PARSER_CONFIG);
return parser.parseStmtList();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package io.substrait.isthmus.sql;

import io.substrait.isthmus.SubstraitTypeSystem;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.sql2rel.StandardConvertletTable;

public class SubstraitSqlToCalcite {

public static RelRoot convertSelect(String selectStatement, Prepare.CatalogReader catalogReader)
throws SqlParseException {
return convertSelect(selectStatement, catalogReader, createRelOptCluster());
}

public static RelRoot convertSelect(
String selectStatement, Prepare.CatalogReader catalogReader, RelOptCluster cluster)
throws SqlParseException {
List<SqlNode> sqlNodes = SubstraitSelectStatementParser.parseSelectStatements(selectStatement);
if (sqlNodes.size() != 1) {
throw new IllegalArgumentException(
String.format("Expected one SELECT statement, found: %d", sqlNodes.size()));
}
List<RelRoot> relRoots = convert(sqlNodes, catalogReader, cluster);
// as there was only 1 select statement, there should only be 1 root
return relRoots.get(0);
}

public static List<RelRoot> convertSelects(
String selectStatements, Prepare.CatalogReader catalogReader) throws SqlParseException {
return convertSelects(selectStatements, catalogReader, createRelOptCluster());
}

public static List<RelRoot> convertSelects(
String selectStatements, Prepare.CatalogReader catalogReader, RelOptCluster cluster)
throws SqlParseException {
List<SqlNode> sqlNodes = SubstraitSelectStatementParser.parseSelectStatements(selectStatements);
return convert(sqlNodes, catalogReader, cluster);
}

static List<RelRoot> convert(
List<SqlNode> selectStatements, Prepare.CatalogReader catalogReader, RelOptCluster cluster) {
RelOptTable.ViewExpander viewExpander = null;
SqlToRelConverter converter =
new SqlToRelConverter(
viewExpander,
new SubstraitSqlValidator(catalogReader),
catalogReader,
cluster,
StandardConvertletTable.INSTANCE,
SqlToRelConverter.CONFIG);
// apply validation
boolean needsValidation = true;
// query is the root of the tree
boolean top = true;
return selectStatements.stream()
.map(
sqlNode -> removeUnnecessaryProjects(converter.convertQuery(sqlNode, needsValidation, top)))
.collect(Collectors.toList());
}

static RelOptCluster createRelOptCluster() {
RexBuilder rexBuilder =
new RexBuilder(new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM));
HepProgram program = HepProgram.builder().build();
RelOptPlanner emptyPlanner = new HepPlanner(program);
return RelOptCluster.create(emptyPlanner, rexBuilder);
}

static RelRoot removeUnnecessaryProjects(RelRoot root) {
return root.withRel(removeUnnecessaryProjects(root.rel));
}

static RelNode removeUnnecessaryProjects(RelNode root) {
HepProgram program = HepProgram.builder().addRuleInstance(CoreRules.PROJECT_REMOVE).build();
HepPlanner planner = new HepPlanner(program);
planner.setRoot(root);
return planner.findBestExp();
}
}
21 changes: 9 additions & 12 deletions isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.isthmus;

import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand All @@ -11,10 +12,7 @@

public class ApplyJoinPlanTest extends PlanTestBase {

private static RelRoot getCalcitePlan(String sql) throws SqlParseException {
SqlToSubstrait s = new SqlToSubstrait();
return s.sqlToRelNode(sql, TPCDS_CATALOG).get(0);
}
static SqlToSubstrait s = new SqlToSubstrait();

private static void validateOuterRef(
Map<RexFieldAccess, Integer> fieldAccessDepthMap, String refName, String colName, int depth) {
Expand Down Expand Up @@ -53,16 +51,15 @@ public void lateralJoinQuery() throws SqlParseException {
*/

// validate outer reference map
RelRoot root = getCalcitePlan(sql);
RelRoot root = SubstraitSqlToCalcite.convertSelect(sql, TPCDS_CATALOG);
Map<RexFieldAccess, Integer> fieldAccessDepthMap = buildOuterFieldRefMap(root);
Assertions.assertEquals(1, fieldAccessDepthMap.size());
validateOuterRef(fieldAccessDepthMap, "$cor0", "SS_ITEM_SK", 1);

// TODO validate end to end conversion
SqlToSubstrait sE2E = new SqlToSubstrait();
Assertions.assertThrows(
UnsupportedOperationException.class,
() -> sE2E.execute(sql, TPCDS_CATALOG),
() -> s.execute(sql, TPCDS_CATALOG),
"Lateral join is not supported");
}

Expand All @@ -74,7 +71,7 @@ public void outerApplyQuery() throws SqlParseException {
+ "FROM store_sales OUTER APPLY\n"
+ " (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)";

RelRoot root = getCalcitePlan(sql);
RelRoot root = SubstraitSqlToCalcite.convertSelect(sql, TPCDS_CATALOG);

Map<RexFieldAccess, Integer> fieldAccessDepthMap = buildOuterFieldRefMap(root);
Assertions.assertEquals(1, fieldAccessDepthMap.size());
Expand All @@ -83,7 +80,7 @@ public void outerApplyQuery() throws SqlParseException {
// TODO validate end to end conversion
Assertions.assertThrows(
UnsupportedOperationException.class,
() -> new SqlToSubstrait().execute(sql, TPCDS_CATALOG),
() -> s.execute(sql, TPCDS_CATALOG),
"APPLY is not supported");
}

Expand Down Expand Up @@ -112,7 +109,7 @@ public void nestedApplyJoinQuery() throws SqlParseException {
LogicalFilter(condition=[AND(=($4, $cor0.I_ITEM_SK), =($4, $cor2.SS_ITEM_SK))])
LogicalTableScan(table=[[tpcds, PROMOTION]])
*/
RelRoot root = getCalcitePlan(sql);
RelRoot root = SubstraitSqlToCalcite.convertSelect(sql, TPCDS_CATALOG);

Map<RexFieldAccess, Integer> fieldAccessDepthMap = buildOuterFieldRefMap(root);
Assertions.assertEquals(3, fieldAccessDepthMap.size());
Expand All @@ -123,7 +120,7 @@ public void nestedApplyJoinQuery() throws SqlParseException {
// TODO validate end to end conversion
Assertions.assertThrows(
UnsupportedOperationException.class,
() -> new SqlToSubstrait().execute(sql, TPCDS_CATALOG),
() -> s.execute(sql, TPCDS_CATALOG),
"APPLY is not supported");
}

Expand All @@ -138,7 +135,7 @@ public void crossApplyQuery() throws SqlParseException {
// TODO validate end to end conversion
Assertions.assertThrows(
UnsupportedOperationException.class,
() -> new SqlToSubstrait().execute(sql, TPCDS_CATALOG),
() -> s.execute(sql, TPCDS_CATALOG),
"APPLY is not supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import io.substrait.plan.Plan;
import io.substrait.relation.NamedScan;
import java.util.List;
Expand All @@ -25,10 +26,8 @@ void preserveNamesFromSql() throws Exception {
String query = "SELECT \"a\", \"B\" FROM foo GROUP BY a, b";
List<String> expectedNames = List.of("a", "B");

List<org.apache.calcite.rel.RelRoot> calciteRelRoots = s.sqlToRelNode(query, catalogReader);
assertEquals(1, calciteRelRoots.size());

org.apache.calcite.rel.RelRoot calciteRelRoot1 = calciteRelRoots.get(0);
org.apache.calcite.rel.RelRoot calciteRelRoot1 =
SubstraitSqlToCalcite.convertSelect(query, catalogReader);
assertEquals(expectedNames, calciteRelRoot1.validatedRowType.getFieldNames());

io.substrait.plan.Plan.Root substraitRelRoot =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import java.io.IOException;
import java.util.List;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
Expand All @@ -24,11 +23,8 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE
// verify that the query works generally
assertFullRoundTrip(query);

SqlToSubstrait sqlConverter = new SqlToSubstrait();
List<RelRoot> relRoots = sqlConverter.sqlToRelNode(query, TPCH_CATALOG);
assertEquals(1, relRoots.size());
RelRoot planRoot = relRoots.get(0);
RelNode originalPlan = planRoot.rel;
RelRoot relRoot = SubstraitSqlToCalcite.convertSelect(query, TPCH_CATALOG);
RelNode originalPlan = relRoot.rel;

// Create a program to apply the AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN rule.
// This will introduce a SqlSumEmptyIsZeroAggFunction to the plan.
Expand All @@ -46,6 +42,6 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE
assertDoesNotThrow(
() ->
// Conversion of the new plan should succeed
SubstraitRelVisitor.convert(RelRoot.of(newPlan, planRoot.kind), EXTENSION_COLLECTION));
SubstraitRelVisitor.convert(RelRoot.of(newPlan, relRoot.kind), EXTENSION_COLLECTION));
}
}
Loading
Loading