Skip to content

Commit 78b6cb6

Browse files
Generate Json Schemas for Kotlin (#495)
1 parent 2564866 commit 78b6cb6

File tree

8 files changed

+391
-85
lines changed

8 files changed

+391
-85
lines changed

examples/src/main/kotlin/my/restate/sdk/examples/CounterKt.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class CounterKt {
4242

4343
@Handler
4444
@Shared
45-
suspend fun get(ctx: SharedObjectContext): Long {
46-
return ctx.get(TOTAL) ?: 0L
45+
suspend fun get(ctx: SharedObjectContext): Long? {
46+
return ctx.get(TOTAL)
4747
}
4848

4949
@Handler

gradle/libs.versions.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,24 @@
181181
[libraries.victools-jsonschema-module-jackson.version]
182182
ref = 'victools-json-schema'
183183

184+
[libraries.schema-kenerator-core]
185+
module = 'io.github.smiley4:schema-kenerator-core'
186+
187+
[libraries.schema-kenerator-core.version]
188+
ref = 'schema-kenerator'
189+
190+
[libraries.schema-kenerator-serialization]
191+
module = 'io.github.smiley4:schema-kenerator-serialization'
192+
193+
[libraries.schema-kenerator-serialization.version]
194+
ref = 'schema-kenerator'
195+
196+
[libraries.schema-kenerator-jsonschema]
197+
module = 'io.github.smiley4:schema-kenerator-jsonschema'
198+
199+
[libraries.schema-kenerator-jsonschema.version]
200+
ref = 'schema-kenerator'
201+
184202
[plugins]
185203
aggregate-javadoc = 'io.freefair.aggregate-javadoc:8.6'
186204
dependency-license-report = 'com.github.jk1.dependency-license-report:2.0'
@@ -213,3 +231,4 @@
213231
spring-boot = '3.4.4'
214232
vertx = '4.5.11'
215233
victools-json-schema = '4.37.0'
234+
schema-kenerator = '2.1.2'

sdk-serde-jackson/src/test/java/dev/restate/serde/jackson/JacksonSerdesTest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,21 @@
1313
import com.fasterxml.jackson.annotation.JsonCreator;
1414
import com.fasterxml.jackson.annotation.JsonProperty;
1515
import com.fasterxml.jackson.core.type.TypeReference;
16+
import com.fasterxml.jackson.databind.node.ObjectNode;
1617
import dev.restate.serde.Serde;
1718
import java.util.List;
1819
import java.util.Objects;
1920
import java.util.Set;
2021
import java.util.stream.Stream;
22+
import org.junit.jupiter.api.Test;
2123
import org.junit.jupiter.params.ParameterizedTest;
2224
import org.junit.jupiter.params.provider.Arguments;
2325
import org.junit.jupiter.params.provider.MethodSource;
2426

2527
class JacksonSerdesTest {
2628

29+
record Recursive(String value, Recursive rec) {}
30+
2731
public static class Person {
2832

2933
private final String name;
@@ -75,4 +79,11 @@ private static Stream<Arguments> roundtripTestCases() {
7579
<T> void roundtrip(T value, Serde<T> serde) {
7680
assertThat(serde.deserialize(serde.serialize(value))).isEqualTo(value);
7781
}
82+
83+
@Test
84+
void schemaGenWorksWithRecursion() {
85+
ObjectNode node =
86+
(ObjectNode) ((Serde.JsonSchema) JacksonSerdes.of(Recursive.class).jsonSchema()).schema();
87+
assertThat(node.at("/properties/rec/$ref").textValue()).isEqualTo("#");
88+
}
7889
}

sdk-serde-kotlinx/build.gradle.kts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ description = "Restate SDK Kotlinx Serialization integration"
88
dependencies {
99
api(libs.kotlinx.serialization.json)
1010
implementation(libs.kotlinx.serialization.core)
11+
implementation(libs.schema.kenerator.core)
12+
implementation(libs.schema.kenerator.serialization)
13+
implementation(libs.schema.kenerator.jsonschema)
1114

1215
implementation(project(":common"))
16+
17+
testImplementation(libs.junit.jupiter)
18+
testImplementation(libs.assertj)
1319
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
2+
//
3+
// This file is part of the Restate Java SDK,
4+
// which is released under the MIT license.
5+
//
6+
// You can find a copy of the license in file LICENSE in the root
7+
// directory of this repository or package, or at
8+
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
9+
package dev.restate.serde.kotlinx
10+
11+
import dev.restate.serde.Serde
12+
import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps
13+
import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps.compileReferencing
14+
import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps.generateJsonSchema
15+
import io.github.smiley4.schemakenerator.jsonschema.TitleBuilder
16+
import io.github.smiley4.schemakenerator.jsonschema.data.IntermediateJsonSchemaData
17+
import io.github.smiley4.schemakenerator.jsonschema.data.RefType
18+
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonArray
19+
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonNode
20+
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonObject
21+
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonTextValue
22+
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.array
23+
import io.github.smiley4.schemakenerator.serialization.SerializationSteps.analyzeTypeUsingKotlinxSerialization
24+
import io.github.smiley4.schemakenerator.serialization.SerializationSteps.initial
25+
import io.github.smiley4.schemakenerator.serialization.SerializationSteps.renameMembers
26+
import kotlin.collections.set
27+
import kotlinx.serialization.ExperimentalSerializationApi
28+
import kotlinx.serialization.KSerializer
29+
import kotlinx.serialization.json.Json
30+
31+
object DefaultJsonSchemaFactory : KotlinSerializationSerdeFactory.JsonSchemaFactory {
32+
@OptIn(ExperimentalSerializationApi::class)
33+
override fun generateSchema(json: Json, serializer: KSerializer<*>) =
34+
Serde.StringifiedJsonSchema(
35+
runCatching {
36+
var initialStep =
37+
initial(serializer.descriptor).analyzeTypeUsingKotlinxSerialization {
38+
serializersModule = json.serializersModule
39+
}
40+
41+
if (json.configuration.namingStrategy != null) {
42+
initialStep = initialStep.renameMembers(json.configuration.namingStrategy!!)
43+
}
44+
45+
val intermediateStep =
46+
initialStep.generateJsonSchema {
47+
optionalHandling = JsonSchemaSteps.OptionalHandling.NON_REQUIRED
48+
}
49+
intermediateStep.writeTitles()
50+
val compiledSchema = intermediateStep.compileReferencing(RefType.SIMPLE)
51+
52+
// In case of nested schemas, compileReferencing also contains self schema...
53+
val rootSchemaName =
54+
TitleBuilder.BUILDER_SIMPLE(
55+
compiledSchema.typeData, intermediateStep.typeDataById)
56+
57+
// If schema is not json object, then it's boolean, so we're good no need for
58+
// additional manipulation
59+
if (compiledSchema.json !is JsonObject) {
60+
return@runCatching compiledSchema.json
61+
}
62+
63+
// Assemble the final schema now
64+
val rootNode = compiledSchema.json as JsonObject
65+
// Add $schema
66+
rootNode.properties.put(
67+
"\$schema", JsonTextValue("https://json-schema.org/draft/2020-12/schema"))
68+
// Add $defs
69+
val definitions =
70+
compiledSchema.definitions.filter { it.key != rootSchemaName }.toMutableMap()
71+
if (definitions.isNotEmpty()) {
72+
rootNode.properties.put("\$defs", JsonObject(definitions))
73+
}
74+
// Replace all $refs
75+
rootNode.fixRefsPrefix("#/definitions/$rootSchemaName")
76+
// If the root type is nullable, it should be in the schema too
77+
if (serializer.descriptor.isNullable) {
78+
val oldTypeProperty = rootNode.properties["type"]
79+
if (oldTypeProperty is JsonTextValue) {
80+
rootNode.properties["type"] = array {
81+
item(oldTypeProperty.value)
82+
item(JsonTextValue("null"))
83+
}
84+
} else if (oldTypeProperty is JsonArray) {
85+
oldTypeProperty.items.add(JsonTextValue("null"))
86+
}
87+
}
88+
89+
return@runCatching rootNode
90+
}
91+
.getOrDefault(JsonObject(mutableMapOf()))
92+
.prettyPrint())
93+
94+
private fun IntermediateJsonSchemaData.writeTitles() {
95+
this.entries.forEach { schema ->
96+
if (schema.json is JsonObject) {
97+
if ((schema.typeData.isMap ||
98+
schema.typeData.isCollection ||
99+
schema.typeData.isEnum ||
100+
schema.typeData.isInlineValue ||
101+
schema.typeData.typeParameters.isNotEmpty() ||
102+
schema.typeData.members.isNotEmpty()) &&
103+
(schema.json as JsonObject).properties["title"] == null) {
104+
(schema.json as JsonObject).properties["title"] =
105+
JsonTextValue(TitleBuilder.BUILDER_SIMPLE(schema.typeData, this.typeDataById))
106+
}
107+
}
108+
}
109+
}
110+
111+
private fun JsonNode.fixRefsPrefix(rootDefinition: String) {
112+
when (this) {
113+
is JsonArray -> this.items.forEach { it.fixRefsPrefix(rootDefinition) }
114+
is JsonObject -> this.fixRefsPrefix(rootDefinition)
115+
else -> {}
116+
}
117+
}
118+
119+
private fun JsonObject.fixRefsPrefix(rootDefinition: String) {
120+
this.properties.computeIfPresent("\$ref") { key, node ->
121+
if (node is JsonTextValue) {
122+
if (node.value.startsWith(rootDefinition)) {
123+
JsonTextValue("#/" + node.value.removePrefix(rootDefinition))
124+
} else {
125+
JsonTextValue("#/\$defs/" + node.value.removePrefix("#/definitions/"))
126+
}
127+
} else {
128+
node
129+
}
130+
}
131+
this.properties.values.forEach { it.fixRefsPrefix(rootDefinition) }
132+
}
133+
}

sdk-serde-kotlinx/src/main/kotlin/dev/restate/serde/kotlinx/KotlinSerializationSerdeFactory.kt

Lines changed: 30 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,17 @@ package dev.restate.serde.kotlinx
1010

1111
import dev.restate.common.Slice
1212
import dev.restate.serde.Serde
13+
import dev.restate.serde.Serde.Schema
1314
import dev.restate.serde.SerdeFactory
1415
import dev.restate.serde.TypeRef
1516
import dev.restate.serde.TypeTag
1617
import java.nio.charset.StandardCharsets
1718
import kotlin.reflect.KClass
1819
import kotlin.reflect.KType
1920
import kotlinx.serialization.*
20-
import kotlinx.serialization.builtins.*
21-
import kotlinx.serialization.descriptors.PrimitiveKind
22-
import kotlinx.serialization.descriptors.SerialDescriptor
23-
import kotlinx.serialization.descriptors.StructureKind
24-
import kotlinx.serialization.encodeToString
21+
import kotlinx.serialization.builtins.nullable
2522
import kotlinx.serialization.json.Json
26-
import kotlinx.serialization.json.JsonArray
27-
import kotlinx.serialization.json.JsonElement
2823
import kotlinx.serialization.json.JsonNull
29-
import kotlinx.serialization.json.JsonTransformingSerializer
3024
import kotlinx.serialization.modules.SerializersModule
3125

3226
/**
@@ -38,7 +32,22 @@ import kotlinx.serialization.modules.SerializersModule
3832
*/
3933
open class KotlinSerializationSerdeFactory
4034
@JvmOverloads
41-
constructor(private val json: Json = Json.Default) : SerdeFactory {
35+
constructor(
36+
private val json: Json = Json.Default,
37+
private val jsonSchemaFactory: JsonSchemaFactory = DefaultJsonSchemaFactory
38+
) : SerdeFactory {
39+
40+
/** Factory to generate json schemas. */
41+
interface JsonSchemaFactory {
42+
fun generateSchema(json: Json, serializer: KSerializer<*>): Schema?
43+
44+
companion object {
45+
val NOOP =
46+
object : JsonSchemaFactory {
47+
override fun generateSchema(json: Json, serializer: KSerializer<*>): Schema? = null
48+
}
49+
}
50+
}
4251

4352
@PublishedApi
4453
internal class KtTypeTag<T>(
@@ -61,7 +70,7 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
6170
}
6271
val serializer: KSerializer<T> =
6372
json.serializersModule.serializer(typeRef.type) as KSerializer<T>
64-
return jsonSerde(json, serializer)
73+
return jsonSerde(json, jsonSchemaFactory, serializer)
6574
}
6675

6776
@Suppress("UNCHECKED_CAST")
@@ -70,7 +79,7 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
7079
return UNIT as Serde<T>
7180
}
7281
val serializer: KSerializer<T> = json.serializersModule.serializer(clazz) as KSerializer<T>
73-
return jsonSerde(json, serializer)
82+
return jsonSerde(json, jsonSchemaFactory, serializer)
7483
}
7584

7685
@Suppress("UNCHECKED_CAST")
@@ -81,7 +90,7 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
8190
}
8291
val serializer: KSerializer<T> =
8392
json.serializersModule.serializerForKtTypeInfo(ktSerdeInfo) as KSerializer<T>
84-
return jsonSerde(json, serializer)
93+
return jsonSerde(json, jsonSchemaFactory, serializer)
8594
}
8695

8796
companion object {
@@ -103,7 +112,13 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
103112
}
104113

105114
/** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */
106-
fun <T : Any?> jsonSerde(json: Json = Json.Default, serializer: KSerializer<T>): Serde<T> {
115+
fun <T : Any?> jsonSerde(
116+
json: Json = Json.Default,
117+
jsonSchemaFactory: JsonSchemaFactory = DefaultJsonSchemaFactory,
118+
serializer: KSerializer<T>
119+
): Serde<T> {
120+
val schema = jsonSchemaFactory.generateSchema(json, serializer)
121+
107122
return object : Serde<T> {
108123
@Suppress("WRONG_NULLABILITY_FOR_JAVA_OVERRIDE")
109124
override fun serialize(value: T?): Slice {
@@ -123,77 +138,11 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
123138
return "application/json"
124139
}
125140

126-
override fun jsonSchema(): Serde.Schema {
127-
val schema: JsonSchema = serializer.descriptor.jsonSchema()
128-
return Serde.StringifiedJsonSchema(Json.encodeToString(schema))
141+
override fun jsonSchema(): Schema? {
142+
return schema
129143
}
130144
}
131145
}
132-
133-
@Serializable
134-
@PublishedApi
135-
internal data class JsonSchema(
136-
@Serializable(with = StringListSerializer::class) val type: List<String>? = null,
137-
val format: String? = null,
138-
) {
139-
companion object {
140-
val INT = JsonSchema(type = listOf("number"), format = "int32")
141-
142-
val LONG = JsonSchema(type = listOf("number"), format = "int64")
143-
144-
val DOUBLE = JsonSchema(type = listOf("number"), format = "double")
145-
146-
val FLOAT = JsonSchema(type = listOf("number"), format = "float")
147-
148-
val STRING = JsonSchema(type = listOf("string"))
149-
150-
val BOOLEAN = JsonSchema(type = listOf("boolean"))
151-
152-
val OBJECT = JsonSchema(type = listOf("object"))
153-
154-
val LIST = JsonSchema(type = listOf("array"))
155-
156-
val ANY = JsonSchema()
157-
}
158-
}
159-
160-
object StringListSerializer :
161-
JsonTransformingSerializer<List<String>>(ListSerializer(String.Companion.serializer())) {
162-
override fun transformSerialize(element: JsonElement): JsonElement {
163-
require(element is JsonArray)
164-
return element.singleOrNull() ?: element
165-
}
166-
}
167-
168-
/**
169-
* Super simplistic json schema generation. We should replace this with an appropriate library.
170-
*/
171-
@OptIn(ExperimentalSerializationApi::class)
172-
@PublishedApi
173-
internal fun SerialDescriptor.jsonSchema(): JsonSchema {
174-
var schema =
175-
when (this.kind) {
176-
PrimitiveKind.BOOLEAN -> JsonSchema.BOOLEAN
177-
PrimitiveKind.BYTE -> JsonSchema.INT
178-
PrimitiveKind.CHAR -> JsonSchema.STRING
179-
PrimitiveKind.DOUBLE -> JsonSchema.DOUBLE
180-
PrimitiveKind.FLOAT -> JsonSchema.FLOAT
181-
PrimitiveKind.INT -> JsonSchema.INT
182-
PrimitiveKind.LONG -> JsonSchema.LONG
183-
PrimitiveKind.SHORT -> JsonSchema.INT
184-
PrimitiveKind.STRING -> JsonSchema.STRING
185-
StructureKind.LIST -> JsonSchema.LIST
186-
StructureKind.MAP -> JsonSchema.OBJECT
187-
else -> JsonSchema.ANY
188-
}
189-
190-
// Add nullability constraint
191-
if (this.isNullable && schema.type != null) {
192-
schema = schema.copy(type = schema.type.plus("null"))
193-
}
194-
195-
return schema
196-
}
197146
}
198147

199148
@InternalSerializationApi

0 commit comments

Comments
 (0)