Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package io.substrait.isthmus;

import io.substrait.isthmus.calcite.SubstraitSchema;
import io.substrait.isthmus.calcite.SubstraitTable;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Rel;
import io.substrait.relation.RelCopyOnWriteVisitor;
import io.substrait.type.NamedStruct;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;

/** For use in generating {@link CalciteSchema}s from Substrait {@link Rel}s */
public class SchemaCollector {

private static final boolean CASE_SENSITIVE = false;

private final RelDataTypeFactory typeFactory;
private final TypeConverter typeConverter;

public SchemaCollector(RelDataTypeFactory typeFactory, TypeConverter typeConverter) {
this.typeFactory = typeFactory;
this.typeConverter = typeConverter;
}

public CalciteSchema toSchema(Rel rel) {
// Create the root schema under which all tables and schemas will be nested.
CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false);

for (Map.Entry<List<String>, NamedStruct> entry : TableGatherer.gatherTables(rel).entrySet()) {
List<String> names = entry.getKey();
NamedStruct namedStruct = entry.getValue();

// The last name in names is the table name. All others are schema names.
String tableName = names.get(names.size() - 1);

// Traverse all schemas, creating them if they are not present
CalciteSchema schema = rootSchema;
for (String schemaName : names.subList(0, names.size() - 1)) {
CalciteSchema subSchema = schema.getSubSchema(schemaName, CASE_SENSITIVE);
if (subSchema != null) {
schema = subSchema;
} else {
SubstraitSchema newSubSchema = new SubstraitSchema();
schema = schema.add(schemaName, newSubSchema);
}
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing LookupCalciteSchema code didn't handle nested schemas. I've added support for it here.


// Create the table if it is not present
CalciteSchema.TableEntry table = schema.getTable(tableName, CASE_SENSITIVE);
if (table == null) {
RelDataType rowType =
typeConverter.toCalcite(typeFactory, namedStruct.struct(), namedStruct.names());
schema.add(tableName, new SubstraitTable(tableName, rowType));
}
}

return rootSchema;
}

static class TableGatherer extends RelCopyOnWriteVisitor<RuntimeException> {
Map<List<String>, NamedStruct> tableMap;

private TableGatherer() {
super();
this.tableMap = new HashMap<>();
}

/**
* Gathers all tables defined in {@link NamedScan}s under the given {@link Rel}
*
* @param rootRel under which to search for {@link NamedScan}s
* @return a map of qualified table names to their associated Substrait schemas
*/
public static Map<List<String>, NamedStruct> gatherTables(Rel rootRel) {
var visitor = new TableGatherer();
rootRel.accept(visitor);
return visitor.tableMap;
}

@Override
public Optional<Rel> visit(NamedScan namedScan) {
super.visit(namedScan);

List<String> tableName = namedScan.getNames();
if (tableMap.containsKey(tableName)) {
NamedStruct existingSchema = tableMap.get(tableName);
if (!existingSchema.equals(namedScan.getInitialSchema())) {
throw new IllegalArgumentException(
String.format(
"NamedScan for %s is present multiple times with different schemas", tableName));
Copy link
Member Author

@vbarua vbarua Apr 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's helpful to guard against this because it can result in some very weird errors. The same table can appear multiple times in a plan (i.e. through joins), and if an external producer applies field trimming and those tables are trimmed to have different fields, unpleasantness can ensue.

}
}
tableMap.put(tableName, namedScan.getInitialSchema());

return Optional.empty();
}
}
}
44 changes: 9 additions & 35 deletions isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.calcite.SubstraitOperatorTable;
import io.substrait.isthmus.calcite.SubstraitTable;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.config.CalciteConnectionConfig;
Expand All @@ -20,7 +21,6 @@
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.impl.AbstractTable;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperatorTable;
Expand Down Expand Up @@ -79,8 +79,8 @@ CalciteCatalogReader registerCreateTables(List<String> tables) throws SqlParseEx
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
if (tables != null) {
for (String tableDef : tables) {
List<DefinedTable> tList = parseCreateTable(factory, validator, tableDef);
for (DefinedTable t : tList) {
List<SubstraitTable> tList = parseCreateTable(factory, validator, tableDef);
for (SubstraitTable t : tList) {
rootSchema.add(t.getName(), t);
}
}
Expand All @@ -97,10 +97,10 @@ CalciteCatalogReader registerSchema(String name, Schema schema) {
return new CalciteCatalogReader(rootSchema, List.of(), factory, config);
}

protected List<DefinedTable> parseCreateTable(
protected List<SubstraitTable> parseCreateTable(
RelDataTypeFactory factory, SqlValidator validator, String sql) throws SqlParseException {
SqlParser parser = SqlParser.create(sql, parserConfig);
List<DefinedTable> definedTableList = new ArrayList<>();
List<SubstraitTable> tableList = new ArrayList<>();

SqlNodeList nodeList = parser.parseStmtList();
for (SqlNode parsed : nodeList) {
Expand Down Expand Up @@ -140,12 +140,12 @@ protected List<DefinedTable> parseCreateTable(
columnTypes.add(col.dataType.deriveType(validator));
}

definedTableList.add(
new DefinedTable(
create.name.names.get(0), factory, factory.createStructType(columnTypes, names)));
tableList.add(
new SubstraitTable(
create.name.names.get(0), factory.createStructType(columnTypes, names)));
}

return definedTableList;
return tableList;
}

protected static SqlParseException fail(String text, SqlParserPos pos) {
Expand Down Expand Up @@ -173,30 +173,4 @@ public static Validator create(
return new Validator(SubstraitOperatorTable.INSTANCE, validatorCatalog, factory, config);
}
}

/** A fully defined pre-specified table. */
protected static final class DefinedTable extends AbstractTable {

private final String name;
private final RelDataTypeFactory factory;
private final RelDataType type;

public DefinedTable(String name, RelDataTypeFactory factory, RelDataType type) {
this.name = name;
this.factory = factory;
this.type = type;
}

@Override
public RelDataType getRowType(RelDataTypeFactory typeFactory) {
// if (factory != typeFactory) {
// throw new IllegalStateException("Different type factory than previously used.");
// }
return type;
}

public String getName() {
return name;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.substrait.extendedexpression.ImmutableExpressionReference;
import io.substrait.extendedexpression.ImmutableExtendedExpression;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.calcite.SubstraitTable;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.proto.ExtendedExpression;
Expand Down Expand Up @@ -145,8 +146,8 @@ private Result registerCreateTablesForExtendedExpression(List<String> tables)
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
if (tables != null) {
for (String tableDef : tables) {
List<DefinedTable> tList = parseCreateTable(factory, validator, tableDef);
for (DefinedTable t : tList) {
List<SubstraitTable> tList = parseCreateTable(factory, validator, tableDef);
for (SubstraitTable t : tList) {
rootSchema.add(t.getName(), t);
for (RelDataTypeField field : t.getRowType(factory).getFieldList()) {
nameToTypeMap.merge( // to validate the sql expression tree
Expand Down
18 changes: 2 additions & 16 deletions isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.jdbc.LookupCalciteSchema;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.schema.Table;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.tools.RelBuilder;
Expand Down Expand Up @@ -57,19 +54,8 @@ public SubstraitToCalcite(
* <p>Override this method to customize schema extraction.
*/
protected CalciteSchema toSchema(Rel rel) {
Map<List<String>, NamedStruct> tableMap = NamedStructGatherer.gatherTables(rel);
Function<List<String>, Table> lookup =
id -> {
NamedStruct table = tableMap.get(id);
if (table == null) {
return null;
}
return new SqlConverterBase.DefinedTable(
id.get(id.size() - 1),
typeFactory,
typeConverter.toCalcite(typeFactory, table.struct(), table.names()));
};
return LookupCalciteSchema.createRootSchema(lookup);
SchemaCollector schemaCollector = new SchemaCollector(typeFactory, typeConverter);
return schemaCollector.toSchema(rel);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package io.substrait.isthmus.calcite;

import java.util.HashMap;
import java.util.Map;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.Table;
import org.apache.calcite.schema.impl.AbstractSchema;

/** Basic {@link AbstractSchema} implementation */
public class SubstraitSchema extends AbstractSchema {

/** Map of table names to their associated tables */
protected final Map<String, Table> tableMap;

/** Map of schema names to their associated schemas */
protected final Map<String, Schema> schemaMap;

public SubstraitSchema() {
this.tableMap = new HashMap<>();
this.schemaMap = new HashMap<>();
}

public SubstraitSchema(Map<String, Table> tableMap) {
this.tableMap = tableMap;
this.schemaMap = new HashMap<>();
}

@Override
public Map<String, Table> getTableMap() {
return tableMap;
}

@Override
protected Map<String, Schema> getSubSchemaMap() {
return schemaMap;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.substrait.isthmus.calcite;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.schema.impl.AbstractTable;

/** Basic {@link AbstractTable} implementation */
public class SubstraitTable extends AbstractTable {

private final RelDataType rowType;
private final String tableName;

public SubstraitTable(String tableName, RelDataType rowType) {
this.tableName = tableName;
this.rowType = rowType;
}

public String getName() {
return tableName;
}

@Override
public RelDataType getRowType(RelDataTypeFactory typeFactory) {
return rowType;
}
}

This file was deleted.

Loading
Loading