Skip to content

fix singleton issue (backproperty initialization in parent class) #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion app/src/main/kotlin/com/stslex/compiler_app/MainActivity.kt
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ class MainActivity : ComponentActivity() {
sendToastOfUserChanges(user)
setName(user.name)
setSecondName(user.secondName)
printUsernameWithSingletonDistinct(user.name)
}

@DistinctUntilChangeFun(false)
@DistinctUntilChangeFun(true)
private fun setName(name: String) {
logger.log(Level.INFO, "setName: $name")
findViewById<TextView>(R.id.usernameFieldTextView).text = name
Expand All @@ -65,3 +66,11 @@ class MainActivity : ComponentActivity() {
findViewById<TextView>(R.id.secondNameFieldTextView).text = name
}
}

@DistinctUntilChangeFun(
logging = true,
singletonAllow = true
)
private fun printUsernameWithSingletonDistinct(name: String) {
println("printUsernameWithSingletonDistinct: $name")
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ package io.github.stslex.compiler_plugin

import io.github.stslex.compiler_plugin.utils.RuntimeLogger

internal object DistinctChangeCache {
internal class DistinctChangeCache(
private val config: DistinctChangeConfig
) {

private val cache = mutableMapOf<String, Pair<List<Any?>, Any?>>()
private val logger = RuntimeLogger.tag("DistinctChangeLogger")

@JvmStatic
@Suppress("UNCHECKED_CAST")
fun <R> invoke(
internal operator fun <R> invoke(
key: String,
args: List<Any?>,
body: () -> R,
config: DistinctChangeConfig
): R {
val entry = cache[key]

Expand All @@ -22,6 +22,9 @@ internal object DistinctChangeCache {
}

if (entry != null && entry.first == args) {
if (config.logging) {
logger.i("$key not change")
}
return entry.second as R
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package io.github.stslex.compiler_plugin

/**
* @param logging enable logs for Kotlin Compiler Runtime work (useful for debug - don't use in production)
* @param singletonAllow if enable - generates distinction for function without classes (so it's singleton)
* */
@Target(AnnotationTarget.FUNCTION)
@Retention(AnnotationRetention.BINARY)
public annotation class DistinctUntilChangeFun(
val logging: Boolean = LOGGING_DEFAULT
val logging: Boolean = LOGGING_DEFAULT,
val singletonAllow: Boolean = false
) {

public companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@ import io.github.stslex.compiler_plugin.utils.CompileLogger.Companion.toCompiler
import io.github.stslex.compiler_plugin.utils.buildLambdaForBody
import io.github.stslex.compiler_plugin.utils.buildSaveInCacheCall
import io.github.stslex.compiler_plugin.utils.fullyQualifiedName
import io.github.stslex.compiler_plugin.utils.generateFields
import io.github.stslex.compiler_plugin.utils.getQualifierValue
import io.github.stslex.compiler_plugin.utils.readQualifier
import org.jetbrains.kotlin.backend.common.IrElementTransformerVoidWithContext
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.jvm.ir.fileParentOrNull
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
import org.jetbrains.kotlin.ir.declarations.createExpressionBody
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
import org.jetbrains.kotlin.ir.util.parentClassOrNull

internal class IrFunctionTransformer(
private val pluginContext: IrPluginContext,
Expand All @@ -25,6 +29,13 @@ internal class IrFunctionTransformer(
val qualifierArgs = pluginContext.readQualifier(declaration, logger)
?: return super.visitSimpleFunction(declaration)

if (
declaration.getQualifierValue("singletonAllow").not() &&
declaration.parentClassOrNull == null
) {
error("singleton is not allowed for ${declaration.name} in ${declaration.fileParentOrNull}")
}

val originalBody = declaration.body ?: return super.visitSimpleFunction(declaration)

logger.i("fullyQualifiedName: ${declaration.fullyQualifiedName}")
Expand All @@ -37,12 +48,16 @@ internal class IrFunctionTransformer(

val argsListExpr = pluginContext.buildArgsListExpression(declaration)
val lambdaExpr = pluginContext.buildLambdaForBody(originalBody, declaration)

val backingField = pluginContext.generateFields(declaration, qualifierArgs, logger)

logger.i("backingField = $backingField")
val memoizeCall = pluginContext.buildSaveInCacheCall(
keyLiteral = keyLiteral,
argsListExpr = argsListExpr,
lambdaExpr = lambdaExpr,
function = declaration,
qualifierArgs = qualifierArgs,
backingField = backingField,
logger = logger
)

Expand All @@ -52,4 +67,3 @@ internal class IrFunctionTransformer(
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,37 @@ package io.github.stslex.compiler_plugin.utils
import io.github.stslex.compiler_plugin.DistinctChangeCache
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.lower.DeclarationIrBuilder
import org.jetbrains.kotlin.backend.jvm.ir.fileParentOrNull
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.ir.builders.declarations.buildFun
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.declarations.IrConstructor
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.declarations.IrPackageFragment
import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
import org.jetbrains.kotlin.ir.declarations.createExpressionBody
import org.jetbrains.kotlin.ir.expressions.IrBody
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrFunctionExpressionImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrGetFieldImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrGetValueImpl
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.symbols.impl.IrFieldSymbolImpl
import org.jetbrains.kotlin.ir.types.defaultType
import org.jetbrains.kotlin.ir.types.typeWith
import org.jetbrains.kotlin.ir.util.deepCopyWithSymbols
import org.jetbrains.kotlin.ir.util.defaultType
import org.jetbrains.kotlin.ir.util.dump
import org.jetbrains.kotlin.ir.util.file
import org.jetbrains.kotlin.ir.util.kotlinFqName
import org.jetbrains.kotlin.ir.util.parentClassOrNull
import org.jetbrains.kotlin.ir.util.patchDeclarationParents
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
Expand Down Expand Up @@ -88,39 +100,117 @@ internal val IrFunction.fullyQualifiedName: String
/**
* Create call for [DistinctChangeCache.invoke]
*/
@OptIn(UnsafeDuringIrConstructionAPI::class)
internal fun IrPluginContext.buildSaveInCacheCall(
keyLiteral: IrExpression,
argsListExpr: IrExpression,
lambdaExpr: IrExpression,
function: IrSimpleFunction,
qualifierArgs: IrExpression,
logger: CompileLogger
logger: CompileLogger,
backingField: IrFieldSymbolImpl
): IrExpression {
logger.i("buildSaveInCacheCall for ${function.name}, args: ${argsListExpr.dump()} with config: ${qualifierArgs.dump()}")
logger.i("buildSaveInCacheCall for ${function.name}, args: ${argsListExpr.dump()}")

val distinctChangeClassSymbol = referenceClass(DistinctChangeCache::class.toClassId())
?: error("Cannot find DistinctChangeCache")

val invokeFunSymbol = distinctChangeClassSymbol.owner.declarations
.filterIsInstance<IrSimpleFunction>()
.firstOrNull { it.name == Name.identifier("invoke") }
?: error("Cannot find DistinctChangeCache.invoke")

val memoizeFunction = referenceFunctions(
CallableId(
classId = DistinctChangeCache::class.toClassId(),
callableName = Name.identifier("invoke")
)
val getDistCacheField = IrGetFieldImpl(
startOffset = function.startOffset,
endOffset = function.endOffset,
symbol = backingField,
type = distinctChangeClassSymbol.owner.defaultType,
receiver = function.dispatchReceiverParameter?.let { thisReceiver ->
IrGetValueImpl(
startOffset = function.startOffset,
endOffset = function.endOffset,
symbol = thisReceiver.symbol,
type = thisReceiver.type
)
},
origin = null
)
.singleOrNull()
?: error("Cannot find function DistinctChangeCache.memorize")

return IrCallImpl(
startOffset = function.startOffset,
endOffset = function.endOffset,
type = function.returnType,
symbol = memoizeFunction,
symbol = invokeFunSymbol.symbol,
typeArgumentsCount = 1,
valueArgumentsCount = 4
valueArgumentsCount = 3,
origin = null
)
.also { call -> call.patchDeclarationParents(function) }
.also { it.patchDeclarationParents(function.parent) }
.apply {
dispatchReceiver = getDistCacheField

putTypeArgument(0, function.returnType)
putValueArgument(0, keyLiteral)
putValueArgument(1, argsListExpr)
putValueArgument(2, lambdaExpr)
putValueArgument(3, qualifierArgs)
}
}

@OptIn(UnsafeDuringIrConstructionAPI::class)
internal fun IrPluginContext.generateFields(
function: IrSimpleFunction,
qualifierArgs: IrExpression,
logger: CompileLogger
): IrFieldSymbolImpl {
logger.i("generateFields for ${function.name} parent: ${function.file}")

val parentClass = function.parentClassOrNull
val parentFile = function.fileParentOrNull

val errorNotFound =
"function ${function.name} in ${function.file} couldn't be used with @DistinctUntilChangeFun"

if (parentClass == null && parentFile == null) error(errorNotFound)


val startOffset = parentClass?.startOffset ?: parentFile?.startOffset ?: error(errorNotFound)
val endOffset = parentClass?.endOffset ?: parentFile?.endOffset ?: error(errorNotFound)

val fieldSymbol = IrFieldSymbolImpl()

val distinctChangeClass = referenceClass(DistinctChangeCache::class.toClassId())
?: error("couldn't find DistinctChangeCache")

val backingField = irFactory.createField(
startOffset = startOffset,
endOffset = endOffset,
origin = IrDeclarationOrigin.PROPERTY_BACKING_FIELD,
symbol = fieldSymbol,
name = Name.identifier("_distinctCache"),
type = distinctChangeClass.defaultType,
visibility = DescriptorVisibilities.PRIVATE,
isFinal = true,
isExternal = false,
isStatic = parentClass == null,
)

val constructorSymbol = distinctChangeClass.owner.declarations
.filterIsInstance<IrConstructor>()
.firstOrNull { it.isPrimary }
?: error("Cannot find primary constructor of DistinctChangeCache")

val callDistInit = IrConstructorCallImpl.fromSymbolOwner(
startOffset = startOffset,
endOffset = endOffset,
type = distinctChangeClass.defaultType,
constructorSymbol = constructorSymbol.symbol
)
.apply {
putValueArgument(0, qualifierArgs)
}

backingField.parent = function.parent
backingField.initializer = irFactory.createExpressionBody(callDistInit)
(function.parentClassOrNull ?: function.fileParentOrNull)?.declarations?.add(backingField)

return fieldSymbol
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.ir.builders.irBoolean
import org.jetbrains.kotlin.ir.builders.irCallConstructor
import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
import org.jetbrains.kotlin.ir.expressions.IrConst
import org.jetbrains.kotlin.ir.expressions.IrConstKind
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.util.constructors
import org.jetbrains.kotlin.ir.util.getAnnotation
import org.jetbrains.kotlin.ir.util.getValueArgument
import org.jetbrains.kotlin.ir.util.patchDeclarationParents
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name

@OptIn(UnsafeDuringIrConstructionAPI::class)
internal fun IrPluginContext.readQualifier(
Expand All @@ -26,8 +30,8 @@ internal fun IrPluginContext.readQualifier(

val irBuilder = createIrBuilder(function)

val currentValue = annotation.getValueArgument(0)
val logging = currentValue ?: irBuilder.irBoolean(LOGGING_DEFAULT)
val logging = annotation.getValueArgument(0)
?: irBuilder.irBoolean(LOGGING_DEFAULT)

val constructorSymbol = referenceClass(DistinctChangeConfig::class.toClassId())
?.constructors
Expand All @@ -43,4 +47,28 @@ internal fun IrPluginContext.readQualifier(
.apply {
putValueArgument(0, logging)
}
}
}

internal fun IrSimpleFunction.getQualifierValue(name: String): Boolean = getAnnotation(
FqName(DistinctUntilChangeFun::class.qualifiedName!!)
)
?.getValueArgument(Name.identifier(name))
?.parseValue<Boolean>()
?: false

private inline fun <reified T> IrExpression.parseValue(): T = when (this) {
is IrConst<*> -> when (kind) {
IrConstKind.Boolean -> value
IrConstKind.Byte -> value
IrConstKind.Char -> value
IrConstKind.Double -> value
IrConstKind.Float -> value
IrConstKind.Int -> value
IrConstKind.Long -> value
IrConstKind.Null -> value
IrConstKind.Short -> value
IrConstKind.String -> value
}

else -> error("Unsupported type")
} as? T ?: error("${T::class} is not as it expected: $value")
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ activity = "1.10.0"
constraintLayout = "2.2.0"
jetbrainsKotlinJvm = "2.0.20"

stslexCompilerPlugin = "0.0.2"
stslexCompilerPlugin = "0.0.3"

[libraries]
android-desugarJdkLibs = { module = "com.android.tools:desugar_jdk_libs", version.ref = "androidDesugarJdkLibs" }
Expand Down