Skip to content

Commit 4ad631a

Browse files
drganjooFahad Zubair
andauthored
Add support for customizing union variants in JSON/CBOR serialization/de-serialization (#3970)
This PR introduces customization points in JSON and CBOR serialization / deserialization logic to support customization of the wire format of Union variant keys. --------- Co-authored-by: Fahad Zubair <fahadzub@amazon.com>
1 parent 1f9c608 commit 4ad631a

File tree

10 files changed

+314
-28
lines changed

10 files changed

+314
-28
lines changed

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,30 @@ operator fun Writable.plus(other: Writable): Writable =
4545

4646
/**
4747
* Helper allowing a `Iterable<Writable>` to be joined together using a `String` separator.
48+
* @param separator The string to use as a separator between elements
49+
* @param prefix An optional string to prepend to the entire joined sequence (defaults to null)
50+
* @return A Writable containing the optionally prefixed, joined elements
4851
*/
49-
fun Iterable<Writable>.join(separator: String) = join(writable(separator))
52+
fun Iterable<Writable>.join(
53+
separator: String,
54+
prefix: String? = null,
55+
) = join(writable(separator), prefix?.let { writable(it) })
5056

5157
/**
5258
* Helper allowing a `Iterable<Writable>` to be joined together using a `Writable` separator.
59+
* @param separator The Writable to use as a separator between elements
60+
* @param prefix An optional Writable to prepend to the entire joined sequence (defaults to null)
61+
* @return A Writable containing the optionally prefixed, joined elements
5362
*/
54-
fun Iterable<Writable>.join(separator: Writable): Writable {
63+
fun Iterable<Writable>.join(
64+
separator: Writable,
65+
prefix: Writable? = null,
66+
): Writable {
5567
val iter = this.iterator()
5668
return writable {
69+
if (iter.hasNext() && prefix != null) {
70+
prefix()
71+
}
5772
iter.forEach { value ->
5873
value()
5974
if (iter.hasNext()) {

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
1919
import software.amazon.smithy.rust.codegen.core.rustlang.writable
2020
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
2121
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
22+
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserCustomization
2223
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator
2324
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
25+
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization
2426
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator
2527
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator
2628
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
@@ -92,7 +94,11 @@ class RpcV2CborHttpBindingResolver(
9294
ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/cbor")
9395
}
9496

95-
open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
97+
open class RpcV2Cbor(
98+
val codegenContext: CodegenContext,
99+
private val serializeCustomization: List<CborSerializerCustomization> = listOf(),
100+
private val parserCustomization: List<CborParserCustomization> = listOf(),
101+
) : Protocol {
96102
private val runtimeConfig = codegenContext.runtimeConfig
97103

98104
override val httpBindingResolver: HttpBindingResolver =
@@ -134,10 +140,11 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
134140
)
135141
}
136142
},
143+
customizations = parserCustomization,
137144
)
138145

139146
override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
140-
CborSerializerGenerator(codegenContext, httpBindingResolver)
147+
CborSerializerGenerator(codegenContext, httpBindingResolver, customizations = serializeCustomization)
141148

142149
override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
143150
RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata")

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
2929
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
3030
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
3131
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
32+
import software.amazon.smithy.rust.codegen.core.rustlang.join
3233
import software.amazon.smithy.rust.codegen.core.rustlang.rust
3334
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
3435
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
@@ -57,10 +58,29 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape
5758
/** Class describing a CBOR parser section that can be used in a customization. */
5859
sealed class CborParserSection(name: String) : Section(name) {
5960
data class BeforeBoxingDeserializedMember(val shape: MemberShape) : CborParserSection("BeforeBoxingDeserializedMember")
61+
62+
/**
63+
* Represents a customization point in union deserialization that occurs before decoding the map structure.
64+
* This allows for custom handling of union variants before the standard map decoding logic is applied.
65+
* @property shape The union shape being deserialized.
66+
*/
67+
data class UnionParserBeforeDecodingMap(val shape: UnionShape) : CborParserSection("UnionParserBeforeDecodingMap")
6068
}
6169

62-
/** Customization for the CBOR parser. */
63-
typealias CborParserCustomization = NamedCustomization<CborParserSection>
70+
/**
71+
* Customization class for CBOR parser generation that allows modification of union type deserialization behavior.
72+
* Previously, union variant discrimination was hardcoded to use `decoder.str()`. This has been made more flexible
73+
* to support different decoder implementations and discrimination methods.
74+
*/
75+
abstract class CborParserCustomization : NamedCustomization<CborParserSection>() {
76+
/**
77+
* Allows customization of how union variants are discriminated during deserialization.
78+
* @param defaultContext The default discrimination context containing decoder symbol and discriminator method.
79+
* @return UnionVariantDiscriminatorContext that defines how to discriminate union variants.
80+
*/
81+
open fun getUnionVariantDiscriminator(defaultContext: CborParserGenerator.UnionVariantDiscriminatorContext) =
82+
defaultContext
83+
}
6484

6585
class CborParserGenerator(
6686
private val codegenContext: CodegenContext,
@@ -75,6 +95,16 @@ class CborParserGenerator(
7595
private val shouldWrapBuilderMemberSetterInputWithOption: (MemberShape) -> Boolean = { _ -> true },
7696
private val customizations: List<CborParserCustomization> = emptyList(),
7797
) : StructuredDataParserGenerator {
98+
/**
99+
* Context class that encapsulates the information needed to discriminate union variants during deserialization.
100+
* @property decoderSymbol The symbol representing the decoder type.
101+
* @property variantDiscriminatorExpression The method call expression to determine the union variant.
102+
*/
103+
data class UnionVariantDiscriminatorContext(
104+
val decoderSymbol: Symbol,
105+
val variantDiscriminatorExpression: Writable,
106+
)
107+
78108
private val model = codegenContext.model
79109
private val symbolProvider = codegenContext.symbolProvider
80110
private val runtimeConfig = codegenContext.runtimeConfig
@@ -298,16 +328,26 @@ class CborParserGenerator(
298328
private fun unionPairParserFnWritable(shape: UnionShape) =
299329
writable {
300330
val returnSymbolToParse = returnSymbolToParse(shape)
331+
// Get actual decoder type to use and the discriminating function to call to extract
332+
// the variant of the union that has been encoded in the data.
333+
val discriminatorContext = getUnionDiscriminatorContext("Decoder", "decoder.str()?.as_ref()")
334+
301335
rustBlockTemplate(
302336
"""
303337
fn pair(
304-
decoder: &mut #{Decoder}
338+
decoder: &mut #{DecoderSymbol}
305339
) -> #{Result}<#{UnionSymbol}, #{Error}>
306340
""",
307341
*codegenScope,
342+
"DecoderSymbol" to discriminatorContext.decoderSymbol,
308343
"UnionSymbol" to returnSymbolToParse.symbol,
309344
) {
310-
withBlock("Ok(match decoder.str()?.as_ref() {", "})") {
345+
rustTemplate(
346+
"""
347+
Ok(match #{VariableDiscriminatingExpression} {
348+
""",
349+
"VariableDiscriminatingExpression" to discriminatorContext.variantDiscriminatorExpression,
350+
).run {
311351
for (member in shape.members()) {
312352
val variantName = symbolProvider.toMemberName(member)
313353

@@ -349,9 +389,24 @@ class CborParserGenerator(
349389
)
350390
}
351391
}
392+
rust("})")
352393
}
353394
}
354395

396+
private fun getUnionDiscriminatorContext(
397+
decoderType: String,
398+
callMethod: String,
399+
): UnionVariantDiscriminatorContext {
400+
val defaultUnionPairContext =
401+
UnionVariantDiscriminatorContext(
402+
smithyCbor.resolve(decoderType).toSymbol(),
403+
writable { rustTemplate(callMethod) },
404+
)
405+
return customizations.fold(defaultUnionPairContext) { context, customization ->
406+
customization.getUnionVariantDiscriminator(context)
407+
}
408+
}
409+
355410
enum class CollectionKind {
356411
Map,
357412
List,
@@ -677,12 +732,22 @@ class CborParserGenerator(
677732

678733
private fun RustWriter.deserializeUnion(shape: UnionShape) {
679734
val returnSymbolToParse = returnSymbolToParse(shape)
735+
val beforeDecoderMapCustomization =
736+
customizations.map { customization ->
737+
customization.section(
738+
CborParserSection.UnionParserBeforeDecodingMap(
739+
shape,
740+
),
741+
)
742+
}.join("")
743+
680744
val parser =
681745
protocolFunctions.deserializeFn(shape) { fnName ->
682746
rustTemplate(
683747
"""
684748
pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{UnionSymbol}, #{Error}> {
685749
#{UnionPairParserFnWritable}
750+
#{BeforeDecoderMapCustomization:W}
686751
687752
match decoder.map()? {
688753
None => {
@@ -707,6 +772,7 @@ class CborParserGenerator(
707772
""",
708773
"UnionSymbol" to returnSymbolToParse.symbol,
709774
"UnionPairParserFnWritable" to unionPairParserFnWritable(shape),
775+
"BeforeDecoderMapCustomization" to beforeDecoderMapCustomization,
710776
*codegenScope,
711777
)
712778
}

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ sealed class JsonParserSection(name: String) : Section(name) {
7070

7171
data class AfterDocumentDeserializedMember(val shape: MemberShape) :
7272
JsonParserSection("AfterDocumentDeserializedMember")
73+
74+
/**
75+
* Represents a customization point at the beginning of union deserialization, before any token
76+
* processing occurs.
77+
*/
78+
data class BeforeUnionDeserialize(val shape: UnionShape) :
79+
JsonParserSection("BeforeUnionDeserialize")
7380
}
7481

7582
/**
@@ -548,6 +555,12 @@ class JsonParserGenerator(
548555
*codegenScope,
549556
"Shape" to returnSymbolToParse.symbol,
550557
) {
558+
// Apply any custom union deserialization logic before processing tokens.
559+
// This allows for customization of how union variants are handled,
560+
// particularly their discrimination mechanism.
561+
for (customization in customizations) {
562+
customization.section(JsonParserSection.BeforeUnionDeserialize(shape))(this)
563+
}
551564
rust("let mut variant = None;")
552565
val checkValueSet = !shape.members().all { it.isTargetUnit() } && !codegenTarget.renderUnknownVariant()
553566
rustBlock("match tokens.next().transpose()?") {

0 commit comments

Comments
 (0)