diff --git a/.cargo-deny-config.toml b/.cargo-deny-config.toml index 3a9407f00cb..abb8c8c0676 100644 --- a/.cargo-deny-config.toml +++ b/.cargo-deny-config.toml @@ -25,6 +25,9 @@ exceptions = [ { allow = ["OpenSSL"], name = "ring", version = "*" }, { allow = ["OpenSSL"], name = "aws-lc-sys", version = "*" }, { allow = ["OpenSSL"], name = "aws-lc-fips-sys", version = "*" }, + { allow = ["BlueOak-1.0.0"], name = "minicbor", version = "<=0.24.2" }, + # Safe to bump as long as license does not change -------------^ + # See D105255799. ] [[licenses.clarify]] diff --git a/build.gradle.kts b/build.gradle.kts index 20f2d9e4000..5e11e0ab02b 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,7 +18,7 @@ allprojects { val allowLocalDeps: String by project repositories { if (allowLocalDeps.toBoolean()) { - mavenLocal() + mavenLocal() } mavenCentral() google() diff --git a/buildSrc/src/main/kotlin/CodegenTestCommon.kt b/buildSrc/src/main/kotlin/CodegenTestCommon.kt index 3c025288ace..8e0fd36447a 100644 --- a/buildSrc/src/main/kotlin/CodegenTestCommon.kt +++ b/buildSrc/src/main/kotlin/CodegenTestCommon.kt @@ -26,9 +26,84 @@ fun generateImports(imports: List): String = if (imports.isEmpty()) { "" } else { - "\"imports\": [${imports.map { "\"$it\"" }.joinToString(", ")}]," + "\"imports\": [${imports.joinToString(", ") { "\"$it\"" }}]," } +val RustKeywords = + setOf( + "as", + "break", + "const", + "continue", + "crate", + "else", + "enum", + "extern", + "false", + "fn", + "for", + "if", + "impl", + "in", + "let", + "loop", + "match", + "mod", + "move", + "mut", + "pub", + "ref", + "return", + "self", + "Self", + "static", + "struct", + "super", + "trait", + "true", + "type", + "unsafe", + "use", + "where", + "while", + "async", + "await", + "dyn", + "abstract", + "become", + "box", + "do", + "final", + "macro", + "override", + "priv", + "typeof", + "unsized", + "virtual", + "yield", + "try", + ) + +fun toRustCrateName(input: String): String { + if (input.isBlank()) { + throw IllegalArgumentException("Rust crate name cannot be empty") + } + val lowerCased = input.lowercase() + // Replace any sequence of characters that are not lowercase letters, numbers, dashes, or underscores with a single underscore. + val sanitized = lowerCased.replace(Regex("[^a-z0-9_-]+"), "_") + // Trim leading or trailing underscores. + val trimmed = sanitized.trim('_') + // Check if the resulting string is empty, purely numeric, or a reserved name + val finalName = + when { + trimmed.isEmpty() -> throw IllegalArgumentException("Rust crate name after sanitizing cannot be empty.") + trimmed.matches(Regex("\\d+")) -> "n$trimmed" // Prepend 'n' if the name is purely numeric. + trimmed in RustKeywords -> "${trimmed}_" // Append an underscore if the name is reserved. + else -> trimmed + } + return finalName +} + private fun generateSmithyBuild( projectDir: String, pluginName: String, @@ -48,7 +123,7 @@ private fun generateSmithyBuild( ${it.extraCodegenConfig ?: ""} }, "service": "${it.service}", - "module": "${it.module}", + "module": "${toRustCrateName(it.module)}", "moduleVersion": "0.0.1", "moduleDescription": "test", "moduleAuthors": ["protocoltest@example.com"] diff --git a/buildSrc/src/main/kotlin/CrateSet.kt b/buildSrc/src/main/kotlin/CrateSet.kt index bc90115443a..253bfa08ca7 100644 --- a/buildSrc/src/main/kotlin/CrateSet.kt +++ b/buildSrc/src/main/kotlin/CrateSet.kt @@ -56,6 +56,7 @@ object CrateSet { val SMITHY_RUNTIME_COMMON = listOf( "aws-smithy-async", + "aws-smithy-cbor", "aws-smithy-checksums", "aws-smithy-compression", "aws-smithy-client", diff --git a/codegen-client/build.gradle.kts b/codegen-client/build.gradle.kts index 485a656d7b0..3e1f1ec580b 100644 --- a/codegen-client/build.gradle.kts +++ b/codegen-client/build.gradle.kts @@ -27,9 +27,10 @@ dependencies { implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") // `smithy.framework#ValidationException` is defined here, which is used in event stream -// marshalling/unmarshalling tests. + // marshalling/unmarshalling tests. testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt index c4aec33b59b..306a439a9ed 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt @@ -19,7 +19,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorC import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCodegenDecorator import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator -import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import java.util.ServiceLoader import java.util.logging.Logger @@ -93,14 +92,6 @@ interface ClientCodegenDecorator : CoreCodegenDecorator, ): List = baseCustomizations - - /** - * Hook to override the protocol test generator - */ - fun protocolTestGenerator( - codegenContext: ClientCodegenContext, - baseGenerator: ProtocolTestGenerator, - ): ProtocolTestGenerator = baseGenerator } /** @@ -176,14 +167,6 @@ open class CombinedClientCodegenDecorator(decorators: List - decorator.protocolTestGenerator(codegenContext, gen) - } - companion object { fun fromClasspath( context: PluginContext, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt index ca39264a1c0..992d85dd72a 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt @@ -28,11 +28,12 @@ class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = true } -class ClientInstantiator(private val codegenContext: ClientCodegenContext) : Instantiator( +class ClientInstantiator(private val codegenContext: ClientCodegenContext, withinTest: Boolean = false) : Instantiator( codegenContext.symbolProvider, codegenContext.model, codegenContext.runtimeConfig, ClientBuilderKindBehavior(codegenContext), + withinTest = false, ) { fun renderFluentCall( writer: RustWriter, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt index 5936438adba..85d32c2bf38 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolTestGenerator.kt @@ -114,7 +114,7 @@ class ClientProtocolTestGenerator( get() = AppliesTo.CLIENT override val expectFail: Set get() = ExpectFail - override val runOnly: Set + override val generateOnly: Set get() = emptySet() override val disabledTests: Set get() = emptySet() @@ -128,7 +128,7 @@ class ClientProtocolTestGenerator( private val inputShape = operationShape.inputShape(codegenContext.model) private val outputShape = operationShape.outputShape(codegenContext.model) - private val instantiator = ClientInstantiator(codegenContext) + private val instantiator = ClientInstantiator(codegenContext, withinTest = true) private val codegenScope = arrayOf( @@ -149,6 +149,8 @@ class ClientProtocolTestGenerator( } private fun RustWriter.renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase) { + logger.info("Generating request test: ${httpRequestTestCase.id}") + if (!protocolSupport.requestSerialization) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -234,6 +236,8 @@ class ClientProtocolTestGenerator( testCase: HttpResponseTestCase, expectedShape: StructureShape, ) { + logger.info("Generating response test: ${testCase.id}") + if (!protocolSupport.responseDeserialization || ( !protocolSupport.errorDeserialization && expectedShape.hasTrait( @@ -357,8 +361,8 @@ class ClientProtocolTestGenerator( if (body == "") { rustWriter.rustTemplate( """ - // No body - #{AssertEq}(::std::str::from_utf8(body).unwrap(), ""); + // No body. + #{AssertEq}(&body, &bytes::Bytes::new()); """, *codegenScope, ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt index 5c0f7e6e1a6..f1a01edd6bc 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext @@ -28,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.util.hasTrait class ClientProtocolLoader(supportedProtocols: ProtocolMap) : @@ -41,6 +43,7 @@ class ClientProtocolLoader(supportedProtocols: ProtocolMap { + override fun protocol(codegenContext: ClientCodegenContext): Protocol = RpcV2Cbor(codegenContext) + + override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator = + OperationGenerator(codegenContext, protocol(codegenContext)) + + override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT +} diff --git a/codegen-core/build.gradle.kts b/codegen-core/build.gradle.kts index 2fdd74abfb1..eff612be356 100644 --- a/codegen-core/build.gradle.kts +++ b/codegen-core/build.gradle.kts @@ -28,6 +28,7 @@ dependencies { implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") } fun gitCommitHash(): String { diff --git a/codegen-core/common-test-models/rpcv2Cbor-extras.smithy b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy new file mode 100644 index 00000000000..c60b93736df --- /dev/null +++ b/codegen-core/common-test-models/rpcv2Cbor-extras.smithy @@ -0,0 +1,349 @@ +$version: "2.0" + +namespace smithy.protocoltests.rpcv2Cbor + +use smithy.framework#ValidationException +use smithy.protocols#rpcv2Cbor +use smithy.test#httpResponseTests +use smithy.test#httpMalformedRequestTests + +@rpcv2Cbor +service RpcV2CborService { + operations: [ + SimpleStructOperation + ErrorSerializationOperation + ComplexStructOperation + EmptyStructOperation + SingleMemberStructOperation + ] +} + +// TODO(https://github.com/smithy-lang/smithy/issues/2326): Smithy should not +// allow HTTP binding traits in this protocol. +@http(uri: "/simple-struct-operation", method: "POST") +operation SimpleStructOperation { + input: SimpleStruct + output: SimpleStruct + errors: [ValidationException] +} + +operation ErrorSerializationOperation { + input: SimpleStruct + output: ErrorSerializationOperationOutput + errors: [ValidationException] +} + +operation ComplexStructOperation { + input: ComplexStruct + output: ComplexStruct + errors: [ValidationException] +} + +operation EmptyStructOperation { + input: EmptyStruct + output: EmptyStruct +} + +operation SingleMemberStructOperation { + input: SingleMemberStruct + output: SingleMemberStruct +} + +apply EmptyStructOperation @httpMalformedRequestTests([ + { + id: "AdditionalTokensEmptyStruct", + documentation: """ + When additional tokens are found past where we expect the end of the body, + the request should be rejected with a serialization exception.""", + protocol: rpcv2Cbor, + request: { + method: "POST", + uri: "/service/RpcV2CborService/operation/EmptyStructOperation", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + // Two empty variable-length encoded CBOR maps back to back. + body: "v/+//w==" + }, + response: { + code: 400, + body: { + mediaType: "application/cbor", + assertion: { + // An empty CBOR map. + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing `__type` because `SerializationException` is not modeled. + contents: "oA==" + } + } + } + + } +]) + +apply SingleMemberStructOperation @httpMalformedRequestTests([ + { + id: "AdditionalTokensSingleMemberStruct", + documentation: """ + When additional tokens are found past where we expect the end of the body, + the request should be rejected with a serialization exception.""", + protocol: rpcv2Cbor, + request: { + method: "POST", + uri: "/service/RpcV2CborService/operation/SingleMemberStructOperation", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + // Two empty variable-length encoded CBOR maps back to back. + body: "v/+//w==" + }, + response: { + code: 400, + body: { + mediaType: "application/cbor", + assertion: { + // An empty CBOR map. + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing `__type` because `SerializationException` is not modeled. + contents: "oA==" + } + } + } + } +]) + +apply ErrorSerializationOperation @httpMalformedRequestTests([ + { + id: "ErrorSerializationIncludesTypeField", + documentation: """ + When invalid input is provided the request should be rejected with + a validation exception, and a `__type` field should be included""", + protocol: rpcv2Cbor, + request: { + method: "POST", + uri: "/service/RpcV2CborService/operation/ErrorSerializationOperation", + headers: { + "smithy-protocol": "rpc-v2-cbor", + "Accept": "application/cbor", + "Content-Type": "application/cbor" + } + // An empty CBOR map. We're missing a lot of `@required` members! + body: "oA==" + }, + response: { + code: 400, + body: { + mediaType: "application/cbor", + assertion: { + contents: "v2ZfX3R5cGV4JHNtaXRoeS5mcmFtZXdvcmsjVmFsaWRhdGlvbkV4Y2VwdGlvbmdtZXNzYWdleGsxIHZhbGlkYXRpb24gZXJyb3IgZGV0ZWN0ZWQuIFZhbHVlIGF0ICcvcmVxdWlyZWRCbG9iJyBmYWlsZWQgdG8gc2F0aXNmeSBjb25zdHJhaW50OiBNZW1iZXIgbXVzdCBub3QgYmUgbnVsbGlmaWVsZExpc3SBv2RwYXRobS9yZXF1aXJlZEJsb2JnbWVzc2FnZXhOVmFsdWUgYXQgJy9yZXF1aXJlZEJsb2InIGZhaWxlZCB0byBzYXRpc2Z5IGNvbnN0cmFpbnQ6IE1lbWJlciBtdXN0IG5vdCBiZSBudWxs//8=" + } + } + } + } +]) + +apply ErrorSerializationOperation @httpResponseTests([ + { + id: "OperationOutputSerializationQuestionablyIncludesTypeField", + documentation: """ + Despite the operation output being a structure shape with the `@error` trait, + `__type` field should, in a strict interpretation of the spec, not be included, + because we're not serializing a server error response. However, we do, because + there shouldn't™️ be any harm in doing so, and it greatly simplifies the + code generator. This test just pins this behavior in case we ever modify it.""", + protocol: rpcv2Cbor, + code: 200, + params: { + errorShape: { + message: "ValidationException message field" + } + } + bodyMediaType: "application/cbor" + body: "v2plcnJvclNoYXBlv2ZfX3R5cGV4JHNtaXRoeS5mcmFtZXdvcmsjVmFsaWRhdGlvbkV4Y2VwdGlvbmdtZXNzYWdleCFWYWxpZGF0aW9uRXhjZXB0aW9uIG1lc3NhZ2UgZmllbGT//w==" + } +]) + +apply SimpleStructOperation @httpResponseTests([ + { + id: "SimpleStruct", + protocol: rpcv2Cbor, + code: 200, // Not used. + params: { + blob: "blobby blob", + boolean: false, + + string: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + byte: 69, + short: 70, + integer: 71, + long: 72, + + float: 0.69, + double: 0.6969, + + timestamp: 1546300800, + enum: "DIAMOND" + + // With `@required`. + + requiredBlob: "blobby blob", + requiredBoolean: false, + + requiredString: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + requiredByte: 69, + requiredShort: 70, + requiredInteger: 71, + requiredLong: 72, + + requiredFloat: 0.69, + requiredDouble: 0.6969, + + requiredTimestamp: 1546300800, + requiredEnum: "DIAMOND" + } + }, + // Same test, but leave optional types empty + { + id: "SimpleStructWithOptionsSetToNone", + protocol: rpcv2Cbor, + code: 200, // Not used. + params: { + requiredBlob: "blobby blob", + requiredBoolean: false, + + requiredString: "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man.", + + requiredByte: 69, + requiredShort: 70, + requiredInteger: 71, + requiredLong: 72, + + requiredFloat: 0.69, + requiredDouble: 0.6969, + + requiredTimestamp: 1546300800, + requiredEnum: "DIAMOND" + } + } +]) + +structure ErrorSerializationOperationOutput { + errorShape: ValidationException +} + +structure SimpleStruct { + blob: Blob + boolean: Boolean + + string: String + + byte: Byte + short: Short + integer: Integer + long: Long + + float: Float + double: Double + + timestamp: Timestamp + enum: Suit + + // With `@required`. + + @required requiredBlob: Blob + @required requiredBoolean: Boolean + + @required requiredString: String + + @required requiredByte: Byte + @required requiredShort: Short + @required requiredInteger: Integer + @required requiredLong: Long + + @required requiredFloat: Float + @required requiredDouble: Double + + @required requiredTimestamp: Timestamp + // @required requiredDocument: MyDocument + @required requiredEnum: Suit +} + +structure ComplexStruct { + structure: SimpleStruct + emptyStructure: EmptyStruct + list: SimpleList + map: SimpleMap + union: SimpleUnion + unitUnion: UnitUnion + + structureList: StructList + + // `@required` for good measure here. + @required complexList: ComplexList + @required complexMap: ComplexMap + @required complexUnion: ComplexUnion +} + +structure EmptyStruct { } + +structure SingleMemberStruct { + message: String +} + +list StructList { + member: SimpleStruct +} + +list SimpleList { + member: String +} + +map SimpleMap { + key: String + value: Integer +} + +// TODO(https://github.com/smithy-lang/smithy/issues/2325): Upstream protocol +// test suite doesn't cover unions. While the generated SDK compiles, we're not +// exercising the (de)serializers with actual values. +union SimpleUnion { + blob: Blob + boolean: Boolean + string: String + unit: Unit +} + +union UnitUnion { + unitA: Unit + unitB: Unit +} + +list ComplexList { + member: ComplexMap +} + +map ComplexMap { + key: String + value: ComplexUnion +} + +union ComplexUnion { + // Recursive path here. + complexStruct: ComplexStruct + + structure: SimpleStruct + list: SimpleList + map: SimpleMap + union: SimpleUnion +} + +enum Suit { + DIAMOND + CLUB + HEART + SPADE +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index fbdd0dca11c..f34921dfffb 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.codegen.core.SymbolDependencyContainer import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.dq import java.nio.file.Path @@ -41,6 +42,12 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer { ) + dependencies().flatMap { it.dependencies } } + open fun toDevDependency(): RustDependency = + when (this) { + is CargoDependency -> this.toDevDependency() + is InlineDependency -> PANIC("it does not make sense for an inline dependency to be a dev-dependency") + } + companion object { private const val PROPERTY_KEY = "rustdep" @@ -71,9 +78,7 @@ class InlineDependency( return renderer.hashCode().toString() } - override fun dependencies(): List { - return extraDependencies - } + override fun dependencies(): List = extraDependencies fun key() = "${module.fullyQualifiedPath()}::$name" @@ -170,7 +175,7 @@ data class Feature(val name: String, val default: Boolean, val deps: List { * Hook for customizing symbols by inserting an additional symbol provider. */ fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = base + + /** + * Hook to override the protocol test generator. + */ + fun protocolTestGenerator( + codegenContext: CodegenContext, + baseGenerator: ProtocolTestGenerator, + ): ProtocolTestGenerator = baseGenerator } /** @@ -199,6 +208,14 @@ abstract class CombinedCoreCodegenDecorator + decorator.protocolTestGenerator(codegenContext, gen) + } + /** * Combines customizations from multiple ordered codegen decorators. * diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index de73ab760d9..c09bc545fc9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -65,6 +65,7 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.letIf import java.math.BigDecimal +import kotlin.jvm.optionals.getOrNull /** * Class describing an instantiator section that can be used in a customization. @@ -94,6 +95,13 @@ open class Instantiator( private val customizations: List = listOf(), private val constructPattern: InstantiatorConstructPattern = InstantiatorConstructPattern.BUILDER, private val customWritable: CustomWritable = NoCustomWritable(), + /** + * A protocol test may provide data for missing members (because we transformed the model). + * This flag makes it so that it is simply ignored, and code generation continues. + **/ + private val ignoreMissingMembers: Boolean = false, + /** Whether we're rendering within a test, in which case we should use dev-dependencies. */ + private val withinTest: Boolean = false, ) { data class Ctx( // The `http` crate requires that headers be lowercase, but Smithy protocol tests @@ -171,7 +179,7 @@ open class Instantiator( is MemberShape -> renderMember(writer, shape, data, ctx) is SimpleShape -> - PrimitiveInstantiator(runtimeConfig, symbolProvider).instantiate( + PrimitiveInstantiator(runtimeConfig, symbolProvider, withinTest).instantiate( shape, data, customWritable, @@ -422,8 +430,14 @@ open class Instantiator( } } - data.members.forEach { (key, value) -> - val memberShape = shape.expectMember(key.value) + for ((key, value) in data.members) { + val memberShape = + shape.getMember(key.value).getOrNull() + ?: if (ignoreMissingMembers) { + continue + } else { + throw CodegenException("Protocol test defines data for member shape `${key.value}`, but member shape was not found on structure shape ${shape.id}") + } renderMemberHelper(memberShape, value) } @@ -471,7 +485,27 @@ open class Instantiator( } } -class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private val symbolProvider: SymbolProvider) { +class PrimitiveInstantiator( + private val runtimeConfig: RuntimeConfig, + private val symbolProvider: SymbolProvider, + withinTest: Boolean = false, +) { + val codegenScope = + listOf( + "DateTime" to RuntimeType.dateTime(runtimeConfig), + "Bytestream" to RuntimeType.byteStream(runtimeConfig), + "Blob" to RuntimeType.blob(runtimeConfig), + "SmithyJson" to RuntimeType.smithyJson(runtimeConfig), + "SmithyTypes" to RuntimeType.smithyTypes(runtimeConfig), + ).map { + it.first to + if (withinTest) { + it.second.toDevDependencyType() + } else { + it.second + } + }.toTypedArray() + fun instantiate( shape: SimpleShape, data: Node, @@ -485,9 +519,9 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va val num = BigDecimal(node.toString()) val wholePart = num.toInt() val fractionalPart = num.remainder(BigDecimal.ONE) - rust( - "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", - RuntimeType.dateTime(runtimeConfig), + rustTemplate( + "#{DateTime}::from_fractional_secs($wholePart, ${fractionalPart}_f64)", + *codegenScope, ) } @@ -498,14 +532,14 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va */ is BlobShape -> if (shape.hasTrait()) { - rust( - "#T::from_static(b${(data as StringNode).value.dq()})", - RuntimeType.byteStream(runtimeConfig), + rustTemplate( + "#{Bytestream}::from_static(b${(data as StringNode).value.dq()})", + *codegenScope, ) } else { - rust( - "#T::new(${(data as StringNode).value.dq()})", - RuntimeType.blob(runtimeConfig), + rustTemplate( + "#{Blob}::new(${(data as StringNode).value.dq()})", + *codegenScope, ) } @@ -515,10 +549,10 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va is StringNode -> { val numberSymbol = symbolProvider.toSymbol(shape) // support Smithy custom values, such as Infinity - rust( - """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", - numberSymbol, - RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Parse"), + rustTemplate( + """<#{NumberSymbol} as #{SmithyTypes}::primitive::Parse>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", + "NumberSymbol" to numberSymbol, + *codegenScope, ) } @@ -533,15 +567,14 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va is BooleanShape -> rust(data.asBooleanNode().get().toString()) is DocumentShape -> rustBlock("") { - val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() + val smithyJson = CargoDependency.smithyJson(runtimeConfig).toDevDependency().toType() rustTemplate( """ let json_bytes = br##"${Node.prettyPrintJson(data)}"##; - let mut tokens = #{json_token_iter}(json_bytes).peekable(); - #{expect_document}(&mut tokens).expect("well formed json") + let mut tokens = #{SmithyJson}::deserialize::json_token_iter(json_bytes).peekable(); + #{SmithyJson}::deserialize::token::expect_document(&mut tokens).expect("well formed json") """, - "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), - "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), + *codegenScope, ) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt index 20121535a8e..3c7950ef34e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -64,7 +64,7 @@ abstract class ProtocolTestGenerator { abstract val brokenTests: Set /** Only generate these tests; useful to temporarily set and shorten development cycles */ - abstract val runOnly: Set + abstract val generateOnly: Set /** * These tests are not even attempted to be generated, either because they will not compile @@ -89,6 +89,8 @@ abstract class ProtocolTestGenerator { allMatchingTestCases().flatMap { fixBrokenTestCase(it) } + // Filter afterward in case a fixed broken test is disabled. + .filterMatching() if (allTests.isEmpty()) { return } @@ -109,6 +111,8 @@ abstract class ProtocolTestGenerator { if (!it.isBroken()) { listOf(it) } else { + logger.info("Fixing ${it.kind} test case ${it.id}") + assert(it.expectFail()) val brokenTest = it.findInBroken()!! @@ -160,11 +164,11 @@ abstract class ProtocolTestGenerator { /** Filter out test cases that are disabled or don't match the service protocol. */ private fun List.filterMatching(): List = - if (runOnly.isEmpty()) { + if (generateOnly.isEmpty()) { this.filter { testCase -> testCase.protocol == codegenContext.protocol && !disabledTests.contains(testCase.id) } } else { logger.warning("Generating only specified tests") - this.filter { testCase -> runOnly.contains(testCase.id) } + this.filter { testCase -> generateOnly.contains(testCase.id) } } private fun TestCase.toFailingTest(): FailingTest = @@ -191,7 +195,7 @@ abstract class ProtocolTestGenerator { val requestTests = operationShape.getTrait()?.getTestCasesFor(appliesTo).orEmpty() .map { TestCase.RequestTest(it) } - return requestTests.filterMatching() + return requestTests } fun responseTestCases(): List { @@ -209,7 +213,7 @@ abstract class ProtocolTestGenerator { ?.getTestCasesFor(appliesTo).orEmpty().map { TestCase.ResponseTest(it, error) } } - return (responseTestsOnOperations + responseTestsOnErrors).filterMatching() + return (responseTestsOnOperations + responseTestsOnErrors) } fun malformedRequestTestCases(): List { @@ -221,7 +225,7 @@ abstract class ProtocolTestGenerator { } else { emptyList() } - return malformedRequestTests.filterMatching() + return malformedRequestTests } /** @@ -412,6 +416,11 @@ object ServiceShapeId { const val AWS_JSON_10 = "aws.protocoltests.json10#JsonRpc10" const val AWS_JSON_11 = "aws.protocoltests.json#JsonProtocol" const val REST_JSON = "aws.protocoltests.restjson#RestJson" + const val RPC_V2_CBOR = "smithy.protocoltests.rpcv2Cbor#RpcV2Protocol" + const val RPC_V2_CBOR_EXTRAS = "smithy.protocoltests.rpcv2Cbor#RpcV2CborService" + const val REST_XML = "aws.protocoltests.restxml#RestXml" + const val AWS_QUERY = "aws.protocoltests.query#AwsQuery" + const val EC2_QUERY = "aws.protocoltests.ec2#AwsEc2" const val REST_JSON_VALIDATION = "aws.protocoltests.restjson.validation#RestJsonValidation" } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 486b443a6a2..b44bdfb84dd 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -174,10 +174,10 @@ open class AwsJson( override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName -> + // `HeaderMap::new()` doesn't allocate. rustTemplate( """ pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { - // Note: HeaderMap::new() doesn't allocate #{json_errors}::parse_error_metadata(payload, &#{Headers}::new()) } """, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt index ad36e791901..afeaf5e1ce1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt @@ -169,6 +169,10 @@ open class HttpTraitHttpBindingResolver( model: Model, ): TimestampFormatTrait.Format = httpIndex.determineTimestampFormat(memberShape, location, defaultTimestampFormat) + /** + * Note that `null` will be returned and hence `Content-Type` will not be set when operation input has no members. + * This is in line with what protocol tests assert. + */ override fun requestContentType(operationShape: OperationShape): String? = httpIndex.determineRequestContentType( operationShape, @@ -176,6 +180,10 @@ open class HttpTraitHttpBindingResolver( contentTypes.eventStreamContentType, ).orNull() + /** + * Note that `null` will be returned and hence `Content-Type` will not be set when operation output has no members. + * This is in line with what protocol tests assert. + */ override fun responseContentType(operationShape: OperationShape): String? = httpIndex.determineResponseContentType( operationShape, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt index 236c297db92..c7b139bfd59 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt @@ -27,11 +27,28 @@ interface Protocol { /** The timestamp format that should be used if no override is specified in the model */ val defaultTimestampFormat: TimestampFormatTrait.Format - /** Returns additional HTTP headers that should be included in HTTP requests for the given operation for this protocol. */ + /** + * Returns additional HTTP headers that should be included in HTTP requests for the given operation for this protocol. + * + * These MUST all be lowercase, or the application will panic, as per + * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static + */ fun additionalRequestHeaders(operationShape: OperationShape): List> = emptyList() + /** + * Returns additional HTTP headers that should be included in HTTP responses for the given operation for this protocol. + * + * These MUST all be lowercase, or the application will panic, as per + * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static + */ + fun additionalResponseHeaders(operationShape: OperationShape): List> = emptyList() + /** * Returns additional HTTP headers that should be included in HTTP responses for the given error shape. + * These headers are added to responses _in addition_ to those returned by `additionalResponseHeaders`; if a header + * added by this function has the same header name as one added by `additionalResponseHeaders`, the one added by + * `additionalResponseHeaders` takes precedence. + * * These MUST all be lowercase, or the application will panic, as per * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static */ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt index e40046f1d84..cb9b7667718 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt @@ -39,7 +39,7 @@ class ProtocolFunctions( private val codegenContext: CodegenContext, ) { companion object { - private val serDeModule = RustModule.pubCrate("protocol_serde") + val serDeModule = RustModule.pubCrate("protocol_serde") fun crossOperationFn( fnName: String, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index 641548fc116..c4e39806682 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -56,6 +56,12 @@ class RestJsonHttpBindingResolver( } } } + + // The spec does not mention whether we should set the `Content-Type` header when there is no modeled output. + // The protocol tests indicate it's optional: + // + // + // In our implementation, we opt to always set it to `application/json`. return super.responseContentType(operationShape) ?: "application/json" } } @@ -124,10 +130,10 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName -> + // `HeaderMap::new()` doesn't allocate. rustTemplate( """ pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { - // Note: HeaderMap::new() doesn't allocate #{json_errors}::parse_error_metadata(payload, &#{Headers}::new()) } """, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt new file mode 100644 index 00000000000..d1af7ae72c4 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -0,0 +1,121 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols + +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ToShapeId +import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.isStreaming +import software.amazon.smithy.rust.codegen.core.util.outputShape + +class RpcV2CborHttpBindingResolver( + private val model: Model, + private val contentTypes: ProtocolContentTypes, +) : HttpBindingResolver { + private fun bindings(shape: ToShapeId): List { + val members = shape.let { model.expectShape(it.toShapeId()) }.members() + // TODO(https://github.com/awslabs/smithy-rs/issues/2237): support non-streaming members too + if (members.size > 1 && members.any { it.isStreaming(model) }) { + throw CodegenException( + "We only support one payload member if that payload contains a streaming member." + + "Tracking issue to relax this constraint: https://github.com/awslabs/smithy-rs/issues/2237", + ) + } + + return members.map { + if (it.isStreaming(model)) { + HttpBindingDescriptor(it, HttpLocation.PAYLOAD, "document") + } else { + HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") + } + } + .toList() + } + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + // In the server, this is only used when the protocol actually supports the `@http` trait. + // However, we will have to do this for client support. Perhaps this method deserves a rename. + override fun httpTrait(operationShape: OperationShape) = PANIC("RPC v2 does not support the `@http` trait") + + override fun requestBindings(operationShape: OperationShape) = bindings(operationShape.inputShape) + + override fun responseBindings(operationShape: OperationShape) = bindings(operationShape.outputShape) + + override fun errorResponseBindings(errorShape: ToShapeId) = bindings(errorShape) + + /** + * https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#requests + * > Requests for operations with no defined input type MUST NOT contain bodies in their HTTP requests. + * > The `Content-Type` for the serialization format MUST NOT be set. + */ + override fun requestContentType(operationShape: OperationShape): String? = + if (OperationNormalizer.hadUserModeledOperationInput(operationShape, model)) { + contentTypes.requestDocument + } else { + null + } + + /** + * https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#responses + * > Responses for operations with no defined output type MUST NOT contain bodies in their HTTP responses. + * > The `Content-Type` for the serialization format MUST NOT be set. + */ + override fun responseContentType(operationShape: OperationShape): String? = + if (OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) { + contentTypes.responseDocument + } else { + null + } + + override fun eventStreamMessageContentType(memberShape: MemberShape): String? = + ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/cbor") +} + +open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { + override val httpBindingResolver: HttpBindingResolver = + RpcV2CborHttpBindingResolver( + codegenContext.model, + ProtocolContentTypes( + requestDocument = "application/cbor", + responseDocument = "application/cbor", + eventStreamContentType = "application/vnd.amazon.eventstream", + eventStreamMessageContentType = "application/cbor", + ), + ) + + // Note that [CborParserGenerator] and [CborSerializerGenerator] automatically (de)serialize timestamps + // using floating point seconds from the epoch. + override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS + + override fun additionalResponseHeaders(operationShape: OperationShape): List> = + listOf("smithy-protocol" to "rpc-v2-cbor") + + override fun structuredDataParser(): StructuredDataParserGenerator = + CborParserGenerator(codegenContext, httpBindingResolver) + + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = + CborSerializerGenerator(codegenContext, httpBindingResolver) + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = + TODO("rpcv2Cbor client support has not yet been implemented") + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = + TODO("rpcv2Cbor event streams have not yet been implemented") +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt new file mode 100644 index 00000000000..99208b0b9a6 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -0,0 +1,666 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.SparseTrait +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant +import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.isTargetUnit +import software.amazon.smithy.rust.codegen.core.util.outputShape + +/** Class describing a CBOR parser section that can be used in a customization. */ +sealed class CborParserSection(name: String) : Section(name) { + data class BeforeBoxingDeserializedMember(val shape: MemberShape) : CborParserSection("BeforeBoxingDeserializedMember") +} + +/** Customization for the CBOR parser. */ +typealias CborParserCustomization = NamedCustomization + +class CborParserGenerator( + private val codegenContext: CodegenContext, + private val httpBindingResolver: HttpBindingResolver, + /** See docs for this parameter in [JsonParserGenerator]. */ + private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape -> + ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) + }, + private val customizations: List = emptyList(), +) : StructuredDataParserGenerator { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig + private val codegenTarget = codegenContext.target + private val smithyCbor = CargoDependency.smithyCbor(runtimeConfig).toType() + private val protocolFunctions = ProtocolFunctions(codegenContext) + private val codegenScope = + arrayOf( + "SmithyCbor" to smithyCbor, + "Decoder" to smithyCbor.resolve("Decoder"), + "Error" to smithyCbor.resolve("decode::DeserializeError"), + "HashMap" to RuntimeType.HashMap, + *preludeScope, + ) + + private fun listMemberParserFn( + listSymbol: Symbol, + isSparseList: Boolean, + memberShape: MemberShape, + returnUnconstrainedType: Boolean, + ) = writable { + rustBlockTemplate( + """ + fn member( + mut list: #{ListSymbol}, + decoder: &mut #{Decoder}, + ) -> #{Result}<#{ListSymbol}, #{Error}> + """, + *codegenScope, + "ListSymbol" to listSymbol, + ) { + val deserializeMemberWritable = deserializeMember(memberShape) + if (isSparseList) { + rustTemplate( + """ + let value = match decoder.datatype()? { + #{SmithyCbor}::data::Type::Null => { + decoder.null()?; + None + } + _ => Some(#{DeserializeMember:W}?), + }; + """, + *codegenScope, + "DeserializeMember" to deserializeMemberWritable, + ) + } else { + rustTemplate( + """ + let value = #{DeserializeMember:W}?; + """, + "DeserializeMember" to deserializeMemberWritable, + ) + } + + if (returnUnconstrainedType) { + rust("list.0.push(value);") + } else { + rust("list.push(value);") + } + + rust("Ok(list)") + } + } + + private fun mapPairParserFnWritable( + keyTarget: StringShape, + valueShape: MemberShape, + isSparseMap: Boolean, + mapSymbol: Symbol, + returnUnconstrainedType: Boolean, + ) = writable { + rustBlockTemplate( + """ + fn pair( + mut map: #{MapSymbol}, + decoder: &mut #{Decoder}, + ) -> #{Result}<#{MapSymbol}, #{Error}> + """, + *codegenScope, + "MapSymbol" to mapSymbol, + ) { + val deserializeKeyWritable = deserializeString(keyTarget) + rustTemplate( + """ + let key = #{DeserializeKey:W}?; + """, + "DeserializeKey" to deserializeKeyWritable, + ) + val deserializeValueWritable = deserializeMember(valueShape) + if (isSparseMap) { + rustTemplate( + """ + let value = match decoder.datatype()? { + #{SmithyCbor}::data::Type::Null => { + decoder.null()?; + None + } + _ => Some(#{DeserializeValue:W}?), + }; + """, + *codegenScope, + "DeserializeValue" to deserializeValueWritable, + ) + } else { + rustTemplate( + """ + let value = #{DeserializeValue:W}?; + """, + "DeserializeValue" to deserializeValueWritable, + ) + } + + if (returnUnconstrainedType) { + rust("map.0.insert(key, value);") + } else { + rust("map.insert(key, value);") + } + + rust("Ok(map)") + } + } + + private fun structurePairParserFnWritable( + builderSymbol: Symbol, + includedMembers: Collection, + ) = writable { + rustBlockTemplate( + """ + ##[allow(clippy::match_single_binding)] + fn pair( + mut builder: #{Builder}, + decoder: &mut #{Decoder} + ) -> #{Result}<#{Builder}, #{Error}> + """, + *codegenScope, + "Builder" to builderSymbol, + ) { + withBlock("builder = match decoder.str()?.as_ref() {", "};") { + for (member in includedMembers) { + rustBlock("${member.memberName.dq()} =>") { + val callBuilderSetMemberFieldWritable = + writable { + withBlock("builder.${member.setterName()}(", ")") { + conditionalBlock("Some(", ")", symbolProvider.toSymbol(member).isOptional()) { + val symbol = symbolProvider.toSymbol(member) + if (symbol.isRustBoxed()) { + rustBlock("") { + rustTemplate( + "let v = #{DeserializeMember:W}?;", + "DeserializeMember" to deserializeMember(member), + ) + + for (customization in customizations) { + customization.section( + CborParserSection.BeforeBoxingDeserializedMember( + member, + ), + )(this) + } + rust("Box::new(v)") + } + } else { + rustTemplate( + "#{DeserializeMember:W}?", + "DeserializeMember" to deserializeMember(member), + ) + } + } + } + } + + if (member.isOptional) { + // Call `builder.set_member()` only if the value for the field on the wire is not null. + rustTemplate( + """ + #{SmithyCbor}::decode::set_optional(builder, decoder, |builder, decoder| { + Ok(#{MemberSettingWritable:W}) + })? + """, + *codegenScope, + "MemberSettingWritable" to callBuilderSetMemberFieldWritable, + ) + } else { + callBuilderSetMemberFieldWritable.invoke(this) + } + } + } + + rust( + """ + _ => { + decoder.skip()?; + builder + } + """, + ) + } + rust("Ok(builder)") + } + } + + private fun unionPairParserFnWritable(shape: UnionShape) = + writable { + val returnSymbolToParse = returnSymbolToParse(shape) + rustBlockTemplate( + """ + fn pair( + decoder: &mut #{Decoder} + ) -> #{Result}<#{UnionSymbol}, #{Error}> + """, + *codegenScope, + "UnionSymbol" to returnSymbolToParse.symbol, + ) { + withBlock("Ok(match decoder.str()?.as_ref() {", "})") { + for (member in shape.members()) { + val variantName = symbolProvider.toMemberName(member) + + if (member.isTargetUnit()) { + rust( + """ + ${member.memberName.dq()} => { + decoder.skip()?; + #T::$variantName + } + """, + returnSymbolToParse.symbol, + ) + } else { + withBlock("${member.memberName.dq()} => #T::$variantName(", "?),", returnSymbolToParse.symbol) { + deserializeMember(member).invoke(this) + } + } + } + when (codegenTarget.renderUnknownVariant()) { + // In client mode, resolve an unknown union variant to the unknown variant. + true -> + rustTemplate( + """ + _ => { + decoder.skip()?; + Some(#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME}) + } + """, + "Union" to returnSymbolToParse.symbol, + *codegenScope, + ) + // In server mode, use strict parsing. + // Consultation: https://github.com/awslabs/smithy/issues/1222 + false -> + rustTemplate( + "variant => return Err(#{Error}::unknown_union_variant(variant, decoder.position()))", + *codegenScope, + ) + } + } + } + } + + enum class CollectionKind { + Map, + List, + ; + + /** Method to invoke on the decoder to decode this collection kind. **/ + fun decoderMethodName() = + when (this) { + Map -> "map" + List -> "list" + } + } + + /** + * Decode a collection of homogeneous CBOR data items: a map or an array. + * The first branch of the `match` corresponds to when the collection is encoded using variable-length encoding; + * the second branch corresponds to fixed-length encoding. + * + * https://www.rfc-editor.org/rfc/rfc8949.html#name-indefinite-length-arrays-an + */ + private fun decodeCollectionLoopWritable( + collectionKind: CollectionKind, + variableBindingName: String, + decodeItemFnName: String, + ) = writable { + rustTemplate( + """ + match decoder.${collectionKind.decoderMethodName()}()? { + None => loop { + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + break; + } + _ => { + $variableBindingName = $decodeItemFnName($variableBindingName, decoder)?; + } + }; + }, + Some(n) => { + for _ in 0..n { + $variableBindingName = $decodeItemFnName($variableBindingName, decoder)?; + } + } + }; + """, + *codegenScope, + ) + } + + private fun decodeStructureMapLoopWritable() = decodeCollectionLoopWritable(CollectionKind.Map, "builder", "pair") + + private fun decodeMapLoopWritable() = decodeCollectionLoopWritable(CollectionKind.Map, "map", "pair") + + private fun decodeListLoopWritable() = decodeCollectionLoopWritable(CollectionKind.List, "list", "member") + + /** + * Reusable structure parser implementation that can be used to generate parsing code for + * operation, error and structure shapes. + * We still generate the parser symbol even if there are no included members because the server + * generation requires parsers for all input structures. + */ + private fun structureParser( + shape: Shape, + builderSymbol: Symbol, + includedMembers: List, + fnNameSuffix: String? = null, + ): RuntimeType { + return protocolFunctions.deserializeFn(shape, fnNameSuffix) { fnName -> + rustTemplate( + """ + pub(crate) fn $fnName(value: &[u8], mut builder: #{Builder}) -> #{Result}<#{Builder}, #{Error}> { + #{StructurePairParserFn:W} + + let decoder = &mut #{Decoder}::new(value); + + #{DecodeStructureMapLoop:W} + + if decoder.position() != value.len() { + return Err(#{Error}::expected_end_of_stream(decoder.position())); + } + + Ok(builder) + } + """, + "Builder" to builderSymbol, + "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), + "DecodeStructureMapLoop" to decodeStructureMapLoopWritable(), + *codegenScope, + ) + } + } + + override fun payloadParser(member: MemberShape): RuntimeType { + UNREACHABLE("No protocol using CBOR serialization supports payload binding") + } + + override fun operationParser(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation CBOR deserializer if there is nothing bound to the HTTP body. + val httpDocumentMembers = httpBindingResolver.responseMembers(operationShape, HttpLocation.DOCUMENT) + if (httpDocumentMembers.isEmpty()) { + return null + } + val outputShape = operationShape.outputShape(model) + return structureParser(operationShape, symbolProvider.symbolForBuilder(outputShape), httpDocumentMembers) + } + + override fun errorParser(errorShape: StructureShape): RuntimeType? { + if (errorShape.members().isEmpty()) { + return null + } + return structureParser( + errorShape, + symbolProvider.symbolForBuilder(errorShape), + errorShape.members().toList(), + fnNameSuffix = "cbor_err", + ) + } + + override fun serverInputParser(operationShape: OperationShape): RuntimeType? { + val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + if (includedMembers.isEmpty()) { + return null + } + val inputShape = operationShape.inputShape(model) + return structureParser(operationShape, symbolProvider.symbolForBuilder(inputShape), includedMembers) + } + + private fun deserializeMember(memberShape: MemberShape) = + writable { + when (val target = model.expectShape(memberShape.target)) { + // Simple shapes: https://smithy.io/2.0/spec/simple-types.html + is BlobShape -> rust("decoder.blob()") + is BooleanShape -> rust("decoder.boolean()") + + is StringShape -> deserializeString(target).invoke(this) + + is ByteShape -> rust("decoder.byte()") + is ShortShape -> rust("decoder.short()") + is IntegerShape -> rust("decoder.integer()") + is LongShape -> rust("decoder.long()") + + is FloatShape -> rust("decoder.float()") + is DoubleShape -> rust("decoder.double()") + + is TimestampShape -> rust("decoder.timestamp()") + + // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html + is StructureShape -> deserializeStruct(target) + is CollectionShape -> deserializeCollection(target) + is MapShape -> deserializeMap(target) + is UnionShape -> deserializeUnion(target) + + // Note that no protocol using CBOR serialization supports `document` shapes. + else -> PANIC("unexpected shape: $target") + } + } + + private fun deserializeString(target: StringShape) = + writable { + when (target.hasTrait()) { + true -> { + if (this@CborParserGenerator.returnSymbolToParse(target).isUnconstrained) { + rust("decoder.string()") + } else { + rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) + } + } + false -> rust("decoder.string()") + } + } + + private fun RustWriter.deserializeCollection(shape: CollectionShape) { + val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) + + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + val initContainerWritable = + writable { + withBlock("let mut list = ", ";") { + conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { + rustTemplate("#{Vec}::new()", *codegenScope) + } + } + } + + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}> { + #{ListMemberParserFn:W} + + #{InitContainerWritable:W} + + #{DecodeListLoop:W} + + Ok(list) + } + """, + "ReturnType" to returnSymbol, + "ListMemberParserFn" to + listMemberParserFn( + returnSymbol, + isSparseList = shape.hasTrait(), + shape.member, + returnUnconstrainedType = returnUnconstrainedType, + ), + "InitContainerWritable" to initContainerWritable, + "DecodeListLoop" to decodeListLoopWritable(), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeMap(shape: MapShape) { + val keyTarget = model.expectShape(shape.key.target, StringShape::class.java) + val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) + + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + val initContainerWritable = + writable { + withBlock("let mut map = ", ";") { + conditionalBlock("#{T}(", ")", conditional = returnUnconstrainedType, returnSymbol) { + rustTemplate("#{HashMap}::new()", *codegenScope) + } + } + } + + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}> { + #{MapPairParserFn:W} + + #{InitContainerWritable:W} + + #{DecodeMapLoop:W} + + Ok(map) + } + """, + "ReturnType" to returnSymbol, + "MapPairParserFn" to + mapPairParserFnWritable( + keyTarget, + shape.value, + isSparseMap = shape.hasTrait(), + returnSymbol, + returnUnconstrainedType = returnUnconstrainedType, + ), + "InitContainerWritable" to initContainerWritable, + "DecodeMapLoop" to decodeMapLoopWritable(), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeStruct(shape: StructureShape) { + val returnSymbolToParse = returnSymbolToParse(shape) + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustBlockTemplate( + "pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{ReturnType}, #{Error}>", + "ReturnType" to returnSymbolToParse.symbol, + *codegenScope, + ) { + val builderSymbol = symbolProvider.symbolForBuilder(shape) + val includedMembers = shape.members() + + rustTemplate( + """ + #{StructurePairParserFn:W} + + let mut builder = #{Builder}::default(); + + #{DecodeStructureMapLoop:W} + """, + *codegenScope, + "StructurePairParserFn" to structurePairParserFnWritable(builderSymbol, includedMembers), + "Builder" to builderSymbol, + "DecodeStructureMapLoop" to decodeStructureMapLoopWritable(), + ) + + // Only call `build()` if the builder is not fallible. Otherwise, return the builder. + if (returnSymbolToParse.isUnconstrained) { + rust("Ok(builder)") + } else { + rust("Ok(builder.build())") + } + } + } + rust("#T(decoder)", parser) + } + + private fun RustWriter.deserializeUnion(shape: UnionShape) { + val returnSymbolToParse = returnSymbolToParse(shape) + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustTemplate( + """ + pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{UnionSymbol}, #{Error}> { + #{UnionPairParserFnWritable} + + match decoder.map()? { + None => { + let variant = pair(decoder)?; + match decoder.datatype()? { + #{SmithyCbor}::data::Type::Break => { + decoder.skip()?; + Ok(variant) + } + ty => Err( + #{Error}::unexpected_union_variant( + ty, + decoder.position(), + ), + ), + } + } + Some(1) => pair(decoder), + Some(_) => Err(#{Error}::mixed_union_variants(decoder.position())) + } + } + """, + "UnionSymbol" to returnSymbolToParse.symbol, + "UnionPairParserFnWritable" to unionPairParserFnWritable(shape), + *codegenScope, + ) + } + rust("#T(decoder)", parser) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 4833f808fb0..cf0676ebd0d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -77,8 +77,6 @@ sealed class JsonParserSection(name: String) : Section(name) { */ typealias JsonParserCustomization = NamedCustomization -data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) - class JsonParserGenerator( private val codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, @@ -339,8 +337,7 @@ class JsonParserGenerator( rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) } } - - else -> rust("u.into_owned()") + false -> rust("u.into_owned()") } } } @@ -447,7 +444,7 @@ class JsonParserGenerator( } private fun RustWriter.deserializeMap(shape: MapShape) { - val keyTarget = model.expectShape(shape.key.target) as StringShape + val keyTarget = model.expectShape(shape.key.target, StringShape::class.java) val isSparse = shape.hasTrait() val returnSymbolToParse = returnSymbolToParse(shape) val parser = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt new file mode 100644 index 00000000000..4b69e873289 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/ReturnSymbolToParse.kt @@ -0,0 +1,14 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse + +import software.amazon.smithy.codegen.core.Symbol + +/** + * Parsers need to know what symbol to parse and return, and whether it's unconstrained or not. + * This data class holds this information that the parsers fill out from a shape. + */ +data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt index f8b053d80f7..fd7e6fcb289 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/StructuredDataParserGenerator.kt @@ -12,8 +12,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType interface StructuredDataParserGenerator { /** - * Generate a parse function for a given targeted as a payload. + * Generate a parse function for a given shape targeted with `@httpPayload`. * Entry point for payload-based parsing. + * * Roughly: * ```rust * fn parse_my_struct(input: &[u8]) -> Result { @@ -49,6 +50,7 @@ interface StructuredDataParserGenerator { /** * Generate a parser for a server operation input structure + * * ```rust * fn deser_operation_crate_operation_my_operation_input( * value: &[u8], builder: my_operation_input::Builder diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt new file mode 100644 index 00000000000..f96a8b7cbc2 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -0,0 +1,419 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize + +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant +import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.isTargetUnit +import software.amazon.smithy.rust.codegen.core.util.isUnit +import software.amazon.smithy.rust.codegen.core.util.outputShape + +/** + * Class describing a CBOR serializer section that can be used in a customization. + */ +sealed class CborSerializerSection(name: String) : Section(name) { + /** + * Mutate the serializer prior to serializing any structure members. Eg: this can be used to inject `__type` + * to record the error type in the case of an error structure. + */ + data class BeforeSerializingStructureMembers( + val structureShape: StructureShape, + val encoderBindingName: String, + ) : CborSerializerSection("BeforeSerializingStructureMembers") + + /** Manipulate the serializer context for a map prior to it being serialized. **/ + data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context) : + CborSerializerSection("BeforeIteratingOverMapOrCollection") +} + +/** + * Customization for the CBOR serializer. + */ +typealias CborSerializerCustomization = NamedCustomization + +class CborSerializerGenerator( + codegenContext: CodegenContext, + private val httpBindingResolver: HttpBindingResolver, + private val customizations: List = listOf(), +) : StructuredDataSerializerGenerator { + data class Context( + /** Expression representing the value to write to the encoder */ + var valueExpression: ValueExpression, + /** Shape to serialize */ + val shape: T, + ) + + data class MemberContext( + /** Name for the variable bound to the encoder object **/ + val encoderBindingName: String, + /** Expression representing the value to write to the `Encoder` */ + var valueExpression: ValueExpression, + val shape: MemberShape, + /** Whether to serialize null values if the type is optional */ + val writeNulls: Boolean = false, + ) { + companion object { + fun collectionMember( + context: Context, + itemName: String, + ): MemberContext = + MemberContext( + "encoder", + ValueExpression.Reference(itemName), + context.shape.member, + writeNulls = true, + ) + + fun mapMember( + context: Context, + key: String, + value: String, + ): MemberContext = + MemberContext( + "encoder.str($key)", + ValueExpression.Reference(value), + context.shape.value, + writeNulls = true, + ) + + fun structMember( + context: StructContext, + member: MemberShape, + symProvider: RustSymbolProvider, + ): MemberContext = + MemberContext( + encodeKeyExpression(member.memberName), + ValueExpression.Value("${context.localName}.${symProvider.toMemberName(member)}"), + member, + ) + + fun unionMember( + variantReference: String, + member: MemberShape, + ): MemberContext = + MemberContext( + encodeKeyExpression(member.memberName), + ValueExpression.Reference(variantReference), + member, + ) + + /** Returns an expression to encode a key member **/ + private fun encodeKeyExpression(name: String): String = "encoder.str(${name.dq()})" + } + } + + data class StructContext( + /** Name of the variable that holds the struct */ + val localName: String, + val shape: StructureShape, + ) + + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val codegenTarget = codegenContext.target + private val runtimeConfig = codegenContext.runtimeConfig + private val protocolFunctions = ProtocolFunctions(codegenContext) + + private val codegenScope = + arrayOf( + "Error" to runtimeConfig.serializationError(), + "Encoder" to RuntimeType.smithyCbor(runtimeConfig).resolve("Encoder"), + *preludeScope, + ) + private val serializerUtil = SerializerUtil(model, symbolProvider) + + /** + * Reusable structure serializer implementation that can be used to generate serializing code for + * operation outputs or errors. + * This function is only used by the server, the client uses directly [serializeStructure]. + */ + private fun serverSerializer( + structureShape: StructureShape, + includedMembers: List, + error: Boolean, + ): RuntimeType { + val suffix = + when (error) { + true -> "error" + else -> "output" + } + return protocolFunctions.serializeFn(structureShape, fnNameSuffix = suffix) { fnName -> + rustBlockTemplate( + "pub fn $fnName(value: &#{target}) -> #{Result}<#{Vec}, #{Error}>", + *codegenScope, + "target" to symbolProvider.toSymbol(structureShape), + ) { + rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope) + // Open a scope in which we can safely shadow the `encoder` variable to bind it to a mutable reference. + rustBlock("") { + rust("let encoder = &mut encoder;") + serializeStructure( + StructContext("value", structureShape), + includedMembers, + ) + } + rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope) + } + } + } + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + override fun payloadSerializer(member: MemberShape): RuntimeType { + TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573") + } + + override fun unsetStructure(structure: StructureShape): RuntimeType = + UNREACHABLE("Only clients use this method when serializing an `@httpPayload`. No protocol using CBOR supports this trait, so we don't need to implement this") + + override fun unsetUnion(union: UnionShape): RuntimeType = + UNREACHABLE("Only clients use this method when serializing an `@httpPayload`. No protocol using CBOR supports this trait, so we don't need to implement this") + + override fun operationInputSerializer(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation CBOR serializer if there is no CBOR body. + val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) + if (httpDocumentMembers.isEmpty()) { + return null + } + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) + TODO("Client implementation should fill this out") + } + + override fun documentSerializer(): RuntimeType = + UNREACHABLE("No protocol using CBOR supports `document` shapes, so we don't need to implement this") + + override fun operationOutputSerializer(operationShape: OperationShape): RuntimeType? { + // Don't generate an operation CBOR serializer if there was no operation output shape in the + // original (untransformed) model. + if (!OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) { + return null + } + + val httpDocumentMembers = httpBindingResolver.responseMembers(operationShape, HttpLocation.DOCUMENT) + val outputShape = operationShape.outputShape(model) + return serverSerializer(outputShape, httpDocumentMembers, error = false) + } + + override fun serverErrorSerializer(shape: ShapeId): RuntimeType { + val errorShape = model.expectShape(shape, StructureShape::class.java) + val includedMembers = + httpBindingResolver.errorResponseBindings(shape).filter { it.location == HttpLocation.DOCUMENT } + .map { it.member } + return serverSerializer(errorShape, includedMembers, error = true) + } + + private fun RustWriter.serializeStructure( + context: StructContext, + includedMembers: List? = null, + ) { + if (context.shape.isUnit()) { + rust( + """ + encoder.begin_map(); + encoder.end(); + """, + ) + return + } + + val structureSerializer = + protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol}) -> #{Result}<(), #{Error}>", + "StructureSymbol" to symbolProvider.toSymbol(context.shape), + *codegenScope, + ) { + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3745) If all members are non-`Option`-al, + // we know AOT the map's size and can use `.map()` instead of `.begin_map()` for efficiency. + rust("encoder.begin_map();") + for (customization in customizations) { + customization.section( + CborSerializerSection.BeforeSerializingStructureMembers( + context.shape, + "encoder", + ), + )(this) + } + context.copy(localName = "input").also { inner -> + val members = includedMembers ?: inner.shape.members() + for (member in members) { + serializeMember(MemberContext.structMember(inner, member, symbolProvider)) + } + } + rust("encoder.end();") + rust("Ok(())") + } + } + rust("#T(encoder, ${context.localName})?;", structureSerializer) + } + + private fun RustWriter.serializeMember(context: MemberContext) { + val targetShape = model.expectShape(context.shape.target) + if (symbolProvider.toSymbol(context.shape).isOptional()) { + safeName().also { local -> + rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") { + context.valueExpression = ValueExpression.Reference(local) + serializeMemberValue(context, targetShape) + } + if (context.writeNulls) { + rustBlock("else") { + rust("${context.encoderBindingName}.null();") + } + } + } + } else { + with(serializerUtil) { + ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) { + serializeMemberValue(context, targetShape) + } + } + } + } + + private fun RustWriter.serializeMemberValue( + context: MemberContext, + target: Shape, + ) { + val encoder = context.encoderBindingName + val value = context.valueExpression + val containerShape = model.expectShape(context.shape.container) + + when (target) { + // Simple shapes: https://smithy.io/2.0/spec/simple-types.html + is BlobShape -> rust("$encoder.blob(${value.asRef()});") + is BooleanShape -> rust("$encoder.boolean(${value.asValue()});") + + is StringShape -> rust("$encoder.str(${value.name}.as_str());") + + is ByteShape -> rust("$encoder.byte(${value.asValue()});") + is ShortShape -> rust("$encoder.short(${value.asValue()});") + is IntegerShape -> rust("$encoder.integer(${value.asValue()});") + is LongShape -> rust("$encoder.long(${value.asValue()});") + + is FloatShape -> rust("$encoder.float(${value.asValue()});") + is DoubleShape -> rust("$encoder.double(${value.asValue()});") + + is TimestampShape -> rust("$encoder.timestamp(${value.asRef()});") + + is DocumentShape -> UNREACHABLE("Smithy RPC v2 CBOR does not support `document` shapes") + + // Aggregate shapes: https://smithy.io/2.0/spec/aggregate-types.html + else -> { + // This condition is equivalent to `containerShape !is CollectionShape`. + if (containerShape is StructureShape || containerShape is UnionShape || containerShape is MapShape) { + rust("$encoder;") // Encode the member key. + } + when (target) { + is StructureShape -> serializeStructure(StructContext(value.name, target)) + is CollectionShape -> serializeCollection(Context(value, target)) + is MapShape -> serializeMap(Context(value, target)) + is UnionShape -> serializeUnion(Context(value, target)) + else -> UNREACHABLE("Smithy added a new aggregate shape: $target") + } + } + } + } + + private fun RustWriter.serializeCollection(context: Context) { + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) + } + rust("encoder.array((${context.valueExpression.asValue()}).len());") + val itemName = safeName("item") + rustBlock("for $itemName in ${context.valueExpression.asRef()}") { + serializeMember(MemberContext.collectionMember(context, itemName)) + } + } + + private fun RustWriter.serializeMap(context: Context) { + val keyName = safeName("key") + val valueName = safeName("value") + for (customization in customizations) { + customization.section(CborSerializerSection.BeforeIteratingOverMapOrCollection(context.shape, context))(this) + } + rust("encoder.map((${context.valueExpression.asValue()}).len());") + rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { + val keyExpression = "$keyName.as_str()" + serializeMember(MemberContext.mapMember(context, keyExpression, valueName)) + } + } + + private fun RustWriter.serializeUnion(context: Context) { + val unionSymbol = symbolProvider.toSymbol(context.shape) + val unionSerializer = + protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(encoder: &mut #{Encoder}, input: &#{UnionSymbol}) -> #{Result}<(), #{Error}>", + "UnionSymbol" to unionSymbol, + *codegenScope, + ) { + // A union is serialized identically as a `structure` shape, but only a single member can be set to a + // non-null value. + rust("encoder.map(1);") + rustBlock("match input") { + for (member in context.shape.members()) { + val variantName = + if (member.isTargetUnit()) { + symbolProvider.toMemberName(member) + } else { + "${symbolProvider.toMemberName(member)}(inner)" + } + rustBlock("#T::$variantName =>", unionSymbol) { + serializeMember(MemberContext.unionMember("inner", member)) + } + } + if (codegenTarget.renderUnknownVariant()) { + rustTemplate( + "#{Union}::${UnionGenerator.UNKNOWN_VARIANT_NAME} => return #{Err}(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", + "Union" to unionSymbol, + *codegenScope, + ) + } + } + rustTemplate("#{Ok}(())", *codegenScope) + } + } + rust("#T(encoder, ${context.valueExpression.asRef()})?;", unionSerializer) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 46341bb09c8..69ec11fdd28 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -47,9 +47,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.outputShape @@ -212,12 +211,21 @@ class JsonSerializerGenerator( *codegenScope, "target" to symbolProvider.toSymbol(structureShape), ) { - rust("let mut out = String::new();") - rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) + rustTemplate( + """ + let mut out = #{String}::new(); + let mut object = #{JsonObjectWriter}::new(&mut out); + """, + *codegenScope, + ) serializeStructure(StructContext("object", "value", structureShape), includedMembers) customizations.forEach { it.section(makeSection(structureShape, "object"))(this) } - rust("object.finish();") - rustTemplate("Ok(out)", *codegenScope) + rust( + """ + object.finish(); + Ok(out) + """, + ) } } } @@ -304,8 +312,7 @@ class JsonSerializerGenerator( override fun operationOutputSerializer(operationShape: OperationShape): RuntimeType? { // Don't generate an operation JSON serializer if there was no operation output shape in the // original (untransformed) model. - val syntheticOutputTrait = operationShape.outputShape(model).expectTrait() - if (syntheticOutputTrait.originalId == null) { + if (!OperationNormalizer.hadUserModeledOperationOutput(operationShape, model)) { return null } @@ -485,13 +492,17 @@ class JsonSerializerGenerator( rust("let mut $objectName = ${context.writerExpression}.start_object();") // We call inner only when context's shape is not the Unit type. // If it were, calling inner would generate the following function: - // pub fn serialize_structure_crate_model_unit( - // object: &mut aws_smithy_json::serialize::JsonObjectWriter, - // input: &crate::model::Unit, - // ) -> Result<(), aws_smithy_http::operation::error::SerializationError> { - // let (_, _) = (object, input); - // Ok(()) - // } + // + // ```rust + // pub fn serialize_structure_crate_model_unit( + // object: &mut aws_smithy_json::serialize::JsonObjectWriter, + // input: &crate::model::Unit, + // ) -> Result<(), aws_smithy_http::operation::error::SerializationError> { + // let (_, _) = (object, input); + // Ok(()) + // } + // ``` + // // However, this would cause a compilation error at a call site because it cannot // extract data out of the Unit type that corresponds to the variable "input" above. if (!context.shape.isTargetUnit()) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt index 92b28d89fcf..a85646673d5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/StructuredDataSerializerGenerator.kt @@ -25,12 +25,15 @@ interface StructuredDataSerializerGenerator { fun payloadSerializer(member: MemberShape): RuntimeType /** - * Generate the correct data when attempting to serialize a structure that is unset + * Generate the correct data when attempting to serialize a structure that is unset. * * ```rust * fn rest_json_unset_struct_payload() -> Vec { * ... * } + * ``` + * + * This method is only invoked when serializing an `@httpPayload`. */ fun unsetStructure(structure: StructureShape): RuntimeType @@ -41,6 +44,9 @@ interface StructuredDataSerializerGenerator { * fn rest_json_unset_union_payload() -> Vec { * ... * } + * ``` + * + * This method is only invoked when serializing an `@httpPayload`. */ fun unsetUnion(union: UnionShape): RuntimeType diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt index e4b95f5cde2..b0c61d3d6cf 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticInputTrait.kt @@ -12,8 +12,13 @@ import software.amazon.smithy.model.traits.AnnotationTrait /** * Indicates that a shape is a synthetic input (see `OperationNormalizer.kt`) * - * All operations are normalized to have an input, even when they are defined without on. This is done for backwards - * compatibility and to produce a consistent API. + * All operations are normalized to have an input, even when they are defined without one. + * This is NOT done for backwards-compatibility, as adding an operation input is a breaking change + * (see ). + * + * It is only done to produce a consistent API. + * TODO(https://github.com/smithy-lang/smithy-rs/issues/3577): In the server, we'd like to stop adding + * these synthetic inputs. */ class SyntheticInputTrait( val operation: ShapeId, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt index d34ca0e39da..cc310e8a478 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt @@ -12,8 +12,14 @@ import software.amazon.smithy.model.traits.AnnotationTrait /** * Indicates that a shape is a synthetic output (see `OperationNormalizer.kt`) * - * All operations are normalized to have an output, even when they are defined without on. This is done for backwards - * compatibility and to produce a consistent API. + * All operations are normalized to have an output, even when they are defined without one. + * + * This is NOT done for backwards-compatibility, as adding an operation output is a breaking change + * (see ). + * + * It is only done to produce a consistent API. + * TODO(https://github.com/smithy-lang/smithy-rs/issues/3577): In the server, we'd like to stop adding + * these synthetic outputs. */ class SyntheticOutputTrait constructor(val operation: ShapeId, val originalId: ShapeId?) : AnnotationTrait(ID, Node.objectNode()) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt index 4092174b55e..89d4512007b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt @@ -14,23 +14,27 @@ import software.amazon.smithy.model.traits.InputTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.orNull +import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.rust.codegen.core.util.rename import java.util.Optional -import kotlin.streams.toList /** * Generate synthetic Input and Output structures for operations. * - * Operation input/output shapes can be retroactively added. In order to support this while maintaining backwards compatibility, - * we need to generate input/output shapes for all operations in a backwards compatible way. - * * This works by **adding** new shapes to the model for operation inputs & outputs. These new shapes have `SyntheticInputTrait` * and `SyntheticOutputTrait` attached to them as needed. This enables downstream code generators to determine if a shape is * "real" vs. a shape created as a synthetic input/output. * * The trait also tracks the original shape id for certain serialization tasks that require it to exist. + * + * Note that adding/removing operation input/output [is a breaking change]; the only reason why we synthetically add them + * is to produce a consistent API. + * + * [is a breaking change]: */ object OperationNormalizer { // Functions to construct synthetic shape IDs—Don't rely on these in external code. @@ -43,6 +47,30 @@ object OperationNormalizer { private fun OperationShape.syntheticOutputId() = ShapeId.fromParts(this.id.namespace + ".synthetic", "${this.id.name}Output") + /** + * Returns `true` if the user had originally modeled an operation input shape on the given [operation]; + * `false` if the transform added a synthetic one. + */ + fun hadUserModeledOperationInput( + operation: OperationShape, + model: Model, + ): Boolean { + val syntheticInputTrait = operation.inputShape(model).expectTrait() + return syntheticInputTrait.originalId != null + } + + /** + * Returns `true` if the user had originally modeled an operation output shape on the given [operation]; + * `false` if the transform added a synthetic one. + */ + fun hadUserModeledOperationOutput( + operation: OperationShape, + model: Model, + ): Boolean { + val syntheticOutputTrait = operation.outputShape(model).expectTrait() + return syntheticOutputTrait.originalId != null + } + /** * Add synthetic input & output shapes to every Operation in model. The generated shapes will be marked with * [SyntheticInputTrait] and [SyntheticOutputTrait] respectively. Shapes will be added _even_ if the operation does diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt index 5508611972e..e0279c53c03 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt @@ -178,7 +178,7 @@ object NamingObstacleCourseTestModels { /** * This targets two bug classes: * - operation inputs used as nested outputs - * - operation outputs used as nested outputs + * - operation outputs used as nested inputs */ fun reusedInputOutputShapesModel(protocol: Trait) = """ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index 13a59f3baef..42997a80127 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.testutil import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.loader.ModelDiscovery import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape @@ -151,7 +152,23 @@ fun String.asSmithyModel( disableValidation: Boolean = false, ): Model { val processed = letIf(!this.trimStart().startsWith("\$version")) { "\$version: ${smithyVersion.dq()}\n$it" } - val assembler = Model.assembler().discoverModels().addUnparsedModel(sourceLocation ?: "test.smithy", processed) + val denyModelsContaining = + arrayOf( + // If Smithy protocol test models are in our classpath, don't load them, since they are fairly large and we + // almost never need them. + "smithy-protocol-tests", + ) + val urls = + ModelDiscovery.findModels().filter { modelUrl -> + denyModelsContaining.none { + modelUrl.toString().contains(it) + } + } + val assembler = Model.assembler() + for (url in urls) { + assembler.addImport(url) + } + assembler.addUnparsedModel(sourceLocation ?: "test.smithy", processed) if (disableValidation) { assembler.disableValidation() } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt index 975416d72fe..f6d6ddf84fa 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt @@ -24,19 +24,15 @@ import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait -inline fun Model.lookup(shapeId: String): T { - return this.expectShape(ShapeId.from(shapeId), T::class.java) -} +inline fun Model.lookup(shapeId: String): T = this.expectShape(ShapeId.from(shapeId), T::class.java) -fun OperationShape.inputShape(model: Model): StructureShape { +fun OperationShape.inputShape(model: Model): StructureShape = // The Rust Smithy generator adds an input to all shapes automatically - return model.expectShape(this.input.get(), StructureShape::class.java) -} + model.expectShape(this.input.get(), StructureShape::class.java) -fun OperationShape.outputShape(model: Model): StructureShape { +fun OperationShape.outputShape(model: Model): StructureShape = // The Rust Smithy generator adds an output to all shapes automatically - return model.expectShape(this.output.get(), StructureShape::class.java) -} + model.expectShape(this.output.get(), StructureShape::class.java) fun StructureShape.expectMember(member: String): MemberShape = this.getMember(member).orElseThrow { CodegenException("$member did not exist on $this") } @@ -55,43 +51,32 @@ fun UnionShape.hasStreamingMember(model: Model) = this.findMemberWithTrait() -} +fun MemberShape.isInputEventStream(model: Model): Boolean = + isEventStream(model) && model.expectShape(container).hasTrait() -fun MemberShape.isOutputEventStream(model: Model): Boolean { - return isEventStream(model) && model.expectShape(container).hasTrait() -} +fun MemberShape.isOutputEventStream(model: Model): Boolean = + isEventStream(model) && model.expectShape(container).hasTrait() private val unitShapeId = ShapeId.from("smithy.api#Unit") -fun MemberShape.isTargetUnit(): Boolean { - return this.target == unitShapeId -} +fun Shape.isUnit(): Boolean = this.id == unitShapeId -fun Shape.hasEventStreamMember(model: Model): Boolean { - return members().any { it.isEventStream(model) } -} +fun MemberShape.isTargetUnit(): Boolean = this.target == unitShapeId -fun OperationShape.isInputEventStream(model: Model): Boolean { - return input.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) -} +fun Shape.hasEventStreamMember(model: Model): Boolean = members().any { it.isEventStream(model) } -fun OperationShape.isOutputEventStream(model: Model): Boolean { - return output.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) -} +fun OperationShape.isInputEventStream(model: Model): Boolean = + input.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) -fun OperationShape.isEventStream(model: Model): Boolean { - return isInputEventStream(model) || isOutputEventStream(model) -} +fun OperationShape.isOutputEventStream(model: Model): Boolean = + output.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false) + +fun OperationShape.isEventStream(model: Model): Boolean = isInputEventStream(model) || isOutputEventStream(model) fun ServiceShape.hasEventStreamOperations(model: Model): Boolean = operations.any { id -> @@ -125,17 +110,13 @@ fun Shape.redactIfNecessary( * * A structure must have at most one streaming member. */ -fun StructureShape.findStreamingMember(model: Model): MemberShape? { - return this.findMemberWithTrait(model) -} +fun StructureShape.findStreamingMember(model: Model): MemberShape? = this.findMemberWithTrait(model) -inline fun StructureShape.findMemberWithTrait(model: Model): MemberShape? { - return this.members().find { it.getMemberTrait(model, T::class.java).isPresent } -} +inline fun StructureShape.findMemberWithTrait(model: Model): MemberShape? = + this.members().find { it.getMemberTrait(model, T::class.java).isPresent } -inline fun UnionShape.findMemberWithTrait(model: Model): MemberShape? { - return this.members().find { it.getMemberTrait(model, T::class.java).isPresent } -} +inline fun UnionShape.findMemberWithTrait(model: Model): MemberShape? = + this.members().find { it.getMemberTrait(model, T::class.java).isPresent } /** * If is member shape returns target, otherwise returns self. @@ -156,12 +137,11 @@ inline fun Shape.expectTrait(): T = expectTrait(T::class.jav /** Kotlin sugar for getTrait() check. e.g. shape.getTrait() instead of shape.getTrait(EnumTrait::class.java) */ inline fun Shape.getTrait(): T? = getTrait(T::class.java).orNull() -fun Shape.isPrimitive(): Boolean { - return when (this) { +fun Shape.isPrimitive(): Boolean = + when (this) { is NumberShape, is BooleanShape -> true else -> false } -} /** Convert a string to a ShapeId */ fun String.shapeId() = ShapeId.from(this) diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index 4134cdd0391..808d476058d 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -24,6 +24,7 @@ val workingDirUnderBuildDir = "smithyprojections/codegen-server-test/" dependencies { implementation(project(":codegen-server")) implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") @@ -43,6 +44,12 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> imports = listOf("$commonModels/naming-obstacle-course-structs.smithy"), ), CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")), + CodegenTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"), + CodegenTest( + "smithy.protocoltests.rpcv2Cbor#RpcV2CborService", + "rpcv2Cbor_extras", + imports = listOf("$commonModels/rpcv2Cbor-extras.smithy") + ), CodegenTest( "com.amazonaws.constraints#ConstraintsService", "constraints_without_public_constrained_types", diff --git a/codegen-server/build.gradle.kts b/codegen-server/build.gradle.kts index 0ba262225d1..49e0462888c 100644 --- a/codegen-server/build.gradle.kts +++ b/codegen-server/build.gradle.kts @@ -26,10 +26,14 @@ dependencies { implementation(project(":codegen-core")) implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") + implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") // `smithy.framework#ValidationException` is defined here, which is used in `constraints.smithy`, which is used // in `CustomValidationExceptionWithReasonDecoratorTest`. testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") + + // It's handy to re-use protocol test suite models from Smithy in our Kotlin tests. + testImplementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion") } java { diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt index 4e41cc8ed37..f50e2363370 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt @@ -56,10 +56,10 @@ class PythonServerAfterDeserializedMemberJsonParserCustomization(private val run } /** - * Customization class used to force casting a non primitive type into one overriden by a new symbol provider, + * Customization class used to force casting a non-primitive type into one overridden by a new symbol provider, * by explicitly calling `into()` on it. */ -class PythonServerAfterDeserializedMemberServerHttpBoundCustomization() : +class PythonServerAfterDeserializedMemberServerHttpBoundCustomization : ServerHttpBoundProtocolCustomization() { override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 01df0c0a937..a9c488503d6 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig */ object ServerCargoDependency { val AsyncTrait: CargoDependency = CargoDependency("async-trait", CratesIo("0.1.74")) + val Base64SimdDev: CargoDependency = CargoDependency("base64-simd", CratesIo("0.8"), scope = DependencyScope.Dev) val FormUrlEncoded: CargoDependency = CargoDependency("form_urlencoded", CratesIo("1")) val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3")) val Mime: CargoDependency = CargoDependency("mime", CratesIo("0.3")) @@ -26,7 +27,7 @@ object ServerCargoDependency { val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.23.1"), scope = DependencyScope.Dev) val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5")) - val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), DependencyScope.Dev) + val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), scope = DependencyScope.Dev) fun smithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server") diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index ea4eadad84e..49c0e7c5403 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -609,13 +609,20 @@ open class ServerCodegenVisitor( } /** - * Generate protocol tests. This method can be overridden by other languages such has Python. + * Generate protocol tests. This method can be overridden by other languages such as Python. */ open fun protocolTestsForOperation( writer: RustWriter, shape: OperationShape, ) { - ServerProtocolTestGenerator(codegenContext, protocolGeneratorFactory.support(), shape).render(writer) + codegenDecorator.protocolTestGenerator( + codegenContext, + ServerProtocolTestGenerator( + codegenContext, + protocolGeneratorFactory.support(), + shape, + ), + ).render(writer) } /** diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt new file mode 100644 index 00000000000..464a52dc463 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt @@ -0,0 +1,60 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.escape +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection +import software.amazon.smithy.rust.codegen.core.util.hasTrait + +/** + * Smithy RPC v2 CBOR requires errors to be serialized in server responses with an additional `__type` field. + * + * Note that we apply this customization when serializing _any_ structure with the `@error` trait, regardless if it's + * an error response or not. Consider this model: + * + * ```smithy + * operation ErrorSerializationOperation { + * input: SimpleStruct + * output: ErrorSerializationOperationOutput + * errors: [ValidationException] + * } + * + * structure ErrorSerializationOperationOutput { + * errorShape: ValidationException + * } + * ``` + * + * `ValidationException` is re-used across the operation output and the operation error. The `__type` field will + * appear when serializing both. + * + * Strictly speaking, the spec says we should only add `__type` when serializing an operation error response, but + * there shouldn't™️ be any harm in always including it, which simplifies the code generator. + */ +class AddTypeFieldToServerErrorsCborCustomization : CborSerializerCustomization() { + override fun section(section: CborSerializerSection): Writable = + when (section) { + is CborSerializerSection.BeforeSerializingStructureMembers -> + if (section.structureShape.hasTrait()) { + writable { + rust( + """ + ${section.encoderBindingName} + .str("__type") + .str("${escape(section.structureShape.id.toString())}"); + """, + ) + } + } else { + emptySection + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt new file mode 100644 index 00000000000..a01d0076e96 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeEncodingMapOrCollectionCborCustomization.kt @@ -0,0 +1,41 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType + +/** + * A customization to, just before we encode over a _constrained_ map or collection shape in a CBOR serializer, + * unwrap the wrapper newtype and take a shared reference to the actual value within it. + * That value will be a `std::collections::HashMap` for map shapes, and a `std::vec::Vec` for collection shapes. + */ +class BeforeEncodingMapOrCollectionCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() { + override fun section(section: CborSerializerSection): Writable = + when (section) { + is CborSerializerSection.BeforeIteratingOverMapOrCollection -> + writable { + check(section.shape is CollectionShape || section.shape is MapShape) + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index d06c21ff70c..5b4860c26db 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -70,18 +70,25 @@ class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat codegenContext.symbolProvider.toSymbol(memberShape).isOptional() } -class ServerInstantiator(codegenContext: CodegenContext, customWritable: CustomWritable = NoCustomWritable()) : +class ServerInstantiator( + codegenContext: CodegenContext, + customWritable: CustomWritable = NoCustomWritable(), + ignoreMissingMembers: Boolean = false, + withinTest: Boolean = false, +) : Instantiator( - codegenContext.symbolProvider, - codegenContext.model, - codegenContext.runtimeConfig, - ServerBuilderKindBehavior(codegenContext), - defaultsForRequiredFields = true, - customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)), - // Construct with direct pattern to more closely replicate actual server customer usage - constructPattern = InstantiatorConstructPattern.DIRECT, - customWritable = customWritable, - ) + codegenContext.symbolProvider, + codegenContext.model, + codegenContext.runtimeConfig, + ServerBuilderKindBehavior(codegenContext), + defaultsForRequiredFields = true, + customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)), + // Construct with direct pattern to more closely replicate actual server customer usage + constructPattern = InstantiatorConstructPattern.DIRECT, + customWritable = customWritable, + ignoreMissingMembers = ignoreMissingMembers, + withinTest = withinTest, + ) class ServerBuilderInstantiator( private val symbolProvider: RustSymbolProvider, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index f31f6d92da5..d4984d65e9e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -24,18 +24,26 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.ReturnSymbolToParse import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator +import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator @@ -298,6 +306,68 @@ class ServerRestXmlProtocol( ) } +class ServerRpcV2CborProtocol( + private val serverCodegenContext: ServerCodegenContext, +) : RpcV2Cbor(serverCodegenContext), ServerProtocol { + val runtimeConfig = codegenContext.runtimeConfig + + override val protocolModulePath = "rpc_v2_cbor" + + override fun structuredDataParser(): StructuredDataParserGenerator = + CborParserGenerator( + serverCodegenContext, httpBindingResolver, returnSymbolToParseFn(serverCodegenContext), + listOf( + ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization( + serverCodegenContext, + ), + ), + ) + + override fun structuredDataSerializer(): StructuredDataSerializerGenerator { + return CborSerializerGenerator( + codegenContext, + httpBindingResolver, + listOf( + BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext), + AddTypeFieldToServerErrorsCborCustomization(), + ), + ) + } + + override fun markerStruct() = ServerRuntimeType.protocol("RpcV2Cbor", "rpc_v2_cbor", runtimeConfig) + + override fun routerType() = + ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() + .resolve("protocol::rpc_v2_cbor::router::RpcV2CborRouter") + + override fun serverRouterRequestSpec( + operationShape: OperationShape, + operationName: String, + serviceName: String, + requestSpecModule: RuntimeType, + ) = writable { + // This is just the key used by the router's map to store and look up operations, it's completely arbitrary. + // We use the same key used by the awsJson1.x routers for simplicity. + // The router will extract the service name and the operation name from the URI, build this key, and lookup the + // operation stored there. + rust("$serviceName.$operationName".dq()) + } + + override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.StaticStr + + override fun serverRouterRuntimeConstructor() = "rpc_v2_router" + + override fun serverContentTypeCheckNoModeledInput() = false + + override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType = + deserializePayloadErrorType( + codegenContext, + binding, + requestRejection(runtimeConfig), + RuntimeType.smithyCbor(codegenContext.runtimeConfig).resolve("decode::DeserializeError"), + ) +} + /** Just a common function to keep things DRY. **/ fun deserializePayloadErrorType( codegenContext: CodegenContext, @@ -317,8 +387,8 @@ fun deserializePayloadErrorType( } /** - * A customization to, just before we box a recursive member that we've deserialized into `Option`, convert it into - * `MaybeConstrained` if the target shape can reach a constrained shape. + * A customization to, just before we box a recursive member that we've deserialized from JSON into `Option`, convert + * it into `MaybeConstrained` if the target shape can reach a constrained shape. */ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(val codegenContext: ServerCodegenContext) : JsonParserCustomization() { @@ -338,3 +408,24 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonPa else -> emptySection } } + +/** + * A customization to, just before we box a recursive member that we've deserialized from CBOR into `T` held in a + * variable binding `v`, convert it into `MaybeConstrained` if the target shape can reach a constrained shape. + */ +class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborParserCustomization(val codegenContext: ServerCodegenContext) : + CborParserCustomization() { + override fun section(section: CborParserSection): Writable = + when (section) { + is CborParserSection.BeforeBoxingDeserializedMember -> + writable { + // We're only interested in _structure_ member shapes that can reach constrained shapes. + if ( + codegenContext.model.expectShape(section.shape.container) is StructureShape && + section.shape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider) + ) { + rust("let v = v.into();") + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index cbcbca16e3a..09e6b635de5 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.DoubleShape import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.protocoltests.traits.AppliesTo @@ -38,6 +39,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Servi import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_11 import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON_VALIDATION +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR_EXTRAS import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.core.util.dq @@ -145,8 +148,15 @@ class ServerProtocolTestGenerator( AWS_JSON_10, "AwsJson10ServerPopulatesNestedDefaultValuesWhenMissingInInResponseParams", ), + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3723): This affects all protocols + FailingTest.MalformedRequestTest(RPC_V2_CBOR_EXTRAS, "AdditionalTokensEmptyStruct"), + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3339) + FailingTest.ResponseTest(RPC_V2_CBOR, "RpcV2CborServerPopulatesDefaultsInResponseWhenMissingInParams"), FailingTest.ResponseTest(REST_JSON, "RestJsonServerPopulatesDefaultsInResponseWhenMissingInParams"), FailingTest.ResponseTest(REST_JSON, "RestJsonServerPopulatesNestedDefaultValuesWhenMissingInInResponseParams"), + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3743): We need to be able to configure + // instantiator so that it uses default _modeled_ values; `""` is not a valid enum value for `defaultEnum`. + FailingTest.RequestTest(RPC_V2_CBOR, "RpcV2CborServerPopulatesDefaultsWhenMissingInRequestBody"), // TODO(https://github.com/smithy-lang/smithy-rs/issues/3735): Null `Document` may come through a request even though its shape is `@required` FailingTest.RequestTest(REST_JSON, "RestJsonServerPopulatesDefaultsWhenMissingInRequestBody"), ) @@ -223,7 +233,7 @@ class ServerProtocolTestGenerator( get() = ExpectFail override val brokenTests: Set get() = BrokenTests - override val runOnly: Set + override val generateOnly: Set get() = emptySet() override val disabledTests: Set get() = DisabledTests @@ -258,10 +268,11 @@ class ServerProtocolTestGenerator( inputT to outputT } - private val instantiator = ServerInstantiator(codegenContext) + private val instantiator = ServerInstantiator(codegenContext, withinTest = true) private val codegenScope = arrayOf( + "Base64SimdDev" to ServerCargoDependency.Base64SimdDev.toType(), "Bytes" to RuntimeType.Bytes, "Hyper" to RuntimeType.Hyper, "Tokio" to ServerCargoDependency.TokioDev.toType(), @@ -288,20 +299,31 @@ class ServerProtocolTestGenerator( * an operation's input shape, the resulting shape is of the form we expect, as defined in the test case. */ private fun RustWriter.renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase) { + logger.info("Generating request test: ${httpRequestTestCase.id}") + if (!protocolSupport.requestDeserialization) { rust("/* test case disabled for this protocol (not yet supported) */") return } with(httpRequestTestCase) { - renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) + renderHttpRequest( + uri, + method, + headers, + body.orNull(), + bodyMediaType.orNull(), + protocol, + queryParams, + host.orNull(), + ) } if (protocolSupport.requestBodyDeserialization) { makeRequest(operationShape, operationSymbol, this, checkRequestHandler(operationShape, httpRequestTestCase)) checkHandlerWasEntered(this) } - // Explicitly warn if the test case defined parameters that we aren't doing anything with + // Explicitly warn if the test case defined parameters that we aren't doing anything with. with(httpRequestTestCase) { if (authScheme.isPresent) { logger.warning("Test case provided authScheme but this was ignored") @@ -322,6 +344,8 @@ class ServerProtocolTestGenerator( testCase: HttpResponseTestCase, shape: StructureShape, ) { + logger.info("Generating response test: ${testCase.id}") + val operationErrorName = "crate::error::${operationSymbol.name}Error" if (!protocolSupport.responseSerialization || ( @@ -354,6 +378,8 @@ class ServerProtocolTestGenerator( * with the given response. */ private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase) { + logger.info("Generating malformed request test: ${testCase.id}") + val (_, outputT) = operationInputOutputTypes[operationShape]!! val panicMessage = "request should have been rejected, but we accepted it; we parsed operation input `{:?}`" @@ -361,7 +387,18 @@ class ServerProtocolTestGenerator( rustBlock("") { with(testCase.request) { // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. - renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull()) + // TODO(https://github.com/smithy-lang/smithy/issues/1932): we send `null` for `bodyMediaType` for now but + // the Smithy protocol test should give it to us. + renderHttpRequest( + uri.get(), + method, + headers, + body.orNull(), + bodyMediaType = null, + testCase.protocol, + queryParams, + host.orNull(), + ) } makeRequest( @@ -379,6 +416,8 @@ class ServerProtocolTestGenerator( method: String, headers: Map, body: String?, + bodyMediaType: String?, + protocol: ShapeId, queryParams: List, host: String?, ) { @@ -409,7 +448,26 @@ class ServerProtocolTestGenerator( // We also escape to avoid interactions with templating in the case where the body contains `#`. val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq() - "#{SmithyHttpServer}::body::Body::from(#{Bytes}::from_static($sanitizedBody.as_bytes()))" + // TODO(https://github.com/smithy-lang/smithy/issues/1932): We're using the `protocol` field as a + // proxy for `bodyMediaType`. This works because `rpcv2Cbor` happens to be the only protocol where + // the body is base64-encoded in the protocol test, but checking `bodyMediaType` should be a more + // resilient check. + val encodedBody = + if (protocol.toShapeId() == ShapeId.from("smithy.protocols#rpcv2Cbor")) { + """ + #{Bytes}::from( + #{Base64SimdDev}::STANDARD.decode_to_vec($sanitizedBody).expect( + "`body` field of Smithy protocol test is not correctly base64 encoded" + ) + ) + """ + } else { + """ + #{Bytes}::from_static($sanitizedBody.as_bytes()) + """ + } + + "#{SmithyHttpServer}::body::Body::from($encodedBody)" } else { "#{SmithyHttpServer}::body::Body::empty()" } @@ -426,7 +484,7 @@ class ServerProtocolTestGenerator( } } - /** Returns the body of the request test. */ + /** Returns the body of the operation handler in a request test. */ private fun checkRequestHandler( operationShape: OperationShape, httpRequestTestCase: HttpRequestTestCase, @@ -434,7 +492,7 @@ class ServerProtocolTestGenerator( val inputShape = operationShape.inputShape(codegenContext.model) val outputShape = operationShape.outputShape(codegenContext.model) - // Construct expected request. + // Construct expected operation input. withBlock("let expected = ", ";") { instantiator.render(this, inputShape, httpRequestTestCase.params, httpRequestTestCase.headers) } @@ -442,14 +500,14 @@ class ServerProtocolTestGenerator( checkRequestParams(inputShape, this) // Construct a dummy response. - withBlock("let response = ", ";") { + withBlock("let output = ", ";") { instantiator.render(this, outputShape, Node.objectNode()) } if (operationShape.errors.isEmpty()) { - rust("response") + rust("output") } else { - rust("Ok(response)") + rust("Ok(output)") } } @@ -634,13 +692,13 @@ class ServerProtocolTestGenerator( rustWriter.rustTemplate( """ // No body. - #{AssertEq}(std::str::from_utf8(&body).unwrap(), ""); + #{AssertEq}(&body, &bytes::Bytes::new()); """, *codegenScope, ) } else { assertOk(rustWriter) { - rustWriter.rust( + rust( "#T(&body, ${ rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 3d94bb8821b..8eeab9c22e9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -54,7 +54,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.core.smithy.wrapOptional import software.amazon.smithy.rust.codegen.core.util.dq @@ -109,7 +109,7 @@ typealias ServerHttpBoundProtocolCustomization = NamedCustomization + setResponseHeaderIfAbsent(this, "content-type", contentTypeValue) + } + + for ((headerName, headerValue) in protocol.additionalResponseHeaders(operationShape)) { + setResponseHeaderIfAbsent(this, headerName, headerValue) } if (errorShape != null) { for ((headerName, headerValue) in protocol.additionalErrorResponseHeaders(errorShape)) { - rustTemplate( - """ - builder = #{header_util}::set_response_header_if_absent( - builder, - http::header::HeaderName::from_static("$headerName"), - "${escape(headerValue)}" - ); - """, - *codegenScope, - ) + setResponseHeaderIfAbsent(this, headerName, headerValue) } } } @@ -709,6 +720,28 @@ class ServerHttpBoundProtocolTraitImplGenerator( // there's something to parse (i.e. `parser != null`), so `!!` is safe here. val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!! rustTemplate("let bytes = #{Hyper}::body::to_bytes(body).await?;", *codegenScope) + // Note that the server is being very lenient here. We're accepting an empty body for when there is modeled + // operation input; we simply parse it as empty operation input. + // This behavior applies to all protocols. This might seem like a bug, but it isn't. There's protocol tests + // that assert that the server should be lenient and accept both empty payloads and no payload + // when there is modeled input: + // + // * [restJson1]: clients omit the payload altogether when the input is empty! So services must accept this. + // * [rpcv2Cbor]: services must accept no payload or empty CBOR map for operations with modeled input. + // + // For the AWS JSON 1.x protocols, services are lenient in the case when there is no modeled input: + // + // * [awsJson1_0]: services must accept no payload or empty JSON document payload for operations with no modeled input + // * [awsJson1_1]: services must accept no payload or empty JSON document payload for operations with no modeled input + // + // However, it's true that there are no tests pinning server behavior when there is _empty_ input. There's + // a [consultation with Smithy] to remedy this. Until that gets resolved, in the meantime, we are being lenient. + // + // [restJson1]: https://github.com/smithy-lang/smithy/blob/main/smithy-aws-protocol-tests/model/restJson1/empty-input-output.smithy#L22 + // [awsJson1_0]: https://github.com/smithy-lang/smithy/blob/main/smithy-aws-protocol-tests/model/awsJson1_0/empty-input-output.smithy + // [awsJson1_1]: https://github.com/smithy-lang/smithy/blob/main/smithy-aws-protocol-tests/model/awsJson1_1/empty-operation.smithy + // [rpcv2Cbor]: https://github.com/smithy-lang/smithy/blob/main/smithy-protocol-tests/model/rpcv2Cbor/empty-input-output.smithy + // [consultation with Smithy]: https://github.com/smithy-lang/smithy/issues/2327 rustBlock("if !bytes.is_empty()") { rustTemplate( """ @@ -750,7 +783,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( serverRenderQueryStringParser(this, operationShape) // If there's no modeled operation input, some protocols require that `Content-Type` header not be present. - val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null + val noInputs = !OperationNormalizer.hadUserModeledOperationInput(operationShape, model) if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) { rustTemplate( """ @@ -760,6 +793,9 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3723): we should inject a check here that asserts that + // the body contents are valid when there is empty operation input or no operation input. + val err = if (ServerBuilderGenerator.hasFallibleBuilder( inputShape, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index ae87ec5723b..a121697eb7f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -20,7 +21,7 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator -class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { +class StreamPayloadSerializerCustomization : ServerHttpBoundProtocolCustomization() { override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { is ServerHttpBoundProtocolSection.WrapStreamPayload -> @@ -79,6 +80,13 @@ class ServerProtocolLoader(supportedProtocols: ProtocolMap = + emptyList(), +) : ProtocolGeneratorFactory { + override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRpcV2CborProtocol(codegenContext) + + override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = + ServerHttpBoundProtocolGenerator( + codegenContext, + ServerRpcV2CborProtocol(codegenContext), + additionalServerHttpBoundProtocolCustomizations, + ) + + override fun support(): ProtocolSupport { + return ProtocolSupport( + // Client support + requestSerialization = false, + requestBodySerialization = false, + responseDeserialization = false, + errorDeserialization = false, + // Server support + requestDeserialization = true, + requestBodyDeserialization = true, + responseSerialization = true, + errorSerialization = true, + ) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt new file mode 100644 index 00000000000..2e92cde4e29 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest.kt @@ -0,0 +1,358 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.protocoltests.traits.AppliesTo +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.BrokenTest +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRpcV2CborProtocol +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRpcV2CborFactory +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.util.function.Predicate +import java.util.logging.Logger + +/** + * This lives in `codegen-server` because we want to run a full integration test for convenience, + * but there's really nothing server-specific here. We're just testing that the CBOR (de)serializers work like + * the ones generated by `serde_cbor`. This is a good exhaustive litmus test for correctness, since `serde_cbor` + * is battle-tested. + */ +internal class CborSerializerAndParserGeneratorSerdeRoundTripIntegrationTest { + class DeriveSerdeSerializeDeserializeSymbolMetadataProvider( + private val base: RustSymbolProvider, + ) : SymbolMetadataProvider(base) { + private val serdeDeserialize = + CargoDependency.Serde.copy(scope = DependencyScope.Compile).toType().resolve("Deserialize") + private val serdeSerialize = + CargoDependency.Serde.copy(scope = DependencyScope.Compile).toType().resolve("Serialize") + + private fun addDeriveSerdeSerializeDeserialize(shape: Shape): RustMetadata { + check(shape !is MemberShape) + + val baseMetadata = base.toSymbol(shape).expectRustMetadata() + return baseMetadata.withDerives(serdeSerialize, serdeDeserialize) + } + + override fun memberMeta(memberShape: MemberShape): RustMetadata { + val baseMetadata = base.toSymbol(memberShape).expectRustMetadata() + return baseMetadata.copy( + additionalAttributes = + baseMetadata.additionalAttributes + + Attribute( + """serde(rename = "${memberShape.memberName}")""", + isDeriveHelper = true, + ), + ) + } + + override fun structureMeta(structureShape: StructureShape) = addDeriveSerdeSerializeDeserialize(structureShape) + + override fun unionMeta(unionShape: UnionShape) = addDeriveSerdeSerializeDeserialize(unionShape) + + override fun enumMeta(stringShape: StringShape) = addDeriveSerdeSerializeDeserialize(stringShape) + + override fun listMeta(listShape: ListShape): RustMetadata = addDeriveSerdeSerializeDeserialize(listShape) + + override fun mapMeta(mapShape: MapShape): RustMetadata = addDeriveSerdeSerializeDeserialize(mapShape) + + override fun stringMeta(stringShape: StringShape): RustMetadata = + addDeriveSerdeSerializeDeserialize(stringShape) + + override fun numberMeta(numberShape: NumberShape): RustMetadata = + addDeriveSerdeSerializeDeserialize(numberShape) + + override fun blobMeta(blobShape: BlobShape): RustMetadata = addDeriveSerdeSerializeDeserialize(blobShape) + } + + fun prepareRpcV2CborModel(): Model { + var model = Model.assembler().discoverModels().assemble().result.get() + + // Filter out `timestamp` and `blob` shapes: those map to runtime types in `aws-smithy-types` on + // which we can't `#[derive(serde::Deserialize)]`. + // Note we can't use `ModelTransformer.removeShapes` because it will leave the model in an inconsistent state + // when removing list/set shape member shapes. + val removeTimestampAndBlobShapes: Predicate = + Predicate { shape -> + when (shape) { + is MemberShape -> { + val targetShape = model.expectShape(shape.target) + targetShape is BlobShape || targetShape is TimestampShape + } + is BlobShape, is TimestampShape -> true + is CollectionShape -> { + val targetShape = model.expectShape(shape.member.target) + targetShape is BlobShape || targetShape is TimestampShape + } + else -> false + } + } + + fun removeShapesByShapeId(shapeIds: Set): Predicate { + val predicate: Predicate = + Predicate { shape -> + when (shape) { + is MemberShape -> { + val targetShape = model.expectShape(shape.target) + shapeIds.contains(targetShape.id) + } + is CollectionShape -> { + val targetShape = model.expectShape(shape.member.target) + shapeIds.contains(targetShape.id) + } + else -> { + shapeIds.contains(shape.id) + } + } + } + return predicate + } + + val modelTransformer = ModelTransformer.create() + model = + modelTransformer.removeShapesIf( + modelTransformer.removeShapesIf(model, removeTimestampAndBlobShapes), + // These enums do not serialize their variants using the Rust members' names. + // We'd have to tack on `#[serde(rename = "name")]` using the proper name defined in the Smithy enum definition. + // But we have no way of injecting that attribute on Rust enum variants in the code generator. + // So we just remove these problematic shapes. + removeShapesByShapeId( + setOf( + ShapeId.from("smithy.protocoltests.shared#FooEnum"), + ShapeId.from("smithy.protocoltests.rpcv2Cbor#TestEnum"), + ), + ), + ) + + return model + } + + @Test + fun `serde_cbor round trip`() { + val addDeriveSerdeSerializeDeserializeDecorator = + object : ServerCodegenDecorator { + override val name: String = "Add `#[derive(serde::Serialize, serde::Deserialize)]`" + override val order: Byte = 0 + + override fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider = + DeriveSerdeSerializeDeserializeSymbolMetadataProvider(base) + } + + // Don't generate protocol tests, because it'll attempt to pull out `params` for member shapes we'll remove + // from the model. + val noProtocolTestsDecorator = + object : ServerCodegenDecorator { + override val name: String = "Don't generate protocol tests" + override val order: Byte = 0 + + override fun protocolTestGenerator( + codegenContext: ServerCodegenContext, + baseGenerator: ProtocolTestGenerator, + ): ProtocolTestGenerator { + val noOpProtocolTestsGenerator = + object : ProtocolTestGenerator() { + override val codegenContext: CodegenContext + get() = baseGenerator.codegenContext + override val protocolSupport: ProtocolSupport + get() = baseGenerator.protocolSupport + override val operationShape: OperationShape + get() = baseGenerator.operationShape + override val appliesTo: AppliesTo + get() = baseGenerator.appliesTo + override val logger: Logger + get() = Logger.getLogger(javaClass.name) + override val expectFail: Set + get() = baseGenerator.expectFail + override val brokenTests: Set + get() = emptySet() + override val generateOnly: Set + get() = baseGenerator.generateOnly + override val disabledTests: Set + get() = baseGenerator.disabledTests + + override fun RustWriter.renderAllTestCases(allTests: List) { + // No-op. + } + } + return noOpProtocolTestsGenerator + } + } + + val model = prepareRpcV2CborModel() + val serviceShape = model.expectShape(ShapeId.from(RPC_V2_CBOR)) + serverIntegrationTest( + model, + additionalDecorators = listOf(addDeriveSerdeSerializeDeserializeDecorator, noProtocolTestsDecorator), + params = IntegrationTestParams(service = serviceShape.id.toString()), + ) { codegenContext, rustCrate -> + // TODO(https://github.com/smithy-lang/smithy-rs/issues/1147): NaN != NaN. Ideally we when we address + // this issue, we'd re-use the structure shape comparison code that both client and server protocol test + // generators would use. + val expectFail = setOf("RpcV2CborSupportsNaNFloatInputs", "RpcV2CborSupportsNaNFloatOutputs") + + val codegenScope = + arrayOf( + "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), + "SerdeCbor" to CargoDependency.SerdeCbor.toType(), + ) + + val instantiator = ServerInstantiator(codegenContext, ignoreMissingMembers = true, withinTest = true) + val rpcv2Cbor = ServerRpcV2CborProtocol(codegenContext) + + for (operationShape in codegenContext.model.operationShapes) { + val serverProtocolTestGenerator = + ServerProtocolTestGenerator(codegenContext, ServerRpcV2CborFactory().support(), operationShape) + + rustCrate.withModule(ProtocolFunctions.serDeModule) { + // The SDK can only serialize operation outputs, so we only ask for response tests. + val responseTests = + serverProtocolTestGenerator.responseTestCases() + + for (test in responseTests) { + when (test) { + is TestCase.MalformedRequestTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.RequestTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.ResponseTest -> { + val targetShape = test.targetShape + val params = test.testCase.params + + val serializeFn = + if (targetShape.hasTrait()) { + rpcv2Cbor.structuredDataSerializer().serverErrorSerializer(targetShape.id) + } else { + rpcv2Cbor.structuredDataSerializer().operationOutputSerializer(operationShape) + } + + if (serializeFn == null) { + // Skip if there's nothing to serialize. + continue + } + + if (expectFail.contains(test.id)) { + writeWithNoFormatting("#[should_panic]") + } + unitTest("we_serialize_and_serde_cbor_deserializes_${test.id.toSnakeCase()}_${test.kind.toString().toSnakeCase()}") { + rustTemplate( + """ + let expected = #{InstantiateShape:W}; + let bytes = #{SerializeFn}(&expected) + .expect("our generated CBOR serializer failed"); + let actual = #{SerdeCbor}::from_slice(&bytes) + .expect("serde_cbor failed deserializing from bytes"); + #{AssertEq}(expected, actual); + """, + "InstantiateShape" to instantiator.generate(targetShape, params), + "SerializeFn" to serializeFn, + *codegenScope, + ) + } + } + } + } + + // The SDK can only deserialize operation inputs, so we only ask for request tests. + val requestTests = + serverProtocolTestGenerator.requestTestCases() + val inputShape = operationShape.inputShape(codegenContext.model) + val err = + if (ServerBuilderGenerator.hasFallibleBuilder( + inputShape, + codegenContext.model, + codegenContext.symbolProvider, + takeInUnconstrainedTypes = true, + ) + ) { + """.expect("builder failed to build")""" + } else { + "" + } + + for (test in requestTests) { + when (test) { + is TestCase.MalformedRequestTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.ResponseTest -> UNREACHABLE("we did not ask for tests of this kind") + is TestCase.RequestTest -> { + val targetShape = operationShape.inputShape(codegenContext.model) + val params = test.testCase.params + + val deserializeFn = + rpcv2Cbor.structuredDataParser().serverInputParser(operationShape) + ?: // Skip if there's nothing to serialize. + continue + + if (expectFail.contains(test.id)) { + writeWithNoFormatting("#[should_panic]") + } + unitTest("serde_cbor_serializes_and_we_deserialize_${test.id.toSnakeCase()}_${test.kind.toString().toSnakeCase()}") { + rustTemplate( + """ + let expected = #{InstantiateShape:W}; + let bytes: Vec = #{SerdeCbor}::to_vec(&expected) + .expect("serde_cbor failed serializing to `Vec`"); + let input = #{InputBuilder}::default(); + let input = #{DeserializeFn}(&bytes, input) + .expect("our generated CBOR deserializer failed"); + let actual = input.build()$err; + #{AssertEq}(expected, actual); + """, + "InstantiateShape" to instantiator.generate(targetShape, params), + "DeserializeFn" to deserializeFn, + "InputBuilder" to inputShape.serverBuilderSymbol(codegenContext), + *codegenScope, + ) + } + } + } + } + } + } + } + } +} diff --git a/examples/Cargo.toml b/examples/Cargo.toml index d92869a661d..a374adf6f0e 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -9,7 +9,6 @@ members = [ "pokemon-service-server-sdk", "pokemon-service-client", "pokemon-service-client-usage", - ] [profile.release] diff --git a/rust-runtime/Cargo.lock b/rust-runtime/Cargo.lock index 1e929f7b588..602ebd4502d 100644 --- a/rust-runtime/Cargo.lock +++ b/rust-runtime/Cargo.lock @@ -302,6 +302,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "aws-smithy-cbor" +version = "0.60.6" +dependencies = [ + "aws-smithy-types 1.2.0", + "criterion", + "minicbor", +] + [[package]] name = "aws-smithy-checksums" version = "0.60.10" @@ -398,7 +407,7 @@ name = "aws-smithy-experimental" version = "0.1.3" dependencies = [ "aws-smithy-async 1.2.1", - "aws-smithy-runtime 1.6.1", + "aws-smithy-runtime 1.6.2", "aws-smithy-runtime-api 1.7.1", "aws-smithy-types 1.2.0", "h2 0.4.5", @@ -465,8 +474,9 @@ version = "0.60.3" [[package]] name = "aws-smithy-http-server" -version = "0.63.0" +version = "0.63.2" dependencies = [ + "aws-smithy-cbor", "aws-smithy-http 0.60.9", "aws-smithy-json 0.60.7", "aws-smithy-runtime-api 1.7.1", @@ -495,7 +505,7 @@ dependencies = [ [[package]] name = "aws-smithy-http-server-python" -version = "0.62.1" +version = "0.63.1" dependencies = [ "aws-smithy-http 0.60.9", "aws-smithy-http-server", @@ -582,14 +592,17 @@ dependencies = [ [[package]] name = "aws-smithy-protocol-test" -version = "0.60.8" +version = "0.62.0" dependencies = [ "assert-json-diff", "aws-smithy-runtime-api 1.7.1", + "base64-simd", + "cbor-diag", "http 0.2.12", "pretty_assertions", "regex-lite", "roxmltree", + "serde_cbor", "serde_json", "thiserror", ] @@ -635,12 +648,12 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.6.1" +version = "1.6.2" dependencies = [ "approx", "aws-smithy-async 1.2.1", "aws-smithy-http 0.60.9", - "aws-smithy-protocol-test 0.60.8", + "aws-smithy-protocol-test 0.62.0", "aws-smithy-runtime-api 1.7.1", "aws-smithy-types 1.2.0", "bytes", @@ -789,7 +802,7 @@ dependencies = [ name = "aws-smithy-xml" version = "0.60.8" dependencies = [ - "aws-smithy-protocol-test 0.60.8", + "aws-smithy-protocol-test 0.62.0", "base64 0.13.1", "proptest", "xmlparser", @@ -954,6 +967,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bs58" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf88ba1141d185c399bee5288d850d63b8369520c1eafc32a0430b5b6c287bf4" +dependencies = [ + "tinyvec", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -985,6 +1007,25 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cbor-diag" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc245b6ecd09b23901a4fbad1ad975701fd5061ceaef6afa93a2d70605a64429" +dependencies = [ + "bs58", + "chrono", + "data-encoding", + "half 2.4.1", + "nom", + "num-bigint", + "num-rational", + "num-traits", + "separator", + "url", + "uuid", +] + [[package]] name = "cc" version = "1.0.99" @@ -1044,7 +1085,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" dependencies = [ "ciborium-io", - "half", + "half 2.4.1", ] [[package]] @@ -1282,6 +1323,12 @@ dependencies = [ "typenum", ] +[[package]] +name = "data-encoding" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" + [[package]] name = "der" version = "0.6.1" @@ -1634,6 +1681,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + [[package]] name = "half" version = "2.4.1" @@ -2168,6 +2221,27 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minicbor" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f8e213c36148d828083ae01948eed271d03f95f7e72571fa242d78184029af2" +dependencies = [ + "half 2.4.1", + "minicbor-derive", +] + +[[package]] +name = "minicbor-derive" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6bdc119b1a405df86a8cde673295114179dbd0ebe18877c26ba89fb080365c2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.67", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2220,6 +2294,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -2235,6 +2319,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -3032,6 +3127,12 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +[[package]] +name = "separator" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f97841a747eef040fcd2e7b3b9a220a7205926e60488e673d9e4926d27772ce5" + [[package]] name = "serde" version = "1.0.203" @@ -3041,6 +3142,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.3", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.203" diff --git a/rust-runtime/Cargo.toml b/rust-runtime/Cargo.toml index 3a618e6bbe9..d7853b00996 100644 --- a/rust-runtime/Cargo.toml +++ b/rust-runtime/Cargo.toml @@ -3,6 +3,7 @@ resolver = "2" members = [ "inlineable", "aws-smithy-async", + "aws-smithy-cbor", "aws-smithy-checksums", "aws-smithy-compression", "aws-smithy-client", diff --git a/rust-runtime/aws-smithy-cbor/Cargo.toml b/rust-runtime/aws-smithy-cbor/Cargo.toml new file mode 100644 index 00000000000..b87366d6ef4 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "aws-smithy-cbor" +version = "0.60.6" +authors = [ + "AWS Rust SDK Team ", + "David Pérez ", +] +description = "CBOR utilities for smithy-rs." +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/awslabs/smithy-rs" + +[dependencies.minicbor] +version = "0.24.2" +features = [ + # To write to a `Vec`: https://docs.rs/minicbor/latest/minicbor/encode/write/trait.Write.html#impl-Write-for-Vec%3Cu8%3E + "alloc", + # To support reading `f16` to accomodate fewer bytes transmitted that fit the value. + "half", +] + +[dependencies] +aws-smithy-types = { path = "../aws-smithy-types" } + +[dev-dependencies] +criterion = "0.5.1" + +[[bench]] +name = "string" +harness = false + +[[bench]] +name = "blob" +harness = false + +[package.metadata.docs.rs] +all-features = true +targets = ["x86_64-unknown-linux-gnu"] +cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] +rustdoc-args = ["--cfg", "docsrs"] +# End of docs.rs metadata diff --git a/rust-runtime/aws-smithy-cbor/LICENSE b/rust-runtime/aws-smithy-cbor/LICENSE new file mode 100644 index 00000000000..67db8588217 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/LICENSE @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. diff --git a/rust-runtime/aws-smithy-cbor/README.md b/rust-runtime/aws-smithy-cbor/README.md new file mode 100644 index 00000000000..367577b3e58 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/README.md @@ -0,0 +1,8 @@ +# aws-smithy-cbor + +CBOR serialization and deserialization primitives for clients and servers +generated by [smithy-rs](https://github.com/smithy-lang/smithy-rs). + + +This crate is part of the [AWS SDK for Rust](https://awslabs.github.io/aws-sdk-rust/) and the [smithy-rs](https://github.com/smithy-lang/smithy-rs) code generator. In most cases, it should not be used directly. + diff --git a/rust-runtime/aws-smithy-cbor/benches/blob.rs b/rust-runtime/aws-smithy-cbor/benches/blob.rs new file mode 100644 index 00000000000..221940bb98e --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/benches/blob.rs @@ -0,0 +1,26 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_cbor::decode::Decoder; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +pub fn blob_benchmark(c: &mut Criterion) { + // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`. + let blob_indefinite_bytes = [ + 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62, 0x79, + 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c, 0x4e, 0x20, + 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0xff, + ]; + + c.bench_function("blob", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&blob_indefinite_bytes); + let _ = black_box(decoder.blob()); + }) + }); +} + +criterion_group!(benches, blob_benchmark); +criterion_main!(benches); diff --git a/rust-runtime/aws-smithy-cbor/benches/string.rs b/rust-runtime/aws-smithy-cbor/benches/string.rs new file mode 100644 index 00000000000..f60ff353e00 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/benches/string.rs @@ -0,0 +1,136 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::borrow::Cow; + +use aws_smithy_cbor::decode::Decoder; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +pub fn str_benchmark(c: &mut Criterion) { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, 0x79, + 0xff, + ]; + + c.bench_function("definite str()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&definite_bytes); + let x = black_box(decoder.str()); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("definite str_alt", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let x = black_box(str_alt(&mut decoder)); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("indefinite str()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&indefinite_bytes); + let x = black_box(decoder.str()); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); + + c.bench_function("indefinite str_alt", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let x = black_box(str_alt(&mut decoder)); + assert!(matches!(x.unwrap().as_ref(), "thisIsAKey")); + }) + }); +} + +// The following seems to be a bit slower than the implementation that we have +// kept in the `aws_smithy_cbor::Decoder`. +pub fn string_alt<'b>( + decoder: &'b mut minicbor::Decoder<'b>, +) -> Result { + decoder.str_iter()?.collect() +} + +// The following seems to be a bit slower than the implementation that we have +// kept in the `aws_smithy_cbor::Decoder`. +fn str_alt<'b>( + decoder: &'b mut minicbor::Decoder<'b>, +) -> Result, minicbor::decode::Error> { + // This implementation uses `next` twice to see if there is + // another str chunk. If there is, it returns a owned `String`. + let mut chunks_iter = decoder.str_iter()?; + let head = match chunks_iter.next() { + Some(Ok(head)) => head, + None => return Ok(Cow::Borrowed("")), + Some(Err(e)) => return Err(e), + }; + + match chunks_iter.next() { + None => Ok(Cow::Borrowed(head)), + Some(Err(e)) => Err(e), + Some(Ok(next)) => { + let mut concatenated_string = String::from(head); + concatenated_string.push_str(next); + for chunk in chunks_iter { + concatenated_string.push_str(chunk?); + } + Ok(Cow::Owned(concatenated_string)) + } + } +} + +// We have two `string` implementations. One uses `collect` the other +// uses `String::new` followed by `string::push`. +pub fn string_benchmark(c: &mut Criterion) { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, 0x79, + 0xff, + ]; + + c.bench_function("definite string()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&definite_bytes); + let _ = black_box(decoder.string()); + }) + }); + + c.bench_function("definite string_alt()", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let _ = black_box(string_alt(&mut decoder)); + }) + }); + + c.bench_function("indefinite string()", |b| { + b.iter(|| { + let mut decoder = Decoder::new(&indefinite_bytes); + let _ = black_box(decoder.string()); + }) + }); + + c.bench_function("indefinite string_alt()", |b| { + b.iter(|| { + let mut decoder = minicbor::decode::Decoder::new(&indefinite_bytes); + let _ = black_box(string_alt(&mut decoder)); + }) + }); +} + +criterion_group!(benches, string_benchmark, str_benchmark,); +criterion_main!(benches); diff --git a/rust-runtime/aws-smithy-cbor/src/data.rs b/rust-runtime/aws-smithy-cbor/src/data.rs new file mode 100644 index 00000000000..e3bfdad2d98 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/data.rs @@ -0,0 +1,102 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Debug, Hash)] +pub enum Type { + Bool, + Null, + Undefined, + U8, + U16, + U32, + U64, + I8, + I16, + I32, + I64, + Int, + F16, + F32, + F64, + Simple, + Bytes, + BytesIndef, + String, + StringIndef, + Array, + ArrayIndef, + Map, + MapIndef, + Tag, + Break, + Unknown(u8), +} + +impl Type { + pub(crate) fn new(ty: minicbor::data::Type) -> Self { + match ty { + minicbor::data::Type::Bool => Self::Bool, + minicbor::data::Type::Null => Self::Null, + minicbor::data::Type::Undefined => Self::Undefined, + minicbor::data::Type::U8 => Self::U8, + minicbor::data::Type::U16 => Self::U16, + minicbor::data::Type::U32 => Self::U32, + minicbor::data::Type::U64 => Self::U64, + minicbor::data::Type::I8 => Self::I8, + minicbor::data::Type::I16 => Self::I16, + minicbor::data::Type::I32 => Self::I32, + minicbor::data::Type::I64 => Self::I64, + minicbor::data::Type::Int => Self::Int, + minicbor::data::Type::F16 => Self::F16, + minicbor::data::Type::F32 => Self::F32, + minicbor::data::Type::F64 => Self::F64, + minicbor::data::Type::Simple => Self::Simple, + minicbor::data::Type::Bytes => Self::Bytes, + minicbor::data::Type::BytesIndef => Self::BytesIndef, + minicbor::data::Type::String => Self::String, + minicbor::data::Type::StringIndef => Self::StringIndef, + minicbor::data::Type::Array => Self::Array, + minicbor::data::Type::ArrayIndef => Self::ArrayIndef, + minicbor::data::Type::Map => Self::Map, + minicbor::data::Type::MapIndef => Self::MapIndef, + minicbor::data::Type::Tag => Self::Tag, + minicbor::data::Type::Break => Self::Break, + minicbor::data::Type::Unknown(byte) => Self::Unknown(byte), + } + } + + // This is just the reverse mapping of `new`. + pub(crate) fn into_minicbor_type(self) -> minicbor::data::Type { + match self { + Type::Bool => minicbor::data::Type::Bool, + Type::Null => minicbor::data::Type::Null, + Type::Undefined => minicbor::data::Type::Undefined, + Type::U8 => minicbor::data::Type::U8, + Type::U16 => minicbor::data::Type::U16, + Type::U32 => minicbor::data::Type::U32, + Type::U64 => minicbor::data::Type::U64, + Type::I8 => minicbor::data::Type::I8, + Type::I16 => minicbor::data::Type::I16, + Type::I32 => minicbor::data::Type::I32, + Type::I64 => minicbor::data::Type::I64, + Type::Int => minicbor::data::Type::Int, + Type::F16 => minicbor::data::Type::F16, + Type::F32 => minicbor::data::Type::F32, + Type::F64 => minicbor::data::Type::F64, + Type::Simple => minicbor::data::Type::Simple, + Type::Bytes => minicbor::data::Type::Bytes, + Type::BytesIndef => minicbor::data::Type::BytesIndef, + Type::String => minicbor::data::Type::String, + Type::StringIndef => minicbor::data::Type::StringIndef, + Type::Array => minicbor::data::Type::Array, + Type::ArrayIndef => minicbor::data::Type::ArrayIndef, + Type::Map => minicbor::data::Type::Map, + Type::MapIndef => minicbor::data::Type::MapIndef, + Type::Tag => minicbor::data::Type::Tag, + Type::Break => minicbor::data::Type::Break, + Type::Unknown(byte) => minicbor::data::Type::Unknown(byte), + } + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs new file mode 100644 index 00000000000..3cfe070397b --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -0,0 +1,341 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::borrow::Cow; + +use aws_smithy_types::{Blob, DateTime}; +use minicbor::decode::Error; + +use crate::data::Type; + +/// Provides functions for decoding a CBOR object with a known schema. +/// +/// Although CBOR is a self-describing format, this decoder is tailored for cases where the schema +/// is known in advance. Therefore, the caller can determine which object key exists at the current +/// position by calling `str` method, and call the relevant function based on the predetermined schema +/// for that key. If an unexpected key is encountered, the caller can use the `skip` method to skip +/// over the element. +#[derive(Debug, Clone)] +pub struct Decoder<'b> { + decoder: minicbor::Decoder<'b>, +} + +/// When any of the decode methods are called they look for that particular data type at the current +/// position. If the CBOR data tag does not match the type, a `DeserializeError` is returned. +#[derive(Debug)] +pub struct DeserializeError { + #[allow(dead_code)] + _inner: Error, +} + +impl std::fmt::Display for DeserializeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self._inner.fmt(f) + } +} + +impl std::error::Error for DeserializeError {} + +impl DeserializeError { + pub(crate) fn new(inner: Error) -> Self { + Self { _inner: inner } + } + + /// More than one union variant was detected: `unexpected_type` was unexpected. + pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self { + Self { + _inner: Error::type_mismatch(unexpected_type.into_minicbor_type()) + .with_message("encountered unexpected union variant; expected end of union") + .at(at), + } + } + + /// Unknown union variant was detected. Servers reject unknown union varaints. + pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self { + Self { + _inner: Error::message(format!( + "encountered unknown union variant {}", + variant_name + )) + .at(at), + } + } + + /// More than one union variant was detected, but we never even got to parse the first one. + /// We immediately raise this error when detecting a union serialized as a fixed-length CBOR + /// map whose length (specified upfront) is a value different than 1. + pub fn mixed_union_variants(at: usize) -> Self { + Self { + _inner: Error::message( + "encountered mixed variants in union; expected a single union variant to be set", + ) + .at(at), + } + } + + /// Expected end of stream but more data is available. + pub fn expected_end_of_stream(at: usize) -> Self { + Self { + _inner: Error::message("encountered additional data; expected end of stream").at(at), + } + } + + /// An unexpected type was encountered. + // We handle this one when decoding sparse collections: we have to expect either a `null` or an + // item, so we try decoding both. + pub fn is_type_mismatch(&self) -> bool { + self._inner.is_type_mismatch() + } +} + +/// Macro for delegating method calls to the decoder. +/// +/// This macro generates wrapper methods for calling specific methods on the decoder and returning +/// the result with error handling. +/// +/// # Example +/// +/// ```ignore +/// delegate_method! { +/// /// Wrapper method for encoding method `encode_str` on the decoder. +/// encode_str_wrapper => encode_str(String); +/// /// Wrapper method for encoding method `encode_int` on the decoder. +/// encode_int_wrapper => encode_int(i32); +/// } +/// ``` +macro_rules! delegate_method { + ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => { + $( + pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> { + self.decoder.$encoder_name().map_err(DeserializeError::new) + } + )+ + }; +} + +impl<'b> Decoder<'b> { + pub fn new(bytes: &'b [u8]) -> Self { + Self { + decoder: minicbor::Decoder::new(bytes), + } + } + + pub fn datatype(&self) -> Result { + self.decoder + .datatype() + .map(Type::new) + .map_err(DeserializeError::new) + } + + delegate_method! { + /// Skips the current CBOR element. + skip => skip(()); + /// Reads a boolean at the current position. + boolean => bool(bool); + /// Reads a byte at the current position. + byte => i8(i8); + /// Reads a short at the current position. + short => i16(i16); + /// Reads a integer at the current position. + integer => i32(i32); + /// Reads a long at the current position. + long => i64(i64); + /// Reads a float at the current position. + float => f32(f32); + /// Reads a double at the current position. + double => f64(f64); + /// Reads a null CBOR element at the current position. + null => null(()); + /// Returns the number of elements in a definite list. For indefinite lists it returns a `None`. + list => array(Option); + /// Returns the number of elements in a definite map. For indefinite map it returns a `None`. + map => map(Option); + } + + /// Returns the current position of the buffer, which will be decoded when any of the methods is called. + pub fn position(&self) -> usize { + self.decoder.position() + } + + /// Returns a `Cow::Borrowed(&str)` if the element at the current position in the buffer is a definite + /// length string. Otherwise, it returns a `Cow::Owned(String)` if the element at the current position is an + /// indefinite-length string. An error is returned if the element is neither a definite length nor an + /// indefinite-length string. + pub fn str(&mut self) -> Result, DeserializeError> { + let bookmark = self.decoder.position(); + match self.decoder.str() { + Ok(str_value) => Ok(Cow::Borrowed(str_value)), + Err(e) if e.is_type_mismatch() => { + // Move the position back to the start of the CBOR element and then try + // decoding it as an indefinite length string. + self.decoder.set_position(bookmark); + Ok(Cow::Owned(self.string()?)) + } + Err(e) => Err(DeserializeError::new(e)), + } + } + + /// Allocates and returns a `String` if the element at the current position in the buffer is either a + /// definite-length or an indefinite-length string. Otherwise, an error is returned if the element is not a string type. + pub fn string(&mut self) -> Result { + let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?; + let head = iter.next(); + + let decoded_string = match head { + None => String::new(), + Some(head) => { + let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?); + for chunk in iter { + combined_chunks.push_str(chunk.map_err(DeserializeError::new)?); + } + combined_chunks + } + }; + + Ok(decoded_string) + } + + /// Returns a `blob` if the element at the current position in the buffer is a byte string. Otherwise, + /// a `DeserializeError` error is returned. + pub fn blob(&mut self) -> Result { + let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?; + let parts: Vec<&[u8]> = iter + .collect::>() + .map_err(DeserializeError::new)?; + + Ok(if parts.len() == 1 { + Blob::new(parts[0]) // Directly convert &[u8] to Blob if there's only one part. + } else { + Blob::new(parts.concat()) // Concatenate all parts into a single Blob. + }) + } + + /// Returns a `DateTime` if the element at the current position in the buffer is a `timestamp`. Otherwise, + /// a `DeserializeError` error is returned. + pub fn timestamp(&mut self) -> Result { + let tag = self.decoder.tag().map_err(DeserializeError::new)?; + let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp); + + if tag != timestamp_tag { + Err(DeserializeError::new(Error::message( + "expected timestamp tag", + ))) + } else { + let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?; + Ok(DateTime::from_secs_f64(epoch_seconds)) + } + } +} + +#[derive(Debug)] +pub struct ArrayIter<'a, 'b, T> { + inner: minicbor::decode::ArrayIter<'a, 'b, T>, +} + +impl<'a, 'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'a, 'b, T> { + type Item = Result; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|opt| opt.map_err(DeserializeError::new)) + } +} + +#[derive(Debug)] +pub struct MapIter<'a, 'b, K, V> { + inner: minicbor::decode::MapIter<'a, 'b, K, V>, +} + +impl<'a, 'b, K, V> Iterator for MapIter<'a, 'b, K, V> +where + K: minicbor::Decode<'b, ()>, + V: minicbor::Decode<'b, ()>, +{ + type Item = Result<(K, V), DeserializeError>; + + fn next(&mut self) -> Option { + self.inner + .next() + .map(|opt| opt.map_err(DeserializeError::new)) + } +} + +pub fn set_optional(builder: B, decoder: &mut Decoder, f: F) -> Result +where + F: Fn(B, &mut Decoder) -> Result, +{ + match decoder.datatype()? { + crate::data::Type::Null => { + decoder.null()?; + Ok(builder) + } + _ => f(builder, decoder), + } +} + +#[cfg(test)] +mod tests { + use crate::Decoder; + + #[test] + fn test_definite_str_is_cow_borrowed() { + // Definite length key `thisIsAKey`. + let definite_bytes = [ + 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79, + ]; + let mut decoder = Decoder::new(&definite_bytes); + let member = decoder.str().expect("could not decode str"); + assert_eq!(member, "thisIsAKey"); + assert!(matches!(member, std::borrow::Cow::Borrowed(_))); + } + + #[test] + fn test_indefinite_str_is_cow_owned() { + // Indefinite length key `this`, `Is`, `A` and `Key`. + let indefinite_bytes = [ + 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65, + 0x79, 0xff, + ]; + let mut decoder = Decoder::new(&indefinite_bytes); + let member = decoder.str().expect("could not decode str"); + assert_eq!(member, "thisIsAKey"); + assert!(matches!(member, std::borrow::Cow::Owned(_))); + } + + #[test] + fn test_empty_str_works() { + let bytes = [0x60]; + let mut decoder = Decoder::new(&bytes); + let member = decoder.str().expect("could not decode empty str"); + assert_eq!(member, ""); + } + + #[test] + fn test_empty_blob_works() { + let bytes = [0x40]; + let mut decoder = Decoder::new(&bytes); + let member = decoder.blob().expect("could not decode an empty blob"); + assert_eq!(member, aws_smithy_types::Blob::new(&[])); + } + + #[test] + fn test_indefinite_length_blob() { + // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`. + // https://cbor.nemo157.com/#type=hex&value=bf69626c6f6256616c75655f50696e646566696e6974652d627974652c49206368756e6b65642c4e206f6e206561636820636f6d6d61ffff + let indefinite_bytes = [ + 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62, + 0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c, + 0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d, + 0x61, 0xff, + ]; + let mut decoder = Decoder::new(&indefinite_bytes); + let member = decoder.blob().expect("could not decode blob"); + assert_eq!( + member, + aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes()) + ); + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/encode.rs b/rust-runtime/aws-smithy-cbor/src/encode.rs new file mode 100644 index 00000000000..1651c37f9b2 --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/encode.rs @@ -0,0 +1,117 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_types::{Blob, DateTime}; + +/// Macro for delegating method calls to the encoder. +/// +/// This macro generates wrapper methods for calling specific encoder methods on the encoder +/// and returning a mutable reference to self for method chaining. +/// +/// # Example +/// +/// ```ignore +/// delegate_method! { +/// /// Wrapper method for encoding method `encode_str` on the encoder. +/// encode_str_wrapper => encode_str(data: &str); +/// /// Wrapper method for encoding method `encode_int` on the encoder. +/// encode_int_wrapper => encode_int(value: i32); +/// } +/// ``` +macro_rules! delegate_method { + ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($($param_name:ident : $param_type:ty),*);)+) => { + $( + pub fn $wrapper_name(&mut self, $($param_name: $param_type),*) -> &mut Self { + self.encoder.$encoder_name($($param_name)*).expect(INFALLIBLE_WRITE); + self + } + )+ + }; +} + +#[derive(Debug, Clone)] +pub struct Encoder { + encoder: minicbor::Encoder>, +} + +/// We always write to a `Vec`, which is infallible in `minicbor`. +/// +const INFALLIBLE_WRITE: &str = "write failed"; + +impl Encoder { + pub fn new(writer: Vec) -> Self { + Self { + encoder: minicbor::Encoder::new(writer), + } + } + + delegate_method! { + /// Used when it's not cheap to calculate the size, i.e. when the struct has one or more + /// `Option`al members. + begin_map => begin_map(); + /// Writes a definite length string. + str => str(x: &str); + /// Writes a boolean value. + boolean => bool(x: bool); + /// Writes a byte value. + byte => i8(x: i8); + /// Writes a short value. + short => i16(x: i16); + /// Writes an integer value. + integer => i32(x: i32); + /// Writes an long value. + long => i64(x: i64); + /// Writes an float value. + float => f32(x: f32); + /// Writes an double value. + double => f64(x: f64); + /// Writes a null tag. + null => null(); + /// Writes an end tag. + end => end(); + } + + pub fn blob(&mut self, x: &Blob) -> &mut Self { + self.encoder.bytes(x.as_ref()).expect(INFALLIBLE_WRITE); + self + } + + /// Writes a fixed length array of given length. + pub fn array(&mut self, len: usize) -> &mut Self { + self.encoder + // `.expect()` safety: `From for usize` is not in the standard library, + // but the conversion should be infallible (unless we ever have 128-bit machines I + // guess). . + .array(len.try_into().expect("`usize` to `u64` conversion failed")) + .expect(INFALLIBLE_WRITE); + self + } + + /// Writes a fixed length map of given length. + /// Used when we know the size in advance, i.e.: + /// - when a struct has all non-`Option`al members. + /// - when serializing `union` shapes (they can only have one member set). + /// - when serializing a `map` shape. + pub fn map(&mut self, len: usize) -> &mut Self { + self.encoder + .map(len.try_into().expect("`usize` to `u64` conversion failed")) + .expect(INFALLIBLE_WRITE); + self + } + + pub fn timestamp(&mut self, x: &DateTime) -> &mut Self { + self.encoder + .tag(minicbor::data::Tag::from( + minicbor::data::IanaTag::Timestamp, + )) + .expect(INFALLIBLE_WRITE); + self.encoder.f64(x.as_secs_f64()).expect(INFALLIBLE_WRITE); + self + } + + pub fn into_writer(self) -> Vec { + self.encoder.into_writer() + } +} diff --git a/rust-runtime/aws-smithy-cbor/src/lib.rs b/rust-runtime/aws-smithy-cbor/src/lib.rs new file mode 100644 index 00000000000..6db4813980a --- /dev/null +++ b/rust-runtime/aws-smithy-cbor/src/lib.rs @@ -0,0 +1,17 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! CBOR abstractions for Smithy. + +/* Automatically managed default lints */ +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +/* End of automatically managed default lints */ + +pub mod data; +pub mod decode; +pub mod encode; + +pub use decode::Decoder; +pub use encode::Encoder; diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index a63a125941e..73c2ba42d1b 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-http-server" -version = "0.63.1" +version = "0.63.2" authors = ["Smithy Rust Server "] edition = "2021" license = "Apache-2.0" @@ -23,6 +23,7 @@ aws-smithy-json = { path = "../aws-smithy-json" } aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["http-02x"] } aws-smithy-types = { path = "../aws-smithy-types", features = ["http-body-0-4-x", "hyper-0-14-x"] } aws-smithy-xml = { path = "../aws-smithy-xml" } +aws-smithy-cbor = { path = "../aws-smithy-cbor" } bytes = "1.1" futures-util = { version = "0.3.29", default-features = false } http = "0.2" diff --git a/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs b/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs index ac7a645f265..cd9e333bb2a 100644 --- a/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs +++ b/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs @@ -175,6 +175,8 @@ where type Future = UpgradeFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // The check that the inner service is ready is done by `Oneshot` in `UpgradeFuture`'s + // implementation. Poll::Ready(Ok(())) } diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs index df304d823f5..38538fe1e9a 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json/router.rs @@ -39,9 +39,9 @@ pub enum Error { // This constant determines when the `TinyMap` implementation switches from being a `Vec` to a // `HashMap`. This is chosen to be 15 as a result of the discussion around // https://github.com/smithy-lang/smithy-rs/pull/1429#issuecomment-1147516546 -const ROUTE_CUTOFF: usize = 15; +pub(crate) const ROUTE_CUTOFF: usize = 15; -/// A [`Router`] supporting [`AWS JSON 1.0`] and [`AWS JSON 1.1`] protocols. +/// A [`Router`] supporting [AWS JSON 1.0] and [AWS JSON 1.1] protocols. /// /// [AWS JSON 1.0]: https://smithy.io/2.0/aws/protocols/aws-json-1_0-protocol.html /// [AWS JSON 1.1]: https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs index 8dfd48e4661..a3e8f2c9192 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/mod.rs @@ -5,5 +5,5 @@ pub mod router; -/// [AWS JSON 1.0 Protocol](https://smithy.io/2.0/aws/protocols/aws-json-1_0-protocol.html). +/// [AWS JSON 1.0](https://smithy.io/2.0/aws/protocols/aws-json-1_0-protocol.html) protocol. pub struct AwsJson1_0; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs index 30a28d6255a..ac963ffe512 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_10/router.rs @@ -12,6 +12,8 @@ use super::AwsJson1_0; pub use crate::protocol::aws_json::router::*; +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs index 6fb09920a0c..697aae52d3e 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/mod.rs @@ -5,5 +5,5 @@ pub mod router; -/// [AWS JSON 1.1 Protocol](https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html). +/// [AWS JSON 1.1](https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html) protocol. pub struct AwsJson1_1; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs index 5ebd1002f28..2e3e16d8ad4 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/aws_json_11/router.rs @@ -12,6 +12,8 @@ use super::AwsJson1_1; pub use crate::protocol::aws_json::router::*; +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs index 27a16d9f18e..6d6bbf3b650 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/mod.rs @@ -9,6 +9,7 @@ pub mod aws_json_11; pub mod rest; pub mod rest_json_1; pub mod rest_xml; +pub mod rpc_v2_cbor; use crate::rejection::MissingContentTypeReason; use aws_smithy_runtime_api::http::Headers as SmithyHeaders; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs index 19a644e7e45..94f99a98dfe 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest/router.rs @@ -26,10 +26,10 @@ pub enum Error { MethodNotAllowed, } -/// A [`Router`] supporting [`AWS REST JSON 1.0`] and [`AWS REST XML`] protocols. +/// A [`Router`] supporting [AWS restJson1] and [AWS restXml] protocols. /// -/// [AWS REST JSON 1.0]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restjson1-protocol.html -/// [AWS REST XML]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restxml-protocol.html +/// [AWS restJson1]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restjson1-protocol.html +/// [AWS restXml]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restxml-protocol.html #[derive(Debug, Clone)] pub struct RestRouter { routes: Vec<(RequestSpec, S)>, diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs index f8384578d23..695d995ce18 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/mod.rs @@ -7,5 +7,5 @@ pub mod rejection; pub mod router; pub mod runtime_error; -/// [AWS REST JSON 1.0 Protocol](https://smithy.io/2.0/aws/protocols/aws-restjson1-protocol.html). +/// [AWS restJson1](https://smithy.io/2.0/aws/protocols/aws-restjson1-protocol.html) protocol. pub struct RestJson1; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs index 023b43031c8..939b1bb6ec3 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_json_1/router.rs @@ -12,6 +12,8 @@ use super::RestJson1; pub use crate::protocol::rest::router::*; +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs index 0b16df11e32..e16570567ea 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/mod.rs @@ -7,5 +7,5 @@ pub mod rejection; pub mod router; pub mod runtime_error; -/// [AWS REST XML Protocol](https://smithy.io/2.0/aws/protocols/aws-restxml-protocol.html). +/// [AWS restXml](https://smithy.io/2.0/aws/protocols/aws-restxml-protocol.html) protocol. pub struct RestXml; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs index 529a3d19a2a..e684ced4dec 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rest_xml/router.rs @@ -13,7 +13,8 @@ use super::RestXml; pub use crate::protocol::rest::router::*; -/// An AWS REST routing error. +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! impl IntoResponse for Error { fn into_response(self) -> http::Response { match self { diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/mod.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/mod.rs new file mode 100644 index 00000000000..287a756446b --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +pub mod rejection; +pub mod router; +pub mod runtime_error; + +/// [Smithy RPC v2 CBOR](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) +/// protocol. +pub struct RpcV2Cbor; diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/rejection.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/rejection.rs new file mode 100644 index 00000000000..2ec8b957af5 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/rejection.rs @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::num::TryFromIntError; + +use crate::rejection::MissingContentTypeReason; +use aws_smithy_runtime_api::http::HttpError; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ResponseRejection { + #[error("invalid bound HTTP status code; status codes must be inside the 100-999 range: {0}")] + InvalidHttpStatusCode(TryFromIntError), + #[error("error serializing CBOR-encoded body: {0}")] + Serialization(#[from] aws_smithy_types::error::operation::SerializationError), + #[error("error building HTTP response: {0}")] + HttpBuild(#[from] http::Error), +} + +#[derive(Debug, Error)] +pub enum RequestRejection { + #[error("error converting non-streaming body to bytes: {0}")] + BufferHttpBodyBytes(crate::Error), + #[error("request contains invalid value for `Accept` header")] + NotAcceptable, + #[error("expected `Content-Type` header not found: {0}")] + MissingContentType(#[from] MissingContentTypeReason), + #[error("error deserializing request HTTP body as CBOR: {0}")] + CborDeserialize(#[from] aws_smithy_cbor::decode::DeserializeError), + // Unlike the other protocols, RPC v2 uses CBOR, a binary serialization format, so we take in a + // `Vec` here instead of `String`. + #[error("request does not adhere to modeled constraints")] + ConstraintViolation(Vec), + + /// Typically happens when the request has headers that are not valid UTF-8. + #[error("failed to convert request: {0}")] + HttpConversion(#[from] HttpError), +} + +impl From for RequestRejection { + fn from(_err: std::convert::Infallible) -> Self { + match _err {} + } +} + +convert_to_request_rejection!(hyper::Error, BufferHttpBodyBytes); +convert_to_request_rejection!(Box, BufferHttpBodyBytes); diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs new file mode 100644 index 00000000000..53d6e314831 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/router.rs @@ -0,0 +1,406 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::convert::Infallible; +use std::str::FromStr; + +use http::header::ToStrError; +use http::HeaderMap; +use once_cell::sync::Lazy; +use regex::Regex; +use thiserror::Error; +use tower::Layer; +use tower::Service; + +use crate::body::empty; +use crate::body::BoxBody; +use crate::extension::RuntimeErrorExtension; +use crate::protocol::aws_json_11::router::ROUTE_CUTOFF; +use crate::response::IntoResponse; +use crate::routing::tiny_map::TinyMap; +use crate::routing::Route; +use crate::routing::Router; +use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; + +use super::RpcV2Cbor; + +pub use crate::protocol::rest::router::*; + +/// An RPC v2 CBOR routing error. +#[derive(Debug, Error)] +pub enum Error { + /// Method was not `POST`. + #[error("method not POST")] + MethodNotAllowed, + /// Requests for the `rpcv2Cbor` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` + /// header. + #[error("contains forbidden headers")] + ForbiddenHeaders, + /// Unable to parse `smithy-protocol` header into a valid wire format value. + #[error("failed to parse `smithy-protocol` header into a valid wire format value")] + InvalidWireFormatHeader(#[from] WireFormatError), + /// Operation not found. + #[error("operation not found")] + NotFound, +} + +/// A [`Router`] supporting the [Smithy RPC v2 CBOR] protocol. +/// +/// [Smithy RPC v2 CBOR]: https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html +#[derive(Debug, Clone)] +pub struct RpcV2CborRouter { + routes: TinyMap<&'static str, S, ROUTE_CUTOFF>, +} + +/// Requests for the `rpcv2Cbor` protocol MUST NOT contain an `x-amz-target` or `x-amzn-target` +/// header. An `rpcv2Cbor` request is malformed if it contains either of these headers. Server-side +/// implementations MUST reject such requests for security reasons. +const FORBIDDEN_HEADERS: &[&str] = &["x-amz-target", "x-amzn-target"]; + +/// Matches the `Identifier` ABNF rule in +/// . +const IDENTIFIER_PATTERN: &str = r#"((_+([A-Za-z]|[0-9]))|[A-Za-z])[A-Za-z0-9_]*"#; + +impl RpcV2CborRouter { + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3748) Consider building a nom parser. + fn uri_path_regex() -> &'static Regex { + // Every request for the `rpcv2Cbor` protocol MUST be sent to a URL with the + // following form: `{prefix?}/service/{serviceName}/operation/{operationName}` + // + // * The optional `prefix` segment may span multiple path segments and is not + // utilized by the Smithy RPC v2 CBOR protocol. For example, a service could + // use a `v1` prefix for the following URL path: `v1/service/FooService/operation/BarOperation` + // * The `serviceName` segment MUST be replaced by the [`shape + // name`](https://smithy.io/2.0/spec/model.html#grammar-token-smithy-Identifier) + // of the service's [Shape ID](https://smithy.io/2.0/spec/model.html#shape-id) + // in the Smithy model. The `serviceName` produced by client implementations + // MUST NOT contain the namespace of the `service` shape. Service + // implementations SHOULD accept an absolute shape ID as the content of this + // segment with the `#` character replaced with a `.` character, routing it + // the same as if only the name was specified. For example, if the `service`'s + // absolute shape ID is `com.example#TheService`, a service should accept both + // `TheService` and `com.example.TheService` as values for the `serviceName` + // segment. + static PATH_REGEX: Lazy = Lazy::new(|| { + Regex::new(&format!( + r#"/service/({}\.)*(?P{})/operation/(?P{})$"#, + IDENTIFIER_PATTERN, IDENTIFIER_PATTERN, IDENTIFIER_PATTERN, + )) + .unwrap() + }); + + &PATH_REGEX + } + + pub fn wire_format_regex() -> &'static Regex { + static SMITHY_PROTOCOL_REGEX: Lazy = Lazy::new(|| Regex::new(r#"^rpc-v2-(?P\w+)$"#).unwrap()); + + &SMITHY_PROTOCOL_REGEX + } + + pub fn boxed(self) -> RpcV2CborRouter> + where + S: Service, Response = http::Response, Error = Infallible>, + S: Send + Clone + 'static, + S::Future: Send + 'static, + { + RpcV2CborRouter { + routes: self.routes.into_iter().map(|(key, s)| (key, Route::new(s))).collect(), + } + } + + /// Applies a [`Layer`] uniformly to all routes. + pub fn layer(self, layer: L) -> RpcV2CborRouter + where + L: Layer, + { + RpcV2CborRouter { + routes: self + .routes + .into_iter() + .map(|(key, route)| (key, layer.layer(route))) + .collect(), + } + } +} + +// TODO(https://github.com/smithy-lang/smithy/issues/2348): We're probably non-compliant here, but +// we have no tests to pin our implemenation against! +impl IntoResponse for Error { + fn into_response(self) -> http::Response { + match self { + Error::MethodNotAllowed => method_disallowed(), + _ => http::Response::builder() + .status(http::StatusCode::NOT_FOUND) + .header(http::header::CONTENT_TYPE, "application/cbor") + .extension(RuntimeErrorExtension::new( + UNKNOWN_OPERATION_EXCEPTION.to_string(), + )) + .body(empty()) + .expect("invalid HTTP response for RPCv2 CBOR routing error; please file a bug report under https://github.com/awslabs/smithy-rs/issues"), + } + } +} + +/// Errors that can happen when parsing the wire format from the `smithy-protocol` header. +#[derive(Debug, Error)] +pub enum WireFormatError { + /// Header not found. + #[error("`smithy-protocol` header not found")] + HeaderNotFound, + /// Header value is not visible ASCII. + #[error("`smithy-protocol` header not visible ASCII")] + HeaderValueNotVisibleAscii(ToStrError), + /// Header value does not match the `rpc-v2-{format}` pattern. The actual parsed header value + /// is stored in the tuple struct. + // https://doc.rust-lang.org/std/fmt/index.html#escaping + #[error("`smithy-protocol` header does not match the `rpc-v2-{{format}}` pattern: `{0}`")] + HeaderValueNotValid(String), + /// Header value matches the `rpc-v2-{format}` pattern, but the `format` is not supported. The + /// actual parsed header value is stored in the tuple struct. + #[error("found unsupported `smithy-protocol` wire format: `{0}`")] + WireFormatNotSupported(String), +} + +/// Smithy RPC V2 requests have a `smithy-protocol` header with the value +/// `"rpc-v2-{format}"`, where `format` is one of the supported wire formats +/// by the protocol (see [`WireFormat`]). +fn parse_wire_format_from_header(headers: &HeaderMap) -> Result { + let header = headers.get("smithy-protocol").ok_or(WireFormatError::HeaderNotFound)?; + let header = header.to_str().map_err(WireFormatError::HeaderValueNotVisibleAscii)?; + let captures = RpcV2CborRouter::<()>::wire_format_regex() + .captures(header) + .ok_or_else(|| WireFormatError::HeaderValueNotValid(header.to_owned()))?; + + let format = captures + .name("format") + .ok_or_else(|| WireFormatError::HeaderValueNotValid(header.to_owned()))?; + + let wire_format_parse_res: Result = format.as_str().parse(); + wire_format_parse_res.map_err(|_| WireFormatError::WireFormatNotSupported(header.to_owned())) +} + +/// Supported wire formats by RPC V2. +enum WireFormat { + Cbor, +} + +struct WireFormatFromStrError; + +impl FromStr for WireFormat { + type Err = WireFormatFromStrError; + + fn from_str(format: &str) -> Result { + match format { + "cbor" => Ok(Self::Cbor), + _ => Err(WireFormatFromStrError), + } + } +} + +impl Router for RpcV2CborRouter { + type Service = S; + + type Error = Error; + + fn match_route(&self, request: &http::Request) -> Result { + // Only `Method::POST` is allowed. + if request.method() != http::Method::POST { + return Err(Error::MethodNotAllowed); + } + + // Some headers are not allowed. + let request_has_forbidden_header = FORBIDDEN_HEADERS + .iter() + .any(|&forbidden_header| request.headers().contains_key(forbidden_header)); + if request_has_forbidden_header { + return Err(Error::ForbiddenHeaders); + } + + // Wire format has to be specified and supported. + let _wire_format = parse_wire_format_from_header(request.headers())?; + + // Extract the service name and the operation name from the request URI. + let request_path = request.uri().path(); + let regex = Self::uri_path_regex(); + + tracing::trace!(%request_path, "capturing service and operation from URI"); + let captures = regex.captures(request_path).ok_or(Error::NotFound)?; + let (service, operation) = (&captures["service"], &captures["operation"]); + tracing::trace!(%service, %operation, "captured service and operation from URI"); + + // Lookup in the `TinyMap` for a route for the target. + let route = self + .routes + .get((format!("{service}.{operation}")).as_str()) + .ok_or(Error::NotFound)?; + Ok(route.clone()) + } +} + +impl FromIterator<(&'static str, S)> for RpcV2CborRouter { + #[inline] + fn from_iter>(iter: T) -> Self { + Self { + routes: iter.into_iter().collect(), + } + } +} + +#[cfg(test)] +mod tests { + use http::{HeaderMap, HeaderValue, Method}; + use regex::Regex; + + use crate::protocol::test_helpers::req; + + use super::{Error, Router, RpcV2CborRouter}; + + fn identifier_regex() -> Regex { + Regex::new(&format!("^{}$", super::IDENTIFIER_PATTERN)).unwrap() + } + + #[test] + fn valid_identifiers() { + let valid_identifiers = vec!["a", "_a", "_0", "__0", "variable123", "_underscored_variable"]; + + for id in &valid_identifiers { + assert!(identifier_regex().is_match(id), "'{}' is incorrectly rejected", id); + } + } + + #[test] + fn invalid_identifiers() { + let invalid_identifiers = vec![ + "0", + "123starts_with_digit", + "@invalid_start_character", + " space_in_identifier", + "invalid-character", + "invalid@character", + "no#hashes", + ]; + + for id in &invalid_identifiers { + assert!(!identifier_regex().is_match(id), "'{}' is incorrectly accepted", id); + } + } + + #[test] + fn uri_regex_works_accepts() { + let regex = RpcV2CborRouter::<()>::uri_path_regex(); + + for uri in [ + "/service/Service/operation/Operation", + "prefix/69/service/Service/operation/Operation", + // Here the prefix is up to the last occurrence of the string `/service`. + "prefix/69/service/Service/operation/Operation/service/Service/operation/Operation", + // Service implementations SHOULD accept an absolute shape ID as the content of this + // segment with the `#` character replaced with a `.` character, routing it the same as + // if only the name was specified. For example, if the `service`'s absolute shape ID is + // `com.example#TheService`, a service should accept both `TheService` and + // `com.example.TheService` as values for the `serviceName` segment. + "/service/aws.protocoltests.rpcv2Cbor.Service/operation/Operation", + "/service/namespace.Service/operation/Operation", + ] { + let captures = regex.captures(uri).unwrap(); + assert_eq!("Service", &captures["service"], "uri: {}", uri); + assert_eq!("Operation", &captures["operation"], "uri: {}", uri); + } + } + + #[test] + fn uri_regex_works_rejects() { + let regex = RpcV2CborRouter::<()>::uri_path_regex(); + + for uri in [ + "", + "foo", + "/servicee/Service/operation/Operation", + "/service/Service", + "/service/Service/operation/", + "/service/Service/operation/Operation/", + "/service/Service/operation/Operation/invalid-suffix", + "/service/namespace.foo#Service/operation/Operation", + "/service/namespace-Service/operation/Operation", + "/service/.Service/operation/Operation", + ] { + assert!(regex.captures(uri).is_none(), "uri: {}", uri); + } + } + + #[test] + fn wire_format_regex_works() { + let regex = RpcV2CborRouter::<()>::wire_format_regex(); + + let captures = regex.captures("rpc-v2-something").unwrap(); + assert_eq!("something", &captures["format"]); + + let captures = regex.captures("rpc-v2-SomethingElse").unwrap(); + assert_eq!("SomethingElse", &captures["format"]); + + let invalid = regex.captures("rpc-v1-something"); + assert!(invalid.is_none()); + } + + /// Helper function returning the only strictly required header. + fn headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert("smithy-protocol", HeaderValue::from_static("rpc-v2-cbor")); + headers + } + + #[test] + fn simple_routing() { + let router: RpcV2CborRouter<_> = ["Service.Operation"].into_iter().map(|op| (op, ())).collect(); + let good_uri = "/prefix/service/Service/operation/Operation"; + + // The request should match. + let routing_result = router.match_route(&req(&Method::POST, good_uri, Some(headers()))); + assert!(routing_result.is_ok()); + + // The request would be valid if it used `Method::POST`. + let invalid_request = req(&Method::GET, good_uri, Some(headers())); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::MethodNotAllowed) + )); + + // The request would be valid if it did not have forbidden headers. + for forbidden_header_name in ["x-amz-target", "x-amzn-target"] { + let mut headers = headers(); + headers.insert(forbidden_header_name, HeaderValue::from_static("Service.Operation")); + let invalid_request = req(&Method::POST, good_uri, Some(headers)); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::ForbiddenHeaders) + )); + } + + for bad_uri in [ + // These requests would be valid if they used correct URIs. + "/prefix/Service/Service/operation/Operation", + "/prefix/service/Service/operation/Operation/suffix", + // These requests would be valid if their URI matched an existing operation. + "/prefix/service/ThisServiceDoesNotExist/operation/Operation", + "/prefix/service/Service/operation/ThisOperationDoesNotExist", + ] { + let invalid_request = &req(&Method::POST, bad_uri, Some(headers())); + assert!(matches!(router.match_route(&invalid_request), Err(Error::NotFound))); + } + + // The request would be valid if it specified a supported wire format in the + // `smithy-protocol` header. + for header_name in ["bad-header", "rpc-v2-json", "foo-rpc-v2-cbor", "rpc-v2-cbor-foo"] { + let mut headers = HeaderMap::new(); + headers.insert("smithy-protocol", HeaderValue::from_static(header_name)); + let invalid_request = &req(&Method::POST, good_uri, Some(headers)); + assert!(matches!( + router.match_route(&invalid_request), + Err(Error::InvalidWireFormatHeader(_)) + )); + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/runtime_error.rs new file mode 100644 index 00000000000..b3f01da3511 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/protocol/rpc_v2_cbor/runtime_error.rs @@ -0,0 +1,98 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use crate::response::IntoResponse; +use crate::runtime_error::{InternalFailureException, INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE}; +use crate::{extension::RuntimeErrorExtension, protocol::rpc_v2_cbor::RpcV2Cbor}; +use bytes::Bytes; +use http::StatusCode; + +use super::rejection::{RequestRejection, ResponseRejection}; + +#[derive(Debug, thiserror::Error)] +pub enum RuntimeError { + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::Serialization`] + #[error("request failed to deserialize or response failed to serialize: {0}")] + Serialization(crate::Error), + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::InternalFailure`] + #[error("internal failure: {0}")] + InternalFailure(crate::Error), + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::NotAcceptable`] + #[error("not acceptable request: request contains an `Accept` header with a MIME type, and the server cannot return a response body adhering to that MIME type")] + NotAcceptable, + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::UnsupportedMediaType`] + #[error("unsupported media type: request does not contain the expected `Content-Type` header value")] + UnsupportedMediaType, + /// See: [`crate::protocol::rest_json_1::runtime_error::RuntimeError::Validation`] + #[error( + "validation failure: operation input contains data that does not adhere to the modeled constraints: {0:?}" + )] + Validation(Vec), +} + +impl RuntimeError { + pub fn name(&self) -> &'static str { + match self { + Self::Serialization(_) => "SerializationException", + Self::InternalFailure(_) => "InternalFailureException", + Self::NotAcceptable => "NotAcceptableException", + Self::UnsupportedMediaType => "UnsupportedMediaTypeException", + Self::Validation(_) => "ValidationException", + } + } + + pub fn status_code(&self) -> StatusCode { + match self { + Self::Serialization(_) => StatusCode::BAD_REQUEST, + Self::InternalFailure(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::NotAcceptable => StatusCode::NOT_ACCEPTABLE, + Self::UnsupportedMediaType => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Self::Validation(_) => StatusCode::BAD_REQUEST, + } + } +} + +impl IntoResponse for InternalFailureException { + fn into_response(self) -> http::Response { + IntoResponse::::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new()))) + } +} + +impl IntoResponse for RuntimeError { + fn into_response(self) -> http::Response { + let res = http::Response::builder() + .status(self.status_code()) + .header("Content-Type", "application/cbor") + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + // https://cbor.nemo157.com/#type=hex&value=a0 + const EMPTY_CBOR_MAP: Bytes = Bytes::from_static(&[0xa0]); + + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3716): we're not serializing + // `__type`. + let body = match self { + RuntimeError::Validation(reason) => crate::body::to_boxed(reason), + _ => crate::body::to_boxed(EMPTY_CBOR_MAP), + }; + + res.body(body) + .expect(INVALID_HTTP_RESPONSE_FOR_RUNTIME_ERROR_PANIC_MESSAGE) + } +} + +impl From for RuntimeError { + fn from(err: ResponseRejection) -> Self { + Self::Serialization(crate::Error::new(err)) + } +} + +impl From for RuntimeError { + fn from(err: RequestRejection) -> Self { + match err { + RequestRejection::ConstraintViolation(reason) => Self::Validation(reason), + _ => Self::Serialization(crate::Error::new(err)), + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 14f124f6874..ede1f5117b0 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -37,7 +37,6 @@ use futures_util::{ use http::Response; use http_body::Body as HttpBody; use tower::{util::Oneshot, Service, ServiceExt}; -use tracing::debug; use crate::{ body::{boxed, BoxBody}, @@ -191,12 +190,13 @@ where } fn call(&mut self, req: http::Request) -> Self::Future { + tracing::debug!("inside routing service call"); match self.router.match_route(&req) { // Successfully routed, use the routes `Service::call`. Ok(ok) => RoutingFuture::from_oneshot(ok.oneshot(req)), // Failed to route, use the `R::Error`s `IntoResponse

`. Err(error) => { - debug!(%error, "failed to route"); + tracing::debug!(%error, "failed to route"); RoutingFuture::from_response(error.into_response()) } } diff --git a/rust-runtime/aws-smithy-protocol-test/Cargo.toml b/rust-runtime/aws-smithy-protocol-test/Cargo.toml index e674b66dfea..9f5189079a0 100644 --- a/rust-runtime/aws-smithy-protocol-test/Cargo.toml +++ b/rust-runtime/aws-smithy-protocol-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-protocol-test" -version = "0.60.8" +version = "0.62.0" authors = ["AWS Rust SDK Team ", "Russell Cohen "] description = "A collection of library functions to validate HTTP requests against Smithy protocol tests." edition = "2021" @@ -10,6 +10,9 @@ repository = "https://github.com/smithy-lang/smithy-rs" [dependencies] # Not perfect for our needs, but good for now assert-json-diff = "1.1" +base64-simd = "0.8" +cbor-diag = "0.1.12" +serde_cbor = "0.11" http = "0.2.1" pretty_assertions = "1.3" regex-lite = "0.1.5" @@ -18,7 +21,6 @@ serde_json = "1" thiserror = "1.0.40" aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["client"] } - [package.metadata.docs.rs] all-features = true targets = ["x86_64-unknown-linux-gnu"] diff --git a/rust-runtime/aws-smithy-protocol-test/src/lib.rs b/rust-runtime/aws-smithy-protocol-test/src/lib.rs index 7fca5463d6b..06cdbc2ff23 100644 --- a/rust-runtime/aws-smithy-protocol-test/src/lib.rs +++ b/rust-runtime/aws-smithy-protocol-test/src/lib.rs @@ -306,10 +306,12 @@ pub fn require_headers( #[derive(Clone)] pub enum MediaType { - /// Json media types are deserialized and compared + /// JSON media types are deserialized and compared Json, /// XML media types are normalized and compared Xml, + /// CBOR media types are decoded from base64 to binary and compared + Cbor, /// For x-www-form-urlencoded, do some map order comparison shenanigans UrlEncodedForm, /// Other media types are compared literally @@ -322,13 +324,14 @@ impl> From for MediaType { "application/json" => MediaType::Json, "application/x-amz-json-1.1" => MediaType::Json, "application/xml" => MediaType::Xml, + "application/cbor" => MediaType::Cbor, "application/x-www-form-urlencoded" => MediaType::UrlEncodedForm, other => MediaType::Other(other.to_string()), } } } -pub fn validate_body>( +pub fn validate_body + Debug>( actual_body: T, expected_body: &str, media_type: MediaType, @@ -336,11 +339,11 @@ pub fn validate_body>( let body_str = std::str::from_utf8(actual_body.as_ref()); match (media_type, body_str) { (MediaType::Json, Ok(actual_body)) => try_json_eq(expected_body, actual_body), - (MediaType::Xml, Ok(actual_body)) => try_xml_equivalent(expected_body, actual_body), (MediaType::Json, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat { expected: "json".to_owned(), found: "input was not valid UTF-8".to_owned(), }), + (MediaType::Xml, Ok(actual_body)) => try_xml_equivalent(actual_body, expected_body), (MediaType::Xml, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat { expected: "XML".to_owned(), found: "input was not valid UTF-8".to_owned(), @@ -352,6 +355,7 @@ pub fn validate_body>( expected: "x-www-form-urlencoded".to_owned(), found: "input was not valid UTF-8".to_owned(), }), + (MediaType::Cbor, _) => try_cbor_eq(actual_body, expected_body), (MediaType::Other(media_type), Ok(actual_body)) => { if actual_body != expected_body { Err(ProtocolTestFailure::BodyDidNotMatch { @@ -410,6 +414,66 @@ fn try_json_eq(expected: &str, actual: &str) -> Result<(), ProtocolTestFailure> } } +fn try_cbor_eq + Debug>( + actual_body: T, + expected_body: &str, +) -> Result<(), ProtocolTestFailure> { + let decoded = base64_simd::STANDARD + .decode_to_vec(expected_body) + .expect("smithy protocol test `body` property is not properly base64 encoded"); + let expected_cbor_value: serde_cbor::Value = + serde_cbor::from_slice(decoded.as_slice()).expect("expected value must be valid CBOR"); + let actual_cbor_value: serde_cbor::Value = serde_cbor::from_slice(actual_body.as_ref()) + .map_err(|e| ProtocolTestFailure::InvalidBodyFormat { + expected: "cbor".to_owned(), + found: format!("{} {:?}", e, actual_body), + })?; + let actual_body_base64 = base64_simd::STANDARD.encode_to_string(&actual_body); + + if expected_cbor_value != actual_cbor_value { + let expected_body_annotated_hex: String = cbor_diag::parse_bytes(&decoded) + .expect("smithy protocol test `body` property is not valid CBOR") + .to_hex(); + let expected_body_diag: String = cbor_diag::parse_bytes(&decoded) + .expect("smithy protocol test `body` property is not valid CBOR") + .to_diag_pretty(); + let actual_body_annotated_hex: String = cbor_diag::parse_bytes(&actual_body) + .expect("actual body is not valid CBOR") + .to_hex(); + let actual_body_diag: String = cbor_diag::parse_bytes(&actual_body) + .expect("actual body is not valid CBOR") + .to_diag_pretty(); + + Err(ProtocolTestFailure::BodyDidNotMatch { + comparison: PrettyString(format!( + "{}", + Comparison::new(&expected_cbor_value, &actual_cbor_value) + )), + // The last newline is important because the panic message ends with a `.` + hint: format!( + "expected body in diagnostic format: +{} +actual body in diagnostic format: +{} +expected body in annotated hex: +{} +actual body in annotated hex: +{} +actual body in base64 (useful to update the protocol test): +{} +", + expected_body_diag, + actual_body_diag, + expected_body_annotated_hex, + actual_body_annotated_hex, + actual_body_base64, + ), + }) + } else { + Ok(()) + } +} + #[cfg(test)] mod tests { use crate::{