diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/schema.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/schema.kt index 6b80262891..dd0d952ed0 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/schema.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/schema.kt @@ -22,13 +22,9 @@ internal fun MutableMap.putColumnsOrder(schema: DataFrameSchema val columnPath = path + name this[columnPath] = i when (column) { - is ColumnSchema.Frame -> { - putColumnsOrder(column.schema, columnPath) - } - - is ColumnSchema.Group -> { - putColumnsOrder(column.schema, columnPath) - } + is ColumnSchema.Frame -> putColumnsOrder(column.schema, columnPath) + is ColumnSchema.Group -> putColumnsOrder(column.schema, columnPath) + is ColumnSchema.Value -> Unit } } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt index 7abc80a6c3..4954f049c2 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt @@ -3,34 +3,54 @@ package org.jetbrains.kotlinx.dataframe.impl.schema import org.jetbrains.kotlinx.dataframe.impl.renderType import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import org.jetbrains.kotlinx.dataframe.schema.CompareResult +import org.jetbrains.kotlinx.dataframe.schema.CompareResult.Equals +import org.jetbrains.kotlinx.dataframe.schema.CompareResult.IsDerived +import org.jetbrains.kotlinx.dataframe.schema.CompareResult.IsSuper +import org.jetbrains.kotlinx.dataframe.schema.CompareResult.None +import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode +import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT +import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT_FOR_NESTED_SCHEMAS import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema +import org.jetbrains.kotlinx.dataframe.schema.plus +import kotlin.collections.forEach public class DataFrameSchemaImpl(override val columns: Map) : DataFrameSchema { - override fun compare(other: DataFrameSchema, strictlyEqualNestedSchemas: Boolean): CompareResult { + override fun compare(other: DataFrameSchema, comparisonMode: ComparisonMode): CompareResult { require(other is DataFrameSchemaImpl) - if (this === other) return CompareResult.Equals - var result = CompareResult.Equals - columns.forEach { - val otherColumn = other.columns[it.key] - if (otherColumn == null) { - result = result.combine(if (strictlyEqualNestedSchemas) CompareResult.None else CompareResult.IsDerived) - } else { - result = result.combine(it.value.compareStrictlyEqualNestedSchemas(otherColumn)) + if (this === other) return Equals + + var result: CompareResult = Equals + + // check for each column in this schema if there is a column with the same name in the other schema + // - if so, those schemas for equality, taking comparisonMode into account + // - if not, consider the other schema derived from this (or unrelated if comparisonMode == STRICT) + this.columns.forEach { (thisColName, thisSchema) -> + val otherSchema = other.columns[thisColName] + result += when { + otherSchema != null -> { + val comparison = thisSchema.compare( + other = otherSchema, + comparisonMode = if (comparisonMode == STRICT_FOR_NESTED_SCHEMAS) STRICT else comparisonMode, + ) + if (comparison != Equals && comparisonMode == STRICT) None else comparison + } + + else -> if (comparisonMode == STRICT) None else IsDerived } - if (result == CompareResult.None) return CompareResult.None + if (result == None) return None } - other.columns.forEach { - val thisField = columns[it.key] - if (thisField == null) { - result = result.combine(if (strictlyEqualNestedSchemas) CompareResult.None else CompareResult.IsSuper) - if (result == CompareResult.None) return CompareResult.None - } + // then check for each column in the other schema if there is a column with the same name in this schema + // if not, consider the other schema as super to this (or unrelated if comparisonMode == STRICT) + other.columns.forEach { (otherColName, _) -> + if (this.columns[otherColName] != null) return@forEach + result += if (comparisonMode == STRICT) None else IsSuper + if (result == None) return None } return result } - override fun equals(other: Any?): Boolean = other is DataFrameSchema && compare(other).isEqual() + override fun equals(other: Any?): Boolean = other is DataFrameSchema && this.compare(other).isEqual() override fun toString(): String = render() diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/ColumnSchema.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/ColumnSchema.kt index 08497d63d3..f400737fd3 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/ColumnSchema.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/ColumnSchema.kt @@ -5,12 +5,15 @@ import org.jetbrains.kotlinx.dataframe.AnyRow import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.columns.ColumnKind +import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.LENIENT +import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT +import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT_FOR_NESTED_SCHEMAS import kotlin.reflect.KType import kotlin.reflect.full.isSubtypeOf import kotlin.reflect.full.isSupertypeOf import kotlin.reflect.typeOf -public abstract class ColumnSchema { +public sealed class ColumnSchema { /** Either [Value] or [Group] or [Frame]. */ public abstract val kind: ColumnKind @@ -55,10 +58,11 @@ public abstract class ColumnSchema { override val nullable: Boolean = false override val type: KType get() = typeOf() - public fun compare(other: Group): CompareResult = schema.compare(other.schema) - - internal fun compareStrictlyEqualNestedSchemas(other: Group): CompareResult = - schema.compare(other.schema, strictlyEqualNestedSchemas = true) + public fun compare(other: Group, comparisonMode: ComparisonMode = LENIENT): CompareResult = + schema.compare( + other = other.schema, + comparisonMode = comparisonMode, + ) } public class Frame( @@ -69,14 +73,11 @@ public abstract class ColumnSchema { public override val kind: ColumnKind = ColumnKind.Frame override val type: KType get() = typeOf() - public fun compare(other: Frame): CompareResult = - schema.compare(other.schema).combine(CompareResult.compareNullability(nullable, other.nullable)) - - internal fun compareStrictlyEqualNestedSchemas(other: Frame): CompareResult = + public fun compare(other: Frame, comparisonMode: ComparisonMode = LENIENT): CompareResult = schema.compare( - other.schema, - strictlyEqualNestedSchemas = true, - ).combine(CompareResult.compareNullability(nullable, other.nullable)) + other = other.schema, + comparisonMode = comparisonMode, + ) + CompareResult.compareNullability(thisIsNullable = nullable, otherIsNullable = other.nullable) } /** Checks equality just on kind, type, or schema. */ @@ -88,37 +89,27 @@ public abstract class ColumnSchema { is Value -> type == (otherType as Value).type is Group -> schema == (otherType as Group).schema is Frame -> schema == (otherType as Frame).schema - else -> throw NotImplementedError() } } - public fun compare(other: ColumnSchema): CompareResult = compare(other, false) - - internal fun compareStrictlyEqualNestedSchemas(other: ColumnSchema): CompareResult = compare(other, true) - - private fun compare(other: ColumnSchema, strictlyEqualNestedSchemas: Boolean): CompareResult { + public fun compare(other: ColumnSchema, comparisonMode: ComparisonMode = LENIENT): CompareResult { if (kind != other.kind) return CompareResult.None if (this === other) return CompareResult.Equals return when (this) { is Value -> compare(other as Value) + is Group -> compare(other as Group, comparisonMode) + is Frame -> compare(other as Frame, comparisonMode) + } + } - is Group -> if (strictlyEqualNestedSchemas) { - compareStrictlyEqualNestedSchemas( - other as Group, - ) - } else { - compare(other as Group) - } - - is Frame -> if (strictlyEqualNestedSchemas) { - compareStrictlyEqualNestedSchemas( - other as Frame, - ) - } else { - compare(other as Frame) - } - - else -> throw NotImplementedError() + override fun hashCode(): Int { + var result = nullable.hashCode() + result = 31 * result + kind.hashCode() + result = 31 * result + when (this) { + is Value -> type.hashCode() + is Group -> schema.hashCode() + is Frame -> schema.hashCode() } + return result } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/CompareResult.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/CompareResult.kt index 6f9a63d592..f238618058 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/CompareResult.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/CompareResult.kt @@ -28,3 +28,5 @@ public enum class CompareResult { } } } + +public operator fun CompareResult.plus(other: CompareResult): CompareResult = this.combine(other) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/ComparisonMode.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/ComparisonMode.kt new file mode 100644 index 0000000000..69fcef445b --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/ComparisonMode.kt @@ -0,0 +1,20 @@ +package org.jetbrains.kotlinx.dataframe.schema + +public enum class ComparisonMode { + /** + * In this mode, all [CompareResults][CompareResult] can occur. + * + * If this schema has columns the other has not, the other is considered [CompareResult.IsDerived]. + * If the other schema has columns this has not, this is considered [CompareResult.IsSuper]. + */ + LENIENT, + + /** + * Columns must all be present in the other schema with the same name and type. + * [CompareResult.IsDerived] and [CompareResult.IsSuper] will result in [CompareResult.None] in this mode. + */ + STRICT, + + /** Works like [LENIENT] at the top-level, but turns to [STRICT] for nested schemas. */ + STRICT_FOR_NESTED_SCHEMAS, +} diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/DataFrameSchema.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/DataFrameSchema.kt index 4c706dbff8..5d8763f294 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/DataFrameSchema.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/DataFrameSchema.kt @@ -1,12 +1,17 @@ package org.jetbrains.kotlinx.dataframe.schema +import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT_FOR_NESTED_SCHEMAS + public interface DataFrameSchema { public val columns: Map /** * By default generated markers for leafs aren't used as supertypes: @DataSchema(isOpen = false) - * strictlyEqualNestedSchemas = true takes this into account for internal codegen logic + * [ComparisonMode.STRICT_FOR_NESTED_SCHEMAS] takes this into account for internal codegen logic */ - public fun compare(other: DataFrameSchema, strictlyEqualNestedSchemas: Boolean = false): CompareResult + public fun compare( + other: DataFrameSchema, + comparisonMode: ComparisonMode = STRICT_FOR_NESTED_SCHEMAS, + ): CompareResult } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/MatchSchemeTests.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/MatchSchemeTests.kt index b50fb98fcc..598c0aa057 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/MatchSchemeTests.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/MatchSchemeTests.kt @@ -6,9 +6,19 @@ import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.annotations.DataSchema import org.jetbrains.kotlinx.dataframe.api.add import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.columnOf +import org.jetbrains.kotlinx.dataframe.api.dataFrameOf import org.jetbrains.kotlinx.dataframe.api.generateCode +import org.jetbrains.kotlinx.dataframe.api.schema import org.jetbrains.kotlinx.dataframe.impl.codeGen.ReplCodeGenerator import org.jetbrains.kotlinx.dataframe.io.readJsonStr +import org.jetbrains.kotlinx.dataframe.schema.CompareResult.Equals +import org.jetbrains.kotlinx.dataframe.schema.CompareResult.IsDerived +import org.jetbrains.kotlinx.dataframe.schema.CompareResult.IsSuper +import org.jetbrains.kotlinx.dataframe.schema.CompareResult.None +import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.LENIENT +import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT +import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT_FOR_NESTED_SCHEMAS import org.junit.Test class MatchSchemeTests { @@ -99,4 +109,91 @@ class MatchSchemeTests { val res = df.generateCode(false, true) println(res) } + + @Test + fun `simple data schema comparison`() { + val scheme1 = dataFrameOf( + "a" to columnOf(1, 2, 3, null), + "b" to columnOf(1.0, 2.0, 3.0, 4.0), + ).schema() + + val scheme2 = dataFrameOf( + "a" to columnOf(1, 2, 3, 4), + "b" to columnOf(1.0, 2.0, 3.0, 4.0), + ).schema() + + val scheme3 = dataFrameOf( + "c" to columnOf(1, 2, 3, 4), + ).schema() + + scheme1.compare(scheme1, LENIENT) shouldBe Equals + scheme2.compare(scheme2, LENIENT) shouldBe Equals + scheme1.compare(scheme2, LENIENT) shouldBe IsSuper + scheme2.compare(scheme1, LENIENT) shouldBe IsDerived + scheme1.compare(scheme3, LENIENT) shouldBe None + + scheme1.compare(scheme1, STRICT_FOR_NESTED_SCHEMAS) shouldBe Equals + scheme2.compare(scheme2, STRICT_FOR_NESTED_SCHEMAS) shouldBe Equals + scheme1.compare(scheme2, STRICT_FOR_NESTED_SCHEMAS) shouldBe IsSuper + scheme2.compare(scheme1, STRICT_FOR_NESTED_SCHEMAS) shouldBe IsDerived + scheme1.compare(scheme3, STRICT_FOR_NESTED_SCHEMAS) shouldBe None + + scheme1.compare(scheme1, STRICT) shouldBe Equals + scheme2.compare(scheme2, STRICT) shouldBe Equals + scheme1.compare(scheme2, STRICT) shouldBe None + scheme2.compare(scheme1, STRICT) shouldBe None + } + + @Test + fun `nested data schema comparison`() { + val scheme1 = dataFrameOf( + "a" to columnOf( + "b" to columnOf(1.0, 2.0, 3.0, null), + ), + ).schema() + + val scheme2 = dataFrameOf( + "a" to columnOf( + "b" to columnOf(1.0, 2.0, 3.0, 4.0), + ), + ).schema() + + val scheme3 = dataFrameOf( + "c" to columnOf(1, 2, 3, 4), + ).schema() + + val scheme4 = dataFrameOf( + "a" to columnOf( + "b" to columnOf(1.0, 2.0, 3.0, null), + ), + "c" to columnOf(1, 2, 3, 4), + ).schema() + + scheme1.compare(scheme1, LENIENT) shouldBe Equals + scheme2.compare(scheme2, LENIENT) shouldBe Equals + scheme1.compare(scheme2, LENIENT) shouldBe IsSuper + scheme2.compare(scheme1, LENIENT) shouldBe IsDerived + scheme1.compare(scheme3, LENIENT) shouldBe None + + scheme1.compare(scheme4, LENIENT) shouldBe IsSuper + scheme4.compare(scheme1, LENIENT) shouldBe IsDerived + + scheme1.compare(scheme1, STRICT_FOR_NESTED_SCHEMAS) shouldBe Equals + scheme2.compare(scheme2, STRICT_FOR_NESTED_SCHEMAS) shouldBe Equals + scheme1.compare(scheme2, STRICT_FOR_NESTED_SCHEMAS) shouldBe None + scheme2.compare(scheme1, STRICT_FOR_NESTED_SCHEMAS) shouldBe None + scheme1.compare(scheme3, STRICT_FOR_NESTED_SCHEMAS) shouldBe None + + scheme1.compare(scheme4, STRICT_FOR_NESTED_SCHEMAS) shouldBe IsSuper + scheme4.compare(scheme1, STRICT_FOR_NESTED_SCHEMAS) shouldBe IsDerived + scheme2.compare(scheme4, STRICT_FOR_NESTED_SCHEMAS) shouldBe None + scheme4.compare(scheme2, STRICT_FOR_NESTED_SCHEMAS) shouldBe None + + scheme1.compare(scheme1, STRICT) shouldBe Equals + scheme2.compare(scheme2, STRICT) shouldBe Equals + scheme1.compare(scheme2, STRICT) shouldBe None + scheme2.compare(scheme1, STRICT) shouldBe None + scheme1.compare(scheme3, STRICT) shouldBe None + scheme3.compare(scheme1, STRICT) shouldBe None + } } diff --git a/dataframe-jupyter/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/CodeGenerationTests.kt b/dataframe-jupyter/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/CodeGenerationTests.kt index cdd2c1e318..49c261effa 100644 --- a/dataframe-jupyter/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/CodeGenerationTests.kt +++ b/dataframe-jupyter/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/CodeGenerationTests.kt @@ -70,4 +70,24 @@ class CodeGenerationTests : DataFrameJupyterTest() { df1.leaf.c """.checkCompilation() } + + // Issue #1222 + @Test + fun `do not reuse marker with non-matching sub-schema`() { + @Language("kt") + val _1 = """ + val df1 = dataFrameOf("group" to columnOf("a" to columnOf(1, null, 3))) + val df2 = dataFrameOf("group" to columnOf("a" to columnOf(1, 2, 3))) + df1.group.a + df2.group.a + """.checkCompilation() + + @Language("kt") + val _2 = """ + val df1 = dataFrameOf("group" to columnOf("a" to columnOf(1, 2, 3))) + val df2 = dataFrameOf("group" to columnOf("a" to columnOf(1, null, 3))) + df1.group.a + df2.group.a + """.checkCompilation() + } }