@@ -29,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
29
29
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
30
30
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
31
31
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
32
+ import software.amazon.smithy.rust.codegen.core.rustlang.join
32
33
import software.amazon.smithy.rust.codegen.core.rustlang.rust
33
34
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
34
35
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
@@ -57,10 +58,29 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape
57
58
/* * Class describing a CBOR parser section that can be used in a customization. */
58
59
sealed class CborParserSection (name : String ) : Section(name) {
59
60
data class BeforeBoxingDeserializedMember (val shape : MemberShape ) : CborParserSection(" BeforeBoxingDeserializedMember" )
61
+
62
+ /* *
63
+ * Represents a customization point in union deserialization that occurs before decoding the map structure.
64
+ * This allows for custom handling of union variants before the standard map decoding logic is applied.
65
+ * @property shape The union shape being deserialized.
66
+ */
67
+ data class UnionParserBeforeDecodingMap (val shape : UnionShape ) : CborParserSection(" UnionParserBeforeDecodingMap" )
60
68
}
61
69
62
- /* * Customization for the CBOR parser. */
63
- typealias CborParserCustomization = NamedCustomization <CborParserSection >
70
+ /* *
71
+ * Customization class for CBOR parser generation that allows modification of union type deserialization behavior.
72
+ * Previously, union variant discrimination was hardcoded to use `decoder.str()`. This has been made more flexible
73
+ * to support different decoder implementations and discrimination methods.
74
+ */
75
+ abstract class CborParserCustomization : NamedCustomization <CborParserSection >() {
76
+ /* *
77
+ * Allows customization of how union variants are discriminated during deserialization.
78
+ * @param defaultContext The default discrimination context containing decoder symbol and discriminator method.
79
+ * @return UnionVariantDiscriminatorContext that defines how to discriminate union variants.
80
+ */
81
+ open fun getUnionVariantDiscriminator (defaultContext : CborParserGenerator .UnionVariantDiscriminatorContext ) =
82
+ defaultContext
83
+ }
64
84
65
85
class CborParserGenerator (
66
86
private val codegenContext : CodegenContext ,
@@ -75,6 +95,16 @@ class CborParserGenerator(
75
95
private val shouldWrapBuilderMemberSetterInputWithOption : (MemberShape ) -> Boolean = { _ -> true },
76
96
private val customizations : List <CborParserCustomization > = emptyList(),
77
97
) : StructuredDataParserGenerator {
98
+ /* *
99
+ * Context class that encapsulates the information needed to discriminate union variants during deserialization.
100
+ * @property decoderSymbol The symbol representing the decoder type.
101
+ * @property variantDiscriminatorExpression The method call expression to determine the union variant.
102
+ */
103
+ data class UnionVariantDiscriminatorContext (
104
+ val decoderSymbol : Symbol ,
105
+ val variantDiscriminatorExpression : Writable ,
106
+ )
107
+
78
108
private val model = codegenContext.model
79
109
private val symbolProvider = codegenContext.symbolProvider
80
110
private val runtimeConfig = codegenContext.runtimeConfig
@@ -298,16 +328,26 @@ class CborParserGenerator(
298
328
private fun unionPairParserFnWritable (shape : UnionShape ) =
299
329
writable {
300
330
val returnSymbolToParse = returnSymbolToParse(shape)
331
+ // Get actual decoder type to use and the discriminating function to call to extract
332
+ // the variant of the union that has been encoded in the data.
333
+ val discriminatorContext = getUnionDiscriminatorContext(" Decoder" , " decoder.str()?.as_ref()" )
334
+
301
335
rustBlockTemplate(
302
336
"""
303
337
fn pair(
304
- decoder: &mut #{Decoder }
338
+ decoder: &mut #{DecoderSymbol }
305
339
) -> #{Result}<#{UnionSymbol}, #{Error}>
306
340
""" ,
307
341
* codegenScope,
342
+ " DecoderSymbol" to discriminatorContext.decoderSymbol,
308
343
" UnionSymbol" to returnSymbolToParse.symbol,
309
344
) {
310
- withBlock(" Ok(match decoder.str()?.as_ref() {" , " })" ) {
345
+ rustTemplate(
346
+ """
347
+ Ok(match #{VariableDiscriminatingExpression} {
348
+ """ ,
349
+ " VariableDiscriminatingExpression" to discriminatorContext.variantDiscriminatorExpression,
350
+ ).run {
311
351
for (member in shape.members()) {
312
352
val variantName = symbolProvider.toMemberName(member)
313
353
@@ -349,9 +389,24 @@ class CborParserGenerator(
349
389
)
350
390
}
351
391
}
392
+ rust(" })" )
352
393
}
353
394
}
354
395
396
+ private fun getUnionDiscriminatorContext (
397
+ decoderType : String ,
398
+ callMethod : String ,
399
+ ): UnionVariantDiscriminatorContext {
400
+ val defaultUnionPairContext =
401
+ UnionVariantDiscriminatorContext (
402
+ smithyCbor.resolve(decoderType).toSymbol(),
403
+ writable { rustTemplate(callMethod) },
404
+ )
405
+ return customizations.fold(defaultUnionPairContext) { context, customization ->
406
+ customization.getUnionVariantDiscriminator(context)
407
+ }
408
+ }
409
+
355
410
enum class CollectionKind {
356
411
Map ,
357
412
List ,
@@ -677,12 +732,22 @@ class CborParserGenerator(
677
732
678
733
private fun RustWriter.deserializeUnion (shape : UnionShape ) {
679
734
val returnSymbolToParse = returnSymbolToParse(shape)
735
+ val beforeDecoderMapCustomization =
736
+ customizations.map { customization ->
737
+ customization.section(
738
+ CborParserSection .UnionParserBeforeDecodingMap (
739
+ shape,
740
+ ),
741
+ )
742
+ }.join(" " )
743
+
680
744
val parser =
681
745
protocolFunctions.deserializeFn(shape) { fnName ->
682
746
rustTemplate(
683
747
"""
684
748
pub(crate) fn $fnName (decoder: &mut #{Decoder}) -> #{Result}<#{UnionSymbol}, #{Error}> {
685
749
#{UnionPairParserFnWritable}
750
+ #{BeforeDecoderMapCustomization:W}
686
751
687
752
match decoder.map()? {
688
753
None => {
@@ -707,6 +772,7 @@ class CborParserGenerator(
707
772
""" ,
708
773
" UnionSymbol" to returnSymbolToParse.symbol,
709
774
" UnionPairParserFnWritable" to unionPairParserFnWritable(shape),
775
+ " BeforeDecoderMapCustomization" to beforeDecoderMapCustomization,
710
776
* codegenScope,
711
777
)
712
778
}
0 commit comments