Skip to content

Reusage schemas fix #1252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,9 @@ internal fun MutableMap<ColumnPath, Int>.putColumnsOrder(schema: DataFrameSchema
val columnPath = path + name
this[columnPath] = i
when (column) {
is ColumnSchema.Frame -> {
putColumnsOrder(column.schema, columnPath)
}

is ColumnSchema.Group -> {
putColumnsOrder(column.schema, columnPath)
}
is ColumnSchema.Frame -> putColumnsOrder(column.schema, columnPath)
is ColumnSchema.Group -> putColumnsOrder(column.schema, columnPath)
is ColumnSchema.Value -> Unit
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,54 @@ package org.jetbrains.kotlinx.dataframe.impl.schema
import org.jetbrains.kotlinx.dataframe.impl.renderType
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import org.jetbrains.kotlinx.dataframe.schema.CompareResult
import org.jetbrains.kotlinx.dataframe.schema.CompareResult.Equals
import org.jetbrains.kotlinx.dataframe.schema.CompareResult.IsDerived
import org.jetbrains.kotlinx.dataframe.schema.CompareResult.IsSuper
import org.jetbrains.kotlinx.dataframe.schema.CompareResult.None
import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode
import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT
import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT_FOR_NESTED_SCHEMAS
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
import org.jetbrains.kotlinx.dataframe.schema.plus
import kotlin.collections.forEach

public class DataFrameSchemaImpl(override val columns: Map<String, ColumnSchema>) : DataFrameSchema {

override fun compare(other: DataFrameSchema, strictlyEqualNestedSchemas: Boolean): CompareResult {
override fun compare(other: DataFrameSchema, comparisonMode: ComparisonMode): CompareResult {
require(other is DataFrameSchemaImpl)
if (this === other) return CompareResult.Equals
var result = CompareResult.Equals
columns.forEach {
val otherColumn = other.columns[it.key]
if (otherColumn == null) {
result = result.combine(if (strictlyEqualNestedSchemas) CompareResult.None else CompareResult.IsDerived)
} else {
result = result.combine(it.value.compareStrictlyEqualNestedSchemas(otherColumn))
if (this === other) return Equals

var result: CompareResult = Equals

// check for each column in this schema if there is a column with the same name in the other schema
// - if so, those schemas for equality, taking comparisonMode into account
// - if not, consider the other schema derived from this (or unrelated if comparisonMode == STRICT)
this.columns.forEach { (thisColName, thisSchema) ->
val otherSchema = other.columns[thisColName]
result += when {
otherSchema != null -> {
val comparison = thisSchema.compare(
other = otherSchema,
comparisonMode = if (comparisonMode == STRICT_FOR_NESTED_SCHEMAS) STRICT else comparisonMode,
)
if (comparison != Equals && comparisonMode == STRICT) None else comparison
}

else -> if (comparisonMode == STRICT) None else IsDerived
}
if (result == CompareResult.None) return CompareResult.None
if (result == None) return None
}
other.columns.forEach {
val thisField = columns[it.key]
if (thisField == null) {
result = result.combine(if (strictlyEqualNestedSchemas) CompareResult.None else CompareResult.IsSuper)
if (result == CompareResult.None) return CompareResult.None
}
// then check for each column in the other schema if there is a column with the same name in this schema
// if not, consider the other schema as super to this (or unrelated if comparisonMode == STRICT)
other.columns.forEach { (otherColName, _) ->
if (this.columns[otherColName] != null) return@forEach
result += if (comparisonMode == STRICT) None else IsSuper
if (result == None) return None
}
return result
}

override fun equals(other: Any?): Boolean = other is DataFrameSchema && compare(other).isEqual()
override fun equals(other: Any?): Boolean = other is DataFrameSchema && this.compare(other).isEqual()

override fun toString(): String = render()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ import org.jetbrains.kotlinx.dataframe.AnyRow
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.LENIENT
import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT
import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT_FOR_NESTED_SCHEMAS
import kotlin.reflect.KType
import kotlin.reflect.full.isSubtypeOf
import kotlin.reflect.full.isSupertypeOf
import kotlin.reflect.typeOf

public abstract class ColumnSchema {
public sealed class ColumnSchema {

/** Either [Value] or [Group] or [Frame]. */
public abstract val kind: ColumnKind
Expand Down Expand Up @@ -55,10 +58,11 @@ public abstract class ColumnSchema {
override val nullable: Boolean = false
override val type: KType get() = typeOf<AnyRow>()

public fun compare(other: Group): CompareResult = schema.compare(other.schema)

internal fun compareStrictlyEqualNestedSchemas(other: Group): CompareResult =
schema.compare(other.schema, strictlyEqualNestedSchemas = true)
public fun compare(other: Group, comparisonMode: ComparisonMode = LENIENT): CompareResult =
schema.compare(
other = other.schema,
comparisonMode = comparisonMode,
)
}

public class Frame(
Expand All @@ -69,14 +73,11 @@ public abstract class ColumnSchema {
public override val kind: ColumnKind = ColumnKind.Frame
override val type: KType get() = typeOf<AnyFrame>()

public fun compare(other: Frame): CompareResult =
schema.compare(other.schema).combine(CompareResult.compareNullability(nullable, other.nullable))

internal fun compareStrictlyEqualNestedSchemas(other: Frame): CompareResult =
public fun compare(other: Frame, comparisonMode: ComparisonMode = LENIENT): CompareResult =
schema.compare(
other.schema,
strictlyEqualNestedSchemas = true,
).combine(CompareResult.compareNullability(nullable, other.nullable))
other = other.schema,
comparisonMode = comparisonMode,
) + CompareResult.compareNullability(thisIsNullable = nullable, otherIsNullable = other.nullable)
}

/** Checks equality just on kind, type, or schema. */
Expand All @@ -88,37 +89,27 @@ public abstract class ColumnSchema {
is Value -> type == (otherType as Value).type
is Group -> schema == (otherType as Group).schema
is Frame -> schema == (otherType as Frame).schema
else -> throw NotImplementedError()
}
}

public fun compare(other: ColumnSchema): CompareResult = compare(other, false)

internal fun compareStrictlyEqualNestedSchemas(other: ColumnSchema): CompareResult = compare(other, true)

private fun compare(other: ColumnSchema, strictlyEqualNestedSchemas: Boolean): CompareResult {
public fun compare(other: ColumnSchema, comparisonMode: ComparisonMode = LENIENT): CompareResult {
if (kind != other.kind) return CompareResult.None
if (this === other) return CompareResult.Equals
return when (this) {
is Value -> compare(other as Value)
is Group -> compare(other as Group, comparisonMode)
is Frame -> compare(other as Frame, comparisonMode)
}
}

is Group -> if (strictlyEqualNestedSchemas) {
compareStrictlyEqualNestedSchemas(
other as Group,
)
} else {
compare(other as Group)
}

is Frame -> if (strictlyEqualNestedSchemas) {
compareStrictlyEqualNestedSchemas(
other as Frame,
)
} else {
compare(other as Frame)
}

else -> throw NotImplementedError()
override fun hashCode(): Int {
var result = nullable.hashCode()
result = 31 * result + kind.hashCode()
result = 31 * result + when (this) {
is Value -> type.hashCode()
is Group -> schema.hashCode()
is Frame -> schema.hashCode()
}
return result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ public enum class CompareResult {
}
}
}

public operator fun CompareResult.plus(other: CompareResult): CompareResult = this.combine(other)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package org.jetbrains.kotlinx.dataframe.schema

public enum class ComparisonMode {
/**
* In this mode, all [CompareResults][CompareResult] can occur.
*
* If this schema has columns the other has not, the other is considered [CompareResult.IsDerived].
* If the other schema has columns this has not, this is considered [CompareResult.IsSuper].
*/
LENIENT,

/**
* Columns must all be present in the other schema with the same name and type.
* [CompareResult.IsDerived] and [CompareResult.IsSuper] will result in [CompareResult.None] in this mode.
*/
STRICT,

/** Works like [LENIENT] at the top-level, but turns to [STRICT] for nested schemas. */
STRICT_FOR_NESTED_SCHEMAS,
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package org.jetbrains.kotlinx.dataframe.schema

import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT_FOR_NESTED_SCHEMAS

public interface DataFrameSchema {

public val columns: Map<String, ColumnSchema>

/**
* By default generated markers for leafs aren't used as supertypes: @DataSchema(isOpen = false)
* strictlyEqualNestedSchemas = true takes this into account for internal codegen logic
* [ComparisonMode.STRICT_FOR_NESTED_SCHEMAS] takes this into account for internal codegen logic
*/
public fun compare(other: DataFrameSchema, strictlyEqualNestedSchemas: Boolean = false): CompareResult
public fun compare(
other: DataFrameSchema,
comparisonMode: ComparisonMode = STRICT_FOR_NESTED_SCHEMAS,
): CompareResult
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,19 @@ import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.add
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.columnOf
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
import org.jetbrains.kotlinx.dataframe.api.generateCode
import org.jetbrains.kotlinx.dataframe.api.schema
import org.jetbrains.kotlinx.dataframe.impl.codeGen.ReplCodeGenerator
import org.jetbrains.kotlinx.dataframe.io.readJsonStr
import org.jetbrains.kotlinx.dataframe.schema.CompareResult.Equals
import org.jetbrains.kotlinx.dataframe.schema.CompareResult.IsDerived
import org.jetbrains.kotlinx.dataframe.schema.CompareResult.IsSuper
import org.jetbrains.kotlinx.dataframe.schema.CompareResult.None
import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.LENIENT
import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT
import org.jetbrains.kotlinx.dataframe.schema.ComparisonMode.STRICT_FOR_NESTED_SCHEMAS
import org.junit.Test

class MatchSchemeTests {
Expand Down Expand Up @@ -99,4 +109,91 @@ class MatchSchemeTests {
val res = df.generateCode(false, true)
println(res)
}

@Test
fun `simple data schema comparison`() {
val scheme1 = dataFrameOf(
"a" to columnOf(1, 2, 3, null),
"b" to columnOf(1.0, 2.0, 3.0, 4.0),
).schema()

val scheme2 = dataFrameOf(
"a" to columnOf(1, 2, 3, 4),
"b" to columnOf(1.0, 2.0, 3.0, 4.0),
).schema()

val scheme3 = dataFrameOf(
"c" to columnOf(1, 2, 3, 4),
).schema()

scheme1.compare(scheme1, LENIENT) shouldBe Equals
scheme2.compare(scheme2, LENIENT) shouldBe Equals
scheme1.compare(scheme2, LENIENT) shouldBe IsSuper
scheme2.compare(scheme1, LENIENT) shouldBe IsDerived
scheme1.compare(scheme3, LENIENT) shouldBe None

scheme1.compare(scheme1, STRICT_FOR_NESTED_SCHEMAS) shouldBe Equals
scheme2.compare(scheme2, STRICT_FOR_NESTED_SCHEMAS) shouldBe Equals
scheme1.compare(scheme2, STRICT_FOR_NESTED_SCHEMAS) shouldBe IsSuper
scheme2.compare(scheme1, STRICT_FOR_NESTED_SCHEMAS) shouldBe IsDerived
scheme1.compare(scheme3, STRICT_FOR_NESTED_SCHEMAS) shouldBe None

scheme1.compare(scheme1, STRICT) shouldBe Equals
scheme2.compare(scheme2, STRICT) shouldBe Equals
scheme1.compare(scheme2, STRICT) shouldBe None
scheme2.compare(scheme1, STRICT) shouldBe None
}

@Test
fun `nested data schema comparison`() {
val scheme1 = dataFrameOf(
"a" to columnOf(
"b" to columnOf(1.0, 2.0, 3.0, null),
),
).schema()

val scheme2 = dataFrameOf(
"a" to columnOf(
"b" to columnOf(1.0, 2.0, 3.0, 4.0),
),
).schema()

val scheme3 = dataFrameOf(
"c" to columnOf(1, 2, 3, 4),
).schema()

val scheme4 = dataFrameOf(
"a" to columnOf(
"b" to columnOf(1.0, 2.0, 3.0, null),
),
"c" to columnOf(1, 2, 3, 4),
).schema()

scheme1.compare(scheme1, LENIENT) shouldBe Equals
scheme2.compare(scheme2, LENIENT) shouldBe Equals
scheme1.compare(scheme2, LENIENT) shouldBe IsSuper
scheme2.compare(scheme1, LENIENT) shouldBe IsDerived
scheme1.compare(scheme3, LENIENT) shouldBe None

scheme1.compare(scheme4, LENIENT) shouldBe IsSuper
scheme4.compare(scheme1, LENIENT) shouldBe IsDerived

scheme1.compare(scheme1, STRICT_FOR_NESTED_SCHEMAS) shouldBe Equals
scheme2.compare(scheme2, STRICT_FOR_NESTED_SCHEMAS) shouldBe Equals
scheme1.compare(scheme2, STRICT_FOR_NESTED_SCHEMAS) shouldBe None
scheme2.compare(scheme1, STRICT_FOR_NESTED_SCHEMAS) shouldBe None
scheme1.compare(scheme3, STRICT_FOR_NESTED_SCHEMAS) shouldBe None

scheme1.compare(scheme4, STRICT_FOR_NESTED_SCHEMAS) shouldBe IsSuper
scheme4.compare(scheme1, STRICT_FOR_NESTED_SCHEMAS) shouldBe IsDerived
scheme2.compare(scheme4, STRICT_FOR_NESTED_SCHEMAS) shouldBe None
scheme4.compare(scheme2, STRICT_FOR_NESTED_SCHEMAS) shouldBe None

scheme1.compare(scheme1, STRICT) shouldBe Equals
scheme2.compare(scheme2, STRICT) shouldBe Equals
scheme1.compare(scheme2, STRICT) shouldBe None
scheme2.compare(scheme1, STRICT) shouldBe None
scheme1.compare(scheme3, STRICT) shouldBe None
scheme3.compare(scheme1, STRICT) shouldBe None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,24 @@ class CodeGenerationTests : DataFrameJupyterTest() {
df1.leaf.c
""".checkCompilation()
}

// Issue #1222
@Test
fun `do not reuse marker with non-matching sub-schema`() {
@Language("kt")
val _1 = """
val df1 = dataFrameOf("group" to columnOf("a" to columnOf(1, null, 3)))
val df2 = dataFrameOf("group" to columnOf("a" to columnOf(1, 2, 3)))
df1.group.a
df2.group.a
""".checkCompilation()

@Language("kt")
val _2 = """
val df1 = dataFrameOf("group" to columnOf("a" to columnOf(1, 2, 3)))
val df2 = dataFrameOf("group" to columnOf("a" to columnOf(1, null, 3)))
df1.group.a
df2.group.a
""".checkCompilation()
}
}