Skip to content

Commit 32b6f56

Browse files
authored
Update UnionGenerator to handle members named Unknown (#4132)
2 parents a9aeece + ac7cedc commit 32b6f56

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

.changelog/unknown-variants.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
applies_to: ["client", "server"]
3+
authors: ["rcoh"]
4+
references: ["smithy-rs#4132"]
5+
breaking: false
6+
new_feature: false
7+
bug_fix: true
8+
---
9+
10+
Smithy unions that contain members named "unknown" will now codegen correctly

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGenerator.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ open class UnionGenerator(
110110
private fun renderImplBlock(unionSymbol: Symbol) {
111111
writer.rustBlock("impl ${unionSymbol.name}") {
112112
sortedMembers.forEach { member ->
113-
val funcNamePart = member.memberName.toSnakeCase()
113+
// We need to get the symbol first because the member can be renamed
114+
val funcNamePart = symbolProvider.toSymbol(member).name.toSnakeCase()
114115
val variantName = symbolProvider.toMemberName(member)
115116

116117
if (sortedMembers.size == 1) {
@@ -219,7 +220,10 @@ private fun RustWriter.renderAsVariant(
219220
targetSymbol,
220221
)
221222
rust("/// Returns `Err(&Self)` if it can't be converted.")
222-
rustBlockTemplate("pub fn as_$funcNamePart(&self) -> #{Result}<&${memberSymbol.rustType().render()}, &Self>", *preludeScope) {
223+
rustBlockTemplate(
224+
"pub fn as_$funcNamePart(&self) -> #{Result}<&${memberSymbol.rustType().render()}, &Self>",
225+
*preludeScope,
226+
) {
223227
rustTemplate(
224228
"if let ${unionSymbol.name}::$variantName(val) = &self { #{Ok}(val) } else { #{Err}(self) }",
225229
*preludeScope,

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators
77

88
import io.kotest.matchers.string.shouldContain
99
import org.junit.jupiter.api.Test
10-
import software.amazon.smithy.codegen.core.SymbolProvider
10+
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordConfig
11+
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider
1112
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
1213
import software.amazon.smithy.rust.codegen.core.rustlang.rust
14+
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
1315
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
1416
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
1517
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
@@ -42,6 +44,28 @@ class UnionGeneratorTest {
4244
writer.toString() shouldContain "#[non_exhaustive]"
4345
}
4446

47+
@Test
48+
fun `generate basic union with member names Unknown`() {
49+
val writer =
50+
generateUnion(
51+
"""
52+
union MyUnion {
53+
unknown: String
54+
}
55+
""",
56+
)
57+
58+
writer.compileAndTest(
59+
"""
60+
let var_a = MyUnion::UnknownValue("abc".to_string());
61+
let var_b = MyUnion::Unknown;
62+
assert_ne!(var_a, var_b);
63+
assert_eq!(var_a, var_a);
64+
""",
65+
)
66+
writer.toString() shouldContain "#[non_exhaustive]"
67+
}
68+
4569
@Test
4670
fun `generate conversion helper methods`() {
4771
val writer =
@@ -232,9 +256,31 @@ class UnionGeneratorTest {
232256
unknownVariant: Boolean = true,
233257
): RustWriter {
234258
val model = "namespace test\n$modelSmithy".asSmithyModel()
235-
val provider: SymbolProvider = testSymbolProvider(model)
259+
// Reserved words to test generation of renamed members
260+
val reservedWords =
261+
RustReservedWordConfig(
262+
structureMemberMap =
263+
StructureGenerator.structureMemberNameMap,
264+
unionMemberMap =
265+
mapOf(
266+
// Unions contain an `Unknown` variant. This exists to support parsing data returned from the server
267+
// that represent union variants that have been added since this SDK was generated.
268+
UnionGenerator.UNKNOWN_VARIANT_NAME to "${UnionGenerator.UNKNOWN_VARIANT_NAME}Value",
269+
"${UnionGenerator.UNKNOWN_VARIANT_NAME}Value" to "${UnionGenerator.UNKNOWN_VARIANT_NAME}Value_",
270+
),
271+
enumMemberMap =
272+
mapOf(),
273+
)
274+
val provider: RustSymbolProvider = testSymbolProvider(model)
275+
val reservedWordsProvider = RustReservedWordSymbolProvider(provider, reservedWords)
236276
val writer = RustWriter.forModule("model")
237-
UnionGenerator(model, provider, writer, model.lookup("test#$unionName"), renderUnknownVariant = unknownVariant).render()
277+
UnionGenerator(
278+
model,
279+
reservedWordsProvider,
280+
writer,
281+
model.lookup("test#$unionName"),
282+
renderUnknownVariant = unknownVariant,
283+
).render()
238284
return writer
239285
}
240286
}

0 commit comments

Comments
 (0)