@@ -7,9 +7,11 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators
7
7
8
8
import io.kotest.matchers.string.shouldContain
9
9
import org.junit.jupiter.api.Test
10
- import software.amazon.smithy.codegen.core.SymbolProvider
10
+ import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordConfig
11
+ import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider
11
12
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
12
13
import software.amazon.smithy.rust.codegen.core.rustlang.rust
14
+ import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
13
15
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
14
16
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
15
17
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
@@ -42,6 +44,28 @@ class UnionGeneratorTest {
42
44
writer.toString() shouldContain " #[non_exhaustive]"
43
45
}
44
46
47
+ @Test
48
+ fun `generate basic union with member names Unknown` () {
49
+ val writer =
50
+ generateUnion(
51
+ """
52
+ union MyUnion {
53
+ unknown: String
54
+ }
55
+ """ ,
56
+ )
57
+
58
+ writer.compileAndTest(
59
+ """
60
+ let var_a = MyUnion::UnknownValue("abc".to_string());
61
+ let var_b = MyUnion::Unknown;
62
+ assert_ne!(var_a, var_b);
63
+ assert_eq!(var_a, var_a);
64
+ """ ,
65
+ )
66
+ writer.toString() shouldContain " #[non_exhaustive]"
67
+ }
68
+
45
69
@Test
46
70
fun `generate conversion helper methods` () {
47
71
val writer =
@@ -232,9 +256,31 @@ class UnionGeneratorTest {
232
256
unknownVariant : Boolean = true,
233
257
): RustWriter {
234
258
val model = " namespace test\n $modelSmithy " .asSmithyModel()
235
- val provider: SymbolProvider = testSymbolProvider(model)
259
+ // Reserved words to test generation of renamed members
260
+ val reservedWords =
261
+ RustReservedWordConfig (
262
+ structureMemberMap =
263
+ StructureGenerator .structureMemberNameMap,
264
+ unionMemberMap =
265
+ mapOf (
266
+ // Unions contain an `Unknown` variant. This exists to support parsing data returned from the server
267
+ // that represent union variants that have been added since this SDK was generated.
268
+ UnionGenerator .UNKNOWN_VARIANT_NAME to " ${UnionGenerator .UNKNOWN_VARIANT_NAME } Value" ,
269
+ " ${UnionGenerator .UNKNOWN_VARIANT_NAME } Value" to " ${UnionGenerator .UNKNOWN_VARIANT_NAME } Value_" ,
270
+ ),
271
+ enumMemberMap =
272
+ mapOf (),
273
+ )
274
+ val provider: RustSymbolProvider = testSymbolProvider(model)
275
+ val reservedWordsProvider = RustReservedWordSymbolProvider (provider, reservedWords)
236
276
val writer = RustWriter .forModule(" model" )
237
- UnionGenerator (model, provider, writer, model.lookup(" test#$unionName " ), renderUnknownVariant = unknownVariant).render()
277
+ UnionGenerator (
278
+ model,
279
+ reservedWordsProvider,
280
+ writer,
281
+ model.lookup(" test#$unionName " ),
282
+ renderUnknownVariant = unknownVariant,
283
+ ).render()
238
284
return writer
239
285
}
240
286
}
0 commit comments