Skip to content

Commit 0a65def

Browse files
Fix custom serde factory usage in kotlin (#477)
1 parent 8a254a8 commit 0a65def

File tree

3 files changed

+52
-3
lines changed

3 files changed

+52
-3
lines changed

sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
99
package dev.restate.sdk.kotlin.gen
1010

11+
import com.google.devtools.ksp.KSTypeNotPresentException
1112
import com.google.devtools.ksp.KspExperimental
1213
import com.google.devtools.ksp.getAnnotationsByType
1314
import com.google.devtools.ksp.getVisibility
@@ -104,7 +105,7 @@ class KElementConverter(
104105
val customSerdeFactory: CustomSerdeFactory? =
105106
classDeclaration.getAnnotationsByType(CustomSerdeFactory::class).firstOrNull()
106107
if (customSerdeFactory != null) {
107-
serdeFactoryDecl = "new " + customSerdeFactory.value + "()"
108+
serdeFactoryDecl = parseAnnotationClassParameter { customSerdeFactory.value } + "()"
108109
}
109110
data.withSerdeFactoryDecl(serdeFactoryDecl)
110111
}
@@ -335,4 +336,18 @@ class KElementConverter(
335336

336337
return typeName
337338
}
339+
340+
@OptIn(KspExperimental::class)
341+
private fun parseAnnotationClassParameter(block: () -> KClass<*>): String? {
342+
return try { // KSTypeNotPresentException will be thrown
343+
block.invoke().qualifiedName
344+
} catch (e: KSTypeNotPresentException) {
345+
var res: String? = null
346+
val declaration = e.ksType.declaration
347+
if (declaration is KSClassDeclaration) {
348+
declaration.qualifiedName?.asString()?.let { res = it }
349+
}
350+
res
351+
}
352+
}
338353
}

sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/serialization/KotlinSerializationSerdeFactory.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
4242

4343
@PublishedApi
4444
internal class KtTypeTag<T>(
45-
internal val type: KClass<*>,
45+
val type: KClass<*>,
4646
/** Reified type */
47-
internal val kotlinType: KType?
47+
val kotlinType: KType?
4848
) : TypeTag<T>
4949

5050
override fun <T : Any?> create(typeTag: TypeTag<T>): Serde<T> {

sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenTest.kt

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ import dev.restate.sdk.core.statemachine.ProtoUtils.*
1919
import dev.restate.sdk.kotlin.*
2020
import dev.restate.sdk.kotlin.serialization.*
2121
import dev.restate.serde.Serde
22+
import dev.restate.serde.SerdeFactory
23+
import dev.restate.serde.TypeRef
24+
import dev.restate.serde.TypeTag
2225
import java.util.stream.Stream
2326
import kotlinx.serialization.Serializable
2427

@@ -212,6 +215,33 @@ class CodegenTest : TestDefinitions.TestSuite {
212215
}
213216
}
214217

218+
class MyCustomSerdeFactory : SerdeFactory {
219+
override fun <T : Any?> create(typeTag: TypeTag<T?>): Serde<T?> {
220+
check(typeTag is KotlinSerializationSerdeFactory.KtTypeTag)
221+
check(typeTag.type == Byte::class)
222+
return Serde.using<Byte>({ b -> byteArrayOf(b) }, { it[0] }) as Serde<T?>
223+
}
224+
225+
override fun <T : Any?> create(typeRef: TypeRef<T?>): Serde<T?> {
226+
check(typeRef.type == Byte::class)
227+
return Serde.using<Byte>({ b -> byteArrayOf(b) }, { it[0] }) as Serde<T?>
228+
}
229+
230+
override fun <T : Any?> create(clazz: Class<T?>?): Serde<T?> {
231+
check(clazz == Byte::class.java)
232+
return Serde.using<Byte>({ b -> byteArrayOf(b) }, { it[0] }) as Serde<T?>
233+
}
234+
}
235+
236+
@CustomSerdeFactory(MyCustomSerdeFactory::class)
237+
@Service(name = "CustomSerdeService")
238+
class CustomSerdeService {
239+
@Handler
240+
suspend fun echo(context: Context, input: Byte): Byte {
241+
return input
242+
}
243+
}
244+
215245
override fun definitions(): Stream<TestDefinition> {
216246
return Stream.of(
217247
testInvocation({ ServiceGreeter() }, "greet")
@@ -358,6 +388,10 @@ class CodegenTest : TestDefinitions.TestSuite {
358388
Slice.EMPTY),
359389
outputCmd(),
360390
END_MESSAGE),
391+
testInvocation({ CustomSerdeService() }, "echo")
392+
.withInput(startMessage(1), inputCmd(byteArrayOf(1)))
393+
.onlyBidiStream()
394+
.expectingOutput(outputCmd(byteArrayOf(1)), END_MESSAGE),
361395
)
362396
}
363397
}

0 commit comments

Comments
 (0)