diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt index 81a39472ad..8750fa0c96 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt @@ -21,6 +21,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.Non import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeMetadataProvider +import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import java.util.logging.Level diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt index 926259e42e..be75256799 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt @@ -7,18 +7,18 @@ package software.amazon.smithy.rust.codegen.client.smithy.customize import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.customizations.AllowLintsGenerator -import software.amazon.smithy.rust.codegen.client.smithy.customizations.CrateVersionGenerator import software.amazon.smithy.rust.codegen.client.smithy.customizations.EndpointPrefixGenerator import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpChecksumRequiredGenerator import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpVersionListCustomization import software.amazon.smithy.rust.codegen.client.smithy.customizations.IdempotencyTokenGenerator import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyReExportCustomization -import software.amazon.smithy.rust.codegen.client.smithy.customizations.pubUseSmithyTypes import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyTypes import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization @@ -52,8 +52,7 @@ class RequiredCustomizations : ClientCodegenDecorator { codegenContext: ClientCodegenContext, baseCustomizations: List, ): List = - baseCustomizations + CrateVersionGenerator() + - AllowLintsGenerator() + baseCustomizations + CrateVersionCustomization() + AllowLintsCustomization() override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { // Add rt-tokio feature for `ByteStream::from_path` diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt index a6660f4387..5fb38e58e3 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt @@ -12,6 +12,7 @@ import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamBaseRequirements.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamBaseRequirements.kt new file mode 100644 index 0000000000..3f6a7a8c56 --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamBaseRequirements.kt @@ -0,0 +1,59 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream + +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.testutil.clientTestRustSettings +import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements +import java.util.stream.Stream + +class TestCasesProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext?): Stream = + EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream() +} + +abstract class ClientEventStreamBaseRequirements : EventStreamTestRequirements { + override fun createCodegenContext( + model: Model, + serviceShape: ServiceShape, + protocolShapeId: ShapeId, + codegenTarget: CodegenTarget, + ): ClientCodegenContext = ClientCodegenContext( + model, + testSymbolProvider(model), + serviceShape, + protocolShapeId, + clientTestRustSettings(), + CombinedClientCodegenDecorator(emptyList()), + ) + + override fun renderBuilderForShape( + writer: RustWriter, + codegenContext: ClientCodegenContext, + shape: StructureShape, + ) { + BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape).apply { + render(writer) + writer.implBlock(shape, codegenContext.symbolProvider) { + renderConvenienceMethod(writer) + } + } + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt new file mode 100644 index 0000000000..936d3b6324 --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt @@ -0,0 +1,46 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream + +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety +import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject +import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig + +class ClientEventStreamMarshallerGeneratorTest { + @ParameterizedTest + @ArgumentsSource(TestCasesProvider::class) + fun test(testCase: EventStreamTestModels.TestCase) { + EventStreamTestTools.runTestCase( + testCase, + object : ClientEventStreamBaseRequirements() { + override fun renderGenerator( + codegenContext: ClientCodegenContext, + project: TestEventStreamProject, + protocol: Protocol, + ): RuntimeType = EventStreamMarshallerGenerator( + project.model, + CodegenTarget.CLIENT, + TestRuntimeConfig, + project.symbolProvider, + project.streamShape, + protocol.structuredDataSerializer(project.operationShape), + testCase.requestContentType, + ).render() + }, + CodegenTarget.CLIENT, + EventStreamTestVariety.Marshall, + ) + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt new file mode 100644 index 0000000000..f9be7b3bf4 --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream + +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety +import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject + +class ClientEventStreamUnmarshallerGeneratorTest { + @ParameterizedTest + @ArgumentsSource(TestCasesProvider::class) + fun test(testCase: EventStreamTestModels.TestCase) { + EventStreamTestTools.runTestCase( + testCase, + object : ClientEventStreamBaseRequirements() { + override fun renderGenerator( + codegenContext: ClientCodegenContext, + project: TestEventStreamProject, + protocol: Protocol, + ): RuntimeType { + fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(codegenContext.symbolProvider) + return EventStreamUnmarshallerGenerator( + protocol, + codegenContext, + project.operationShape, + project.streamShape, + ::builderSymbol, + ).render() + } + }, + CodegenTarget.CLIENT, + EventStreamTestVariety.Unmarshall, + ) + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt similarity index 89% rename from codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProvider.kt rename to codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt index 7a885d91ae..54b43f722b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rust.codegen.client.smithy +package software.amazon.smithy.rust.codegen.core.smithy import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model @@ -14,13 +14,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol -import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingTraitSymbolProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt similarity index 86% rename from codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingTraitSymbolProvider.kt rename to codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt index 3203e96a22..2e91ee5fba 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingTraitSymbolProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rust.codegen.client.smithy +package software.amazon.smithy.rust.codegen.core.smithy import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model @@ -14,13 +14,6 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.smithy.Default -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider -import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata -import software.amazon.smithy.rust.codegen.core.smithy.setDefault import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/AllowLintsGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt similarity index 86% rename from codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/AllowLintsGenerator.kt rename to codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt index f5755d9a73..0fc88f7a1c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/AllowLintsGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt @@ -3,19 +3,19 @@ * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rust.codegen.client.smithy.customizations +package software.amazon.smithy.rust.codegen.core.smithy.customizations import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection -val AllowedRustcLints = listOf( +private val allowedRustcLints = listOf( // Deprecated items should be safe to compile, so don't block the compilation. "deprecated", ) -val AllowedClippyLints = listOf( +private val allowedClippyLints = listOf( // Sometimes operations are named the same as our module e.g. output leading to `output::output`. "module_inception", @@ -54,16 +54,16 @@ val AllowedClippyLints = listOf( // "result_large_err", ) -val AllowedRustdocLints = listOf( +private val allowedRustdocLints = listOf( // Rust >=1.53.0 requires links to be wrapped in ``. This is extremely hard to enforce for // docs that come from the modeled documentation, so we need to disable this lint "bare_urls", ) -class AllowLintsGenerator( - private val rustcLints: List = AllowedRustcLints, - private val clippyLints: List = AllowedClippyLints, - private val rustdocLints: List = AllowedRustdocLints, +class AllowLintsCustomization( + private val rustcLints: List = allowedRustcLints, + private val clippyLints: List = allowedClippyLints, + private val rustdocLints: List = allowedRustdocLints, ) : LibRsCustomization() { override fun section(section: LibRsSection) = when (section) { is LibRsSection.Attributes -> writable { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/CrateVersionGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt similarity index 87% rename from codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/CrateVersionGenerator.kt rename to codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt index d3b350ed02..eca5503050 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/CrateVersionGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rust.codegen.client.smithy.customizations +package software.amazon.smithy.rust.codegen.core.smithy.customizations import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -13,7 +13,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection /** * Add `PGK_VERSION` const in lib.rs to enable knowing the version of the current module */ -class CrateVersionGenerator : LibRsCustomization() { +class CrateVersionCustomization : LibRsCustomization() { override fun section(section: LibRsSection) = writable { if (section is LibRsSection.Body) { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SmithyTypesPubUseGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt similarity index 75% rename from codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SmithyTypesPubUseGenerator.kt rename to codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt index 2f2c72ab15..02a8843c19 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SmithyTypesPubUseGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt @@ -3,9 +3,10 @@ * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rust.codegen.client.smithy.customizations +package software.amazon.smithy.rust.codegen.core.smithy.customizations import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -30,23 +31,19 @@ private fun hasStreamingOperations(model: Model): Boolean { } } -/** Returns true if the model has any blob shapes or members */ -private fun hasBlobs(model: Model): Boolean { - return model.structureShapes.any { structure -> - structure.members().any { member -> model.expectShape(member.target).isBlobShape } +// TODO(https://github.com/awslabs/smithy-rs/issues/2111): Fix this logic to consider collection/map shapes +private fun structUnionMembersMatchPredicate(model: Model, predicate: (Shape) -> Boolean): Boolean = + model.structureShapes.any { structure -> + structure.members().any { member -> predicate(model.expectShape(member.target)) } } || model.unionShapes.any { union -> - union.members().any { member -> model.expectShape(member.target).isBlobShape } + union.members().any { member -> predicate(model.expectShape(member.target)) } } -} -/** Returns true if the model has any timestamp shapes or members */ -private fun hasDateTimes(model: Model): Boolean { - return model.structureShapes.any { structure -> - structure.members().any { member -> model.expectShape(member.target).isTimestampShape } - } || model.unionShapes.any { union -> - union.members().any { member -> model.expectShape(member.target).isTimestampShape } - } -} +/** Returns true if the model uses any blob shapes */ +private fun hasBlobs(model: Model): Boolean = structUnionMembersMatchPredicate(model, Shape::isBlobShape) + +/** Returns true if the model uses any timestamp shapes */ +private fun hasDateTimes(model: Model): Boolean = structUnionMembersMatchPredicate(model, Shape::isTimestampShape) /** Returns a list of types that should be re-exported for the given model */ internal fun pubUseTypes(runtimeConfig: RuntimeConfig, model: Model): List { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt new file mode 100644 index 0000000000..6e82fc1b2c --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt @@ -0,0 +1,208 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.testutil + +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.dq + +internal object EventStreamMarshallTestCases { + internal fun RustWriter.writeMarshallTestCases( + testCase: EventStreamTestModels.TestCase, + generator: RuntimeType, + ) { + val protocolTestHelpers = CargoDependency.smithyProtocolTestHelpers(TestRuntimeConfig) + .copy(scope = DependencyScope.Compile) + rustTemplate( + """ + use aws_smithy_eventstream::frame::{Message, Header, HeaderValue, MarshallMessage}; + use std::collections::HashMap; + use aws_smithy_types::{Blob, DateTime}; + use crate::error::*; + use crate::model::*; + + use #{validate_body}; + use #{MediaType}; + + fn headers_to_map<'a>(headers: &'a [Header]) -> HashMap { + let mut map = HashMap::new(); + for header in headers { + map.insert(header.name().as_str().to_string(), header.value()); + } + map + } + + fn str_header(value: &'static str) -> HeaderValue { + HeaderValue::String(value.into()) + } + """, + "validate_body" to protocolTestHelpers.toType().resolve("validate_body"), + "MediaType" to protocolTestHelpers.toType().resolve("MediaType"), + ) + + unitTest( + "message_with_blob", + """ + let event = TestStream::MessageWithBlob( + MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build() + ); + let result = ${format(generator)}().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap()); + assert_eq!(&b"hello, world!"[..], message.payload()); + """, + ) + + unitTest( + "message_with_string", + """ + let event = TestStream::MessageWithString( + MessageWithString::builder().data("hello, world!").build() + ); + let result = ${format(generator)}().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap()); + assert_eq!(&b"hello, world!"[..], message.payload()); + """, + ) + + unitTest( + "message_with_struct", + """ + let event = TestStream::MessageWithStruct( + MessageWithStruct::builder().some_struct( + TestStruct::builder() + .some_string("hello") + .some_int(5) + .build() + ).build() + ); + let result = ${format(generator)}().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + + validate_body( + message.payload(), + ${testCase.validTestStruct.dq()}, + MediaType::from(${testCase.requestContentType.dq()}) + ).unwrap(); + """, + ) + + unitTest( + "message_with_union", + """ + let event = TestStream::MessageWithUnion(MessageWithUnion::builder().some_union( + TestUnion::Foo("hello".into()) + ).build()); + let result = ${format(generator)}().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + + validate_body( + message.payload(), + ${testCase.validTestUnion.dq()}, + MediaType::from(${testCase.requestContentType.dq()}) + ).unwrap(); + """, + ) + + unitTest( + "message_with_headers", + """ + let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder() + .blob(Blob::new(&b"test"[..])) + .boolean(true) + .byte(55i8) + .int(100_000i32) + .long(9_000_000_000i64) + .short(16_000i16) + .string("test") + .timestamp(DateTime::from_secs(5)) + .build() + ); + let result = ${format(generator)}().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let actual_message = result.unwrap(); + let expected_message = Message::new(&b""[..]) + .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) + .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into()))) + .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) + .add_header(Header::new("boolean", HeaderValue::Bool(true))) + .add_header(Header::new("byte", HeaderValue::Byte(55i8))) + .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) + .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) + .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) + .add_header(Header::new("string", HeaderValue::String("test".into()))) + .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); + assert_eq!(expected_message, actual_message); + """, + ) + + unitTest( + "message_with_header_and_payload", + """ + let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() + .header("header") + .payload(Blob::new(&b"payload"[..])) + .build() + ); + let result = ${format(generator)}().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let actual_message = result.unwrap(); + let expected_message = Message::new(&b"payload"[..]) + .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) + .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into()))) + .add_header(Header::new("header", HeaderValue::String("header".into()))) + .add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into()))); + assert_eq!(expected_message, actual_message); + """, + ) + + unitTest( + "message_with_no_header_payload_traits", + """ + let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() + .some_int(5) + .some_string("hello") + .build() + ); + let result = ${format(generator)}().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + + validate_body( + message.payload(), + ${testCase.validMessageWithNoHeaderPayloadTraits.dq()}, + MediaType::from(${testCase.requestContentType.dq()}) + ).unwrap(); + """, + ) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt new file mode 100644 index 0000000000..58ab85eec6 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -0,0 +1,179 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.testutil + +import software.amazon.smithy.model.Model +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson +import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml + +private fun fillInBaseModel( + protocolName: String, + extraServiceAnnotations: String = "", +): String = """ + namespace test + + use aws.protocols#$protocolName + + union TestUnion { + Foo: String, + Bar: Integer, + } + structure TestStruct { + someString: String, + someInt: Integer, + } + + @error("client") + structure SomeError { + Message: String, + } + + structure MessageWithBlob { @eventPayload data: Blob } + structure MessageWithString { @eventPayload data: String } + structure MessageWithStruct { @eventPayload someStruct: TestStruct } + structure MessageWithUnion { @eventPayload someUnion: TestUnion } + structure MessageWithHeaders { + @eventHeader blob: Blob, + @eventHeader boolean: Boolean, + @eventHeader byte: Byte, + @eventHeader int: Integer, + @eventHeader long: Long, + @eventHeader short: Short, + @eventHeader string: String, + @eventHeader timestamp: Timestamp, + } + structure MessageWithHeaderAndPayload { + @eventHeader header: String, + @eventPayload payload: Blob, + } + structure MessageWithNoHeaderPayloadTraits { + someInt: Integer, + someString: String, + } + + @streaming + union TestStream { + MessageWithBlob: MessageWithBlob, + MessageWithString: MessageWithString, + MessageWithStruct: MessageWithStruct, + MessageWithUnion: MessageWithUnion, + MessageWithHeaders: MessageWithHeaders, + MessageWithHeaderAndPayload: MessageWithHeaderAndPayload, + MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits, + SomeError: SomeError, + } + structure TestStreamInputOutput { @httpPayload @required value: TestStream } + operation TestStreamOp { + input: TestStreamInputOutput, + output: TestStreamInputOutput, + errors: [SomeError], + } + $extraServiceAnnotations + @$protocolName + service TestService { version: "123", operations: [TestStreamOp] } +""" + +object EventStreamTestModels { + private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel() + private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel() + private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel() + private fun awsQuery(): Model = + fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() + private fun ec2Query(): Model = + fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() + + data class TestCase( + val protocolShapeId: String, + val model: Model, + val requestContentType: String, + val responseContentType: String, + val validTestStruct: String, + val validMessageWithNoHeaderPayloadTraits: String, + val validTestUnion: String, + val validSomeError: String, + val validUnmodeledError: String, + val protocolBuilder: (CodegenContext) -> Protocol, + ) { + override fun toString(): String = protocolShapeId + } + + val TEST_CASES = listOf( + // + // restJson1 + // + TestCase( + protocolShapeId = "aws.protocols#restJson1", + model = restJson1(), + requestContentType = "application/json", + responseContentType = "application/json", + validTestStruct = """{"someString":"hello","someInt":5}""", + validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", + validTestUnion = """{"Foo":"hello"}""", + validSomeError = """{"Message":"some error"}""", + validUnmodeledError = """{"Message":"unmodeled error"}""", + ) { RestJson(it) }, + + // + // awsJson1_1 + // + TestCase( + protocolShapeId = "aws.protocols#awsJson1_1", + model = awsJson11(), + requestContentType = "application/x-amz-json-1.1", + responseContentType = "application/x-amz-json-1.1", + validTestStruct = """{"someString":"hello","someInt":5}""", + validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", + validTestUnion = """{"Foo":"hello"}""", + validSomeError = """{"Message":"some error"}""", + validUnmodeledError = """{"Message":"unmodeled error"}""", + ) { AwsJson(it, AwsJsonVersion.Json11) }, + + // + // restXml + // + TestCase( + protocolShapeId = "aws.protocols#restXml", + model = restXml(), + requestContentType = "application/xml", + responseContentType = "application/xml", + validTestStruct = """ + + hello + 5 + + """.trimIndent(), + validMessageWithNoHeaderPayloadTraits = """ + + hello + 5 + + """.trimIndent(), + validTestUnion = "hello", + validSomeError = """ + + + SomeError + SomeError + some error + + + """.trimIndent(), + validUnmodeledError = """ + + + UnmodeledError + UnmodeledError + unmodeled error + + + """.trimIndent(), + ) { RestXml(it) }, + ) +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestTools.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestTools.kt new file mode 100644 index 0000000000..8340ff3ee3 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestTools.kt @@ -0,0 +1,174 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.testutil + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.CombinedErrorGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ServerCombinedErrorGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.core.util.outputShape +import kotlin.streams.toList + +data class TestEventStreamProject( + val model: Model, + val serviceShape: ServiceShape, + val operationShape: OperationShape, + val streamShape: UnionShape, + val symbolProvider: RustSymbolProvider, + val project: TestWriterDelegator, +) + +enum class EventStreamTestVariety { + Marshall, + Unmarshall +} + +interface EventStreamTestRequirements { + /** Create a codegen context for the tests */ + fun createCodegenContext( + model: Model, + serviceShape: ServiceShape, + protocolShapeId: ShapeId, + codegenTarget: CodegenTarget, + ): C + + /** Render the event stream marshall/unmarshall code generator */ + fun renderGenerator( + codegenContext: C, + project: TestEventStreamProject, + protocol: Protocol, + ): RuntimeType + + /** Render a builder for the given shape */ + fun renderBuilderForShape( + writer: RustWriter, + codegenContext: C, + shape: StructureShape, + ) +} + +object EventStreamTestTools { + fun runTestCase( + testCase: EventStreamTestModels.TestCase, + requirements: EventStreamTestRequirements, + codegenTarget: CodegenTarget, + variety: EventStreamTestVariety, + ) { + val model = EventStreamNormalizer.transform(OperationNormalizer.transform(testCase.model)) + val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape + val codegenContext = requirements.createCodegenContext( + model, + serviceShape, + ShapeId.from(testCase.protocolShapeId), + codegenTarget, + ) + val test = generateTestProject(requirements, codegenContext, codegenTarget) + val protocol = testCase.protocolBuilder(codegenContext) + val generator = requirements.renderGenerator(codegenContext, test, protocol) + + test.project.lib { + when (variety) { + EventStreamTestVariety.Marshall -> writeMarshallTestCases(testCase, generator) + EventStreamTestVariety.Unmarshall -> writeUnmarshallTestCases(testCase, codegenTarget, generator) + } + } + test.project.compileAndTest() + } + + private fun generateTestProject( + requirements: EventStreamTestRequirements, + codegenContext: C, + codegenTarget: CodegenTarget, + ): TestEventStreamProject { + val model = codegenContext.model + val symbolProvider = codegenContext.symbolProvider + val operationShape = model.expectShape(ShapeId.from("test#TestStreamOp")) as OperationShape + val unionShape = model.expectShape(ShapeId.from("test#TestStream")) as UnionShape + + val project = TestWorkspace.testProject(symbolProvider) + val operationSymbol = symbolProvider.toSymbol(operationShape) + project.withModule(ErrorsModule) { + val errors = model.shapes() + .filter { shape -> shape.isStructureShape && shape.hasTrait() } + .map { it.asStructureShape().get() } + .toList() + when (codegenTarget) { + CodegenTarget.CLIENT -> CombinedErrorGenerator(model, symbolProvider, operationSymbol, errors).render(this) + CodegenTarget.SERVER -> ServerCombinedErrorGenerator(model, symbolProvider, operationSymbol, errors).render(this) + } + for (shape in model.shapes().filter { shape -> shape is StructureShape && shape.hasTrait() }) { + StructureGenerator(model, symbolProvider, this, shape as StructureShape).render(codegenTarget) + requirements.renderBuilderForShape(this, codegenContext, shape) + } + } + project.withModule(ModelsModule) { + val inputOutput = model.lookup("test#TestStreamInputOutput") + recursivelyGenerateModels(model, symbolProvider, inputOutput, this, codegenTarget) + } + project.withModule(RustModule.Output) { + operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, this) + } + return TestEventStreamProject( + model, + codegenContext.serviceShape, + operationShape, + unionShape, + symbolProvider, + project, + ) + } + + private fun recursivelyGenerateModels( + model: Model, + symbolProvider: RustSymbolProvider, + shape: Shape, + writer: RustWriter, + mode: CodegenTarget, + ) { + for (member in shape.members()) { + if (member.target.namespace == "smithy.api") { + continue + } + val target = model.expectShape(member.target) + when (target) { + is StructureShape -> target.renderWithModelBuilder(model, symbolProvider, writer) + is UnionShape -> UnionGenerator( + model, + symbolProvider, + writer, + target, + renderUnknownVariant = mode.renderUnknownVariant(), + ).render() + else -> TODO("EventStreamTestTools doesn't support rendering $target") + } + recursivelyGenerateModels(model, symbolProvider, target, writer, mode) + } + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt new file mode 100644 index 0000000000..bb27c724e0 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt @@ -0,0 +1,274 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.testutil + +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType + +internal object EventStreamUnmarshallTestCases { + internal fun RustWriter.writeUnmarshallTestCases( + testCase: EventStreamTestModels.TestCase, + codegenTarget: CodegenTarget, + generator: RuntimeType, + ) { + rust( + """ + use aws_smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage}; + use aws_smithy_types::{Blob, DateTime}; + use crate::error::*; + use crate::model::*; + + fn msg( + message_type: &'static str, + event_type: &'static str, + content_type: &'static str, + payload: &'static [u8], + ) -> Message { + let message = Message::new(payload) + .add_header(Header::new(":message-type", HeaderValue::String(message_type.into()))) + .add_header(Header::new(":content-type", HeaderValue::String(content_type.into()))); + if message_type == "event" { + message.add_header(Header::new(":event-type", HeaderValue::String(event_type.into()))) + } else { + message.add_header(Header::new(":exception-type", HeaderValue::String(event_type.into()))) + } + } + fn expect_event(unmarshalled: UnmarshalledMessage) -> T { + match unmarshalled { + UnmarshalledMessage::Event(event) => event, + _ => panic!("expected event, got: {:?}", unmarshalled), + } + } + fn expect_error(unmarshalled: UnmarshalledMessage) -> E { + match unmarshalled { + UnmarshalledMessage::Error(error) => error, + _ => panic!("expected error, got: {:?}", unmarshalled), + } + } + """, + ) + + unitTest( + name = "message_with_blob", + test = """ + let message = msg("event", "MessageWithBlob", "application/octet-stream", b"hello, world!"); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithBlob( + MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build() + ), + expect_event(result.unwrap()) + ); + """, + ) + + if (codegenTarget == CodegenTarget.CLIENT) { + unitTest( + "unknown_message", + """ + let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!"); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::Unknown, + expect_event(result.unwrap()) + ); + """, + ) + } + + unitTest( + "message_with_string", + """ + let message = msg("event", "MessageWithString", "text/plain", b"hello, world!"); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithString(MessageWithString::builder().data("hello, world!").build()), + expect_event(result.unwrap()) + ); + """, + ) + + unitTest( + "message_with_struct", + """ + let message = msg( + "event", + "MessageWithStruct", + "${testCase.responseContentType}", + br#"${testCase.validTestStruct}"# + ); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct( + TestStruct::builder() + .some_string("hello") + .some_int(5) + .build() + ).build()), + expect_event(result.unwrap()) + ); + """, + ) + + unitTest( + "message_with_union", + """ + let message = msg( + "event", + "MessageWithUnion", + "${testCase.responseContentType}", + br#"${testCase.validTestUnion}"# + ); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithUnion(MessageWithUnion::builder().some_union( + TestUnion::Foo("hello".into()) + ).build()), + expect_event(result.unwrap()) + ); + """, + ) + + unitTest( + "message_with_headers", + """ + let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"") + .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) + .add_header(Header::new("boolean", HeaderValue::Bool(true))) + .add_header(Header::new("byte", HeaderValue::Byte(55i8))) + .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) + .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) + .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) + .add_header(Header::new("string", HeaderValue::String("test".into()))) + .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithHeaders(MessageWithHeaders::builder() + .blob(Blob::new(&b"test"[..])) + .boolean(true) + .byte(55i8) + .int(100_000i32) + .long(9_000_000_000i64) + .short(16_000i16) + .string("test") + .timestamp(DateTime::from_secs(5)) + .build() + ), + expect_event(result.unwrap()) + ); + """, + ) + + unitTest( + "message_with_header_and_payload", + """ + let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload") + .add_header(Header::new("header", HeaderValue::String("header".into()))); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() + .header("header") + .payload(Blob::new(&b"payload"[..])) + .build() + ), + expect_event(result.unwrap()) + ); + """, + ) + + unitTest( + "message_with_no_header_payload_traits", + """ + let message = msg( + "event", + "MessageWithNoHeaderPayloadTraits", + "${testCase.responseContentType}", + br#"${testCase.validMessageWithNoHeaderPayloadTraits}"# + ); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() + .some_int(5) + .some_string("hello") + .build() + ), + expect_event(result.unwrap()) + ); + """, + ) + + val (someError, kindSuffix) = when (codegenTarget) { + CodegenTarget.CLIENT -> "TestStreamErrorKind::SomeError" to ".kind" + CodegenTarget.SERVER -> "TestStreamError::SomeError" to "" + } + unitTest( + "some_error", + """ + let message = msg( + "exception", + "SomeError", + "${testCase.responseContentType}", + br#"${testCase.validSomeError}"# + ); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + match expect_error(result.unwrap())$kindSuffix { + $someError(err) => assert_eq!(Some("some error"), err.message()), + kind => panic!("expected SomeError, but got {:?}", kind), + } + """, + ) + + if (codegenTarget == CodegenTarget.CLIENT) { + unitTest( + "generic_error", + """ + let message = msg( + "exception", + "UnmodeledError", + "${testCase.responseContentType}", + br#"${testCase.validUnmodeledError}"# + ); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + match expect_error(result.unwrap())$kindSuffix { + TestStreamErrorKind::Unhandled(err) => { + let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err)); + let expected = "message: \"unmodeled error\""; + assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'"); + } + kind => panic!("expected generic error, but got {:?}", kind), + } + """, + ) + } + + unitTest( + "bad_content_type", + """ + let message = msg( + "event", + "MessageWithBlob", + "wrong-content-type", + br#"${testCase.validTestStruct}"# + ); + let result = ${format(generator)}().unmarshall(&message); + assert!(result.is_err(), "expected error, got: {:?}", result); + assert!(format!("{}", result.err().unwrap()).contains("expected :content-type to be")); + """, + ) + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/SmithyTypesPubUseGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseGeneratorTest.kt similarity index 97% rename from codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/SmithyTypesPubUseGeneratorTest.kt rename to codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseGeneratorTest.kt index 6a7ca5be25..c147567d71 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/SmithyTypesPubUseGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseGeneratorTest.kt @@ -3,11 +3,10 @@ * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rust.codegen.client.customizations +package software.amazon.smithy.rust.codegen.core.smithy.customizations import org.junit.jupiter.api.Test import software.amazon.smithy.model.Model -import software.amazon.smithy.rust.codegen.client.smithy.customizations.pubUseTypes import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel diff --git a/codegen-server/build.gradle.kts b/codegen-server/build.gradle.kts index 50aa275492..8dd6dc5d18 100644 --- a/codegen-server/build.gradle.kts +++ b/codegen-server/build.gradle.kts @@ -24,7 +24,6 @@ val smithyVersion: String by project dependencies { implementation(project(":codegen-core")) - implementation(project(":codegen-client")) implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt index 446260ae96..3d8a57ef6b 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt @@ -10,10 +10,10 @@ import software.amazon.smithy.build.SmithyBuildPlugin import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.rust.codegen.client.smithy.EventStreamSymbolProvider import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.server.python.smithy.customizations.DECORATORS diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt index 1da7d944fa..98e875f3cc 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt @@ -10,12 +10,12 @@ import software.amazon.smithy.build.SmithyBuildPlugin import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.rust.codegen.client.smithy.EventStreamSymbolProvider -import software.amazon.smithy.rust.codegen.client.smithy.StreamingShapeMetadataProvider -import software.amazon.smithy.rust.codegen.client.smithy.StreamingShapeSymbolProvider import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeMetadataProvider +import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt index 2949d605de..90b3550b98 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt @@ -5,11 +5,11 @@ package software.amazon.smithy.rust.codegen.server.smithy.customizations -import software.amazon.smithy.rust.codegen.client.smithy.customizations.AllowLintsGenerator -import software.amazon.smithy.rust.codegen.client.smithy.customizations.CrateVersionGenerator -import software.amazon.smithy.rust.codegen.client.smithy.customizations.pubUseSmithyTypes import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyTypes import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator @@ -29,7 +29,7 @@ class ServerRequiredCustomizations : ServerCodegenDecorator { codegenContext: ServerCodegenContext, baseCustomizations: List, ): List = - baseCustomizations + CrateVersionGenerator() + AllowLintsGenerator() + baseCustomizations + CrateVersionCustomization() + AllowLintsCustomization() override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { // Add rt-tokio feature for `ByteStream::from_path` diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt index 86bb4439bc..9d49dfb9fe 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt @@ -39,12 +39,20 @@ private fun testServiceShapeFor(model: Model) = fun serverTestSymbolProvider(model: Model, serviceShape: ServiceShape? = null) = serverTestSymbolProviders(model, serviceShape).symbolProvider -fun serverTestSymbolProviders(model: Model, serviceShape: ServiceShape? = null) = +fun serverTestSymbolProviders( + model: Model, + serviceShape: ServiceShape? = null, + settings: ServerRustSettings? = null, +) = ServerSymbolProviders.from( model, serviceShape ?: testServiceShapeFor(model), ServerTestSymbolVisitorConfig, - serverTestRustSettings((serviceShape ?: testServiceShapeFor(model)).id).codegenConfig.publicConstrainedTypes, + ( + settings ?: serverTestRustSettings( + (serviceShape ?: testServiceShapeFor(model)).id, + ) + ).codegenConfig.publicConstrainedTypes, RustCodegenServerPlugin::baseSymbolProvider, ) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt deleted file mode 100644 index 71e425eb38..0000000000 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt +++ /dev/null @@ -1,407 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.server.smithy.protocols - -import org.junit.jupiter.api.extension.ExtensionContext -import org.junit.jupiter.params.provider.Arguments -import org.junit.jupiter.params.provider.ArgumentsProvider -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.ErrorTrait -import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.CombinedErrorGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ServerCombinedErrorGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock -import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant -import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson -import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion -import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsQueryProtocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Ec2QueryProtocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson -import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml -import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer -import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer -import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace -import software.amazon.smithy.rust.codegen.core.testutil.TestWriterDelegator -import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel -import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder -import software.amazon.smithy.rust.codegen.core.util.hasTrait -import software.amazon.smithy.rust.codegen.core.util.lookup -import software.amazon.smithy.rust.codegen.core.util.outputShape -import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider -import java.util.stream.Stream -import kotlin.streams.toList - -private fun fillInBaseModel( - protocolName: String, - extraServiceAnnotations: String = "", -): String = """ - namespace test - - use aws.protocols#$protocolName - - union TestUnion { - Foo: String, - Bar: Integer, - } - structure TestStruct { - someString: String, - someInt: Integer, - } - - @error("client") - structure SomeError { - Message: String, - } - - structure MessageWithBlob { @eventPayload data: Blob } - structure MessageWithString { @eventPayload data: String } - structure MessageWithStruct { @eventPayload someStruct: TestStruct } - structure MessageWithUnion { @eventPayload someUnion: TestUnion } - structure MessageWithHeaders { - @eventHeader blob: Blob, - @eventHeader boolean: Boolean, - @eventHeader byte: Byte, - @eventHeader int: Integer, - @eventHeader long: Long, - @eventHeader short: Short, - @eventHeader string: String, - @eventHeader timestamp: Timestamp, - } - structure MessageWithHeaderAndPayload { - @eventHeader header: String, - @eventPayload payload: Blob, - } - structure MessageWithNoHeaderPayloadTraits { - someInt: Integer, - someString: String, - } - - @streaming - union TestStream { - MessageWithBlob: MessageWithBlob, - MessageWithString: MessageWithString, - MessageWithStruct: MessageWithStruct, - MessageWithUnion: MessageWithUnion, - MessageWithHeaders: MessageWithHeaders, - MessageWithHeaderAndPayload: MessageWithHeaderAndPayload, - MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits, - SomeError: SomeError, - } - structure TestStreamInputOutput { @httpPayload @required value: TestStream } - operation TestStreamOp { - input: TestStreamInputOutput, - output: TestStreamInputOutput, - errors: [SomeError], - } - $extraServiceAnnotations - @$protocolName - service TestService { version: "123", operations: [TestStreamOp] } -""" - -object EventStreamTestModels { - private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel() - private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel() - private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel() - private fun awsQuery(): Model = - fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() - private fun ec2Query(): Model = - fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() - - data class TestCase( - val protocolShapeId: String, - val model: Model, - val requestContentType: String, - val responseContentType: String, - val validTestStruct: String, - val validMessageWithNoHeaderPayloadTraits: String, - val validTestUnion: String, - val validSomeError: String, - val validUnmodeledError: String, - val target: CodegenTarget = CodegenTarget.CLIENT, - val protocolBuilder: (CodegenContext) -> Protocol, - ) { - override fun toString(): String = protocolShapeId - } - - private val testCases = listOf( - // - // restJson1 - // - TestCase( - protocolShapeId = "aws.protocols#restJson1", - model = restJson1(), - requestContentType = "application/json", - responseContentType = "application/json", - validTestStruct = """{"someString":"hello","someInt":5}""", - validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", - validTestUnion = """{"Foo":"hello"}""", - validSomeError = """{"Message":"some error"}""", - validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { RestJson(it) }, - - // - // restJson1, server mode - // - TestCase( - protocolShapeId = "aws.protocols#restJson1", - model = restJson1(), - requestContentType = "application/json", - responseContentType = "application/json", - validTestStruct = """{"someString":"hello","someInt":5}""", - validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", - validTestUnion = """{"Foo":"hello"}""", - validSomeError = """{"Message":"some error"}""", - validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { RestJson(it) }, - - // - // awsJson1_1 - // - TestCase( - protocolShapeId = "aws.protocols#awsJson1_1", - model = awsJson11(), - requestContentType = "application/x-amz-json-1.1", - responseContentType = "application/x-amz-json-1.1", - validTestStruct = """{"someString":"hello","someInt":5}""", - validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", - validTestUnion = """{"Foo":"hello"}""", - validSomeError = """{"Message":"some error"}""", - validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { AwsJson(it, AwsJsonVersion.Json11) }, - - // - // restXml - // - TestCase( - protocolShapeId = "aws.protocols#restXml", - model = restXml(), - requestContentType = "application/xml", - responseContentType = "application/xml", - validTestStruct = """ - - hello - 5 - - """.trimIndent(), - validMessageWithNoHeaderPayloadTraits = """ - - hello - 5 - - """.trimIndent(), - validTestUnion = "hello", - validSomeError = """ - - - SomeError - SomeError - some error - - - """.trimIndent(), - validUnmodeledError = """ - - - UnmodeledError - UnmodeledError - unmodeled error - - - """.trimIndent(), - ) { RestXml(it) }, - - // - // awsQuery - // - TestCase( - protocolShapeId = "aws.protocols#awsQuery", - model = awsQuery(), - requestContentType = "application/x-www-form-urlencoded", - responseContentType = "text/xml", - validTestStruct = """ - - hello - 5 - - """.trimIndent(), - validMessageWithNoHeaderPayloadTraits = """ - - hello - 5 - - """.trimIndent(), - validTestUnion = "hello", - validSomeError = """ - - - SomeError - SomeError - some error - - - """.trimIndent(), - validUnmodeledError = """ - - - UnmodeledError - UnmodeledError - unmodeled error - - - """.trimIndent(), - ) { AwsQueryProtocol(it) }, - - // - // ec2Query - // - TestCase( - protocolShapeId = "aws.protocols#ec2Query", - model = ec2Query(), - requestContentType = "application/x-www-form-urlencoded", - responseContentType = "text/xml", - validTestStruct = """ - - hello - 5 - - """.trimIndent(), - validMessageWithNoHeaderPayloadTraits = """ - - hello - 5 - - """.trimIndent(), - validTestUnion = "hello", - validSomeError = """ - - - - SomeError - SomeError - some error - - - - """.trimIndent(), - validUnmodeledError = """ - - - - UnmodeledError - UnmodeledError - unmodeled error - - - - """.trimIndent(), - ) { Ec2QueryProtocol(it) }, - ) - // TODO(https://github.com/awslabs/smithy-rs/issues/1442) Server tests - // should be run from the server subproject using the - // `serverTestSymbolProvider()`. - // .flatMap { listOf(it, it.copy(target = CodegenTarget.SERVER)) } - - class UnmarshallTestCasesProvider : ArgumentsProvider { - override fun provideArguments(context: ExtensionContext?): Stream = - testCases.map { Arguments.of(it) }.stream() - } - - class MarshallTestCasesProvider : ArgumentsProvider { - override fun provideArguments(context: ExtensionContext?): Stream = - // Don't include awsQuery or ec2Query for now since marshall support for them is unimplemented - testCases - .filter { testCase -> !testCase.protocolShapeId.contains("Query") } - .map { Arguments.of(it) }.stream() - } -} - -data class TestEventStreamProject( - val model: Model, - val serviceShape: ServiceShape, - val operationShape: OperationShape, - val streamShape: UnionShape, - val symbolProvider: RustSymbolProvider, - val project: TestWriterDelegator, -) - -object EventStreamTestTools { - fun generateTestProject(testCase: EventStreamTestModels.TestCase): TestEventStreamProject { - val model = EventStreamNormalizer.transform(OperationNormalizer.transform(testCase.model)) - val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val operationShape = model.expectShape(ShapeId.from("test#TestStreamOp")) as OperationShape - val unionShape = model.expectShape(ShapeId.from("test#TestStream")) as UnionShape - - val symbolProvider = when (testCase.target) { - CodegenTarget.CLIENT -> testSymbolProvider(model) - CodegenTarget.SERVER -> serverTestSymbolProvider(model) - } - val project = TestWorkspace.testProject(symbolProvider) - val operationSymbol = symbolProvider.toSymbol(operationShape) - project.withModule(ErrorsModule) { - val errors = model.shapes() - .filter { shape -> shape.isStructureShape && shape.hasTrait() } - .map { it.asStructureShape().get() } - .toList() - when (testCase.target) { - CodegenTarget.CLIENT -> CombinedErrorGenerator(model, symbolProvider, operationSymbol, errors).render(this) - CodegenTarget.SERVER -> ServerCombinedErrorGenerator(model, symbolProvider, operationSymbol, errors).render(this) - } - for (shape in model.shapes().filter { shape -> shape.isStructureShape && shape.hasTrait() }) { - StructureGenerator(model, symbolProvider, this, shape as StructureShape).render(testCase.target) - val builderGen = BuilderGenerator(model, symbolProvider, shape) - builderGen.render(this) - implBlock(shape, symbolProvider) { - builderGen.renderConvenienceMethod(this) - } - } - } - project.withModule(ModelsModule) { - val inputOutput = model.lookup("test#TestStreamInputOutput") - recursivelyGenerateModels(model, symbolProvider, inputOutput, this, testCase.target) - } - project.withModule(RustModule.Output) { - operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, this) - } - return TestEventStreamProject(model, serviceShape, operationShape, unionShape, symbolProvider, project) - } - - private fun recursivelyGenerateModels( - model: Model, - symbolProvider: RustSymbolProvider, - shape: Shape, - writer: RustWriter, - mode: CodegenTarget, - ) { - for (member in shape.members()) { - val target = model.expectShape(member.target) - if (target is StructureShape || target is UnionShape) { - if (target is StructureShape) { - target.renderWithModelBuilder(model, symbolProvider, writer) - } else if (target is UnionShape) { - UnionGenerator(model, symbolProvider, writer, target, renderUnknownVariant = mode.renderUnknownVariant()).render() - } - recursivelyGenerateModels(model, symbolProvider, target, writer, mode) - } - } - } -} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamBaseRequirements.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamBaseRequirements.kt new file mode 100644 index 0000000000..1e2cdef43f --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamBaseRequirements.kt @@ -0,0 +1,83 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream + +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenConfig +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGeneratorWithoutPublicConstrainedTypes +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings +import java.util.stream.Stream + +data class TestCase( + val eventStreamTestCase: EventStreamTestModels.TestCase, + val publicConstrainedTypes: Boolean, +) { + override fun toString(): String = "$eventStreamTestCase, publicConstrainedTypes = $publicConstrainedTypes" +} + +class TestCasesProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext?): Stream = + EventStreamTestModels.TEST_CASES + .flatMap { testCase -> + listOf( + TestCase(testCase, publicConstrainedTypes = false), + TestCase(testCase, publicConstrainedTypes = true), + ) + }.map { Arguments.of(it) }.stream() +} + +abstract class ServerEventStreamBaseRequirements : EventStreamTestRequirements { + abstract val publicConstrainedTypes: Boolean + + override fun createCodegenContext( + model: Model, + serviceShape: ServiceShape, + protocolShapeId: ShapeId, + codegenTarget: CodegenTarget, + ): ServerCodegenContext = serverTestCodegenContext( + model, serviceShape, + serverTestRustSettings( + codegenConfig = ServerCodegenConfig(publicConstrainedTypes = publicConstrainedTypes), + ), + protocolShapeId, + ) + + override fun renderBuilderForShape( + writer: RustWriter, + codegenContext: ServerCodegenContext, + shape: StructureShape, + ) { + if (codegenContext.settings.codegenConfig.publicConstrainedTypes) { + ServerBuilderGenerator(codegenContext, shape).apply { + render(writer) + writer.implBlock(shape, codegenContext.symbolProvider) { + renderConvenienceMethod(writer) + } + } + } else { + ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape).apply { + render(writer) + writer.implBlock(shape, codegenContext.symbolProvider) { + renderConvenienceMethod(writer) + } + } + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt new file mode 100644 index 0000000000..860b451ee8 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream + +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety +import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject +import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext + +class ServerEventStreamMarshallerGeneratorTest { + @ParameterizedTest + @ArgumentsSource(TestCasesProvider::class) + fun test(testCase: TestCase) { + EventStreamTestTools.runTestCase( + testCase.eventStreamTestCase, + object : ServerEventStreamBaseRequirements() { + override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes + + override fun renderGenerator( + codegenContext: ServerCodegenContext, + project: TestEventStreamProject, + protocol: Protocol, + ): RuntimeType { + return EventStreamMarshallerGenerator( + project.model, + CodegenTarget.SERVER, + TestRuntimeConfig, + project.symbolProvider, + project.streamShape, + protocol.structuredDataSerializer(project.operationShape), + testCase.eventStreamTestCase.requestContentType, + ).render() + } + }, + CodegenTarget.SERVER, + EventStreamTestVariety.Marshall, + ) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt new file mode 100644 index 0000000000..08a5ef5f58 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt @@ -0,0 +1,73 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream + +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety +import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol + +class ServerEventStreamUnmarshallerGeneratorTest { + @ParameterizedTest + @ArgumentsSource(TestCasesProvider::class) + fun test(testCase: TestCase) { + // TODO(https://github.com/awslabs/smithy-rs/issues/1442): Enable tests for `publicConstrainedTypes = false` + // by deleting this if/return + if (!testCase.publicConstrainedTypes) { + return + } + + EventStreamTestTools.runTestCase( + testCase.eventStreamTestCase, + object : ServerEventStreamBaseRequirements() { + override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes + + override fun renderGenerator( + codegenContext: ServerCodegenContext, + project: TestEventStreamProject, + protocol: Protocol, + ): RuntimeType { + fun builderSymbol(shape: StructureShape): Symbol = shape.serverBuilderSymbol(codegenContext) + return EventStreamUnmarshallerGenerator( + protocol, + codegenContext, + project.operationShape, + project.streamShape, + ::builderSymbol, + ).render() + } + + // TODO(https://github.com/awslabs/smithy-rs/issues/1442): Delete this function override to use the correct builder from the parent class + override fun renderBuilderForShape( + writer: RustWriter, + codegenContext: ServerCodegenContext, + shape: StructureShape, + ) { + BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape).apply { + render(writer) + writer.implBlock(shape, codegenContext.symbolProvider) { + renderConvenienceMethod(writer) + } + } + } + }, + CodegenTarget.SERVER, + EventStreamTestVariety.Unmarshall, + ) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt deleted file mode 100644 index 5094426d75..0000000000 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt +++ /dev/null @@ -1,306 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.server.smithy.protocols.parse - -import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.ArgumentsSource -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator -import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings -import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.server.smithy.protocols.EventStreamTestModels -import software.amazon.smithy.rust.codegen.server.smithy.protocols.EventStreamTestTools - -class EventStreamUnmarshallerGeneratorTest { - @ParameterizedTest - @ArgumentsSource(EventStreamTestModels.UnmarshallTestCasesProvider::class) - fun test(testCase: EventStreamTestModels.TestCase) { - val test = EventStreamTestTools.generateTestProject(testCase) - - val codegenContext = CodegenContext( - test.model, - test.symbolProvider, - test.serviceShape, - ShapeId.from(testCase.protocolShapeId), - testRustSettings(), - target = testCase.target, - ) - val protocol = testCase.protocolBuilder(codegenContext) - fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(codegenContext.symbolProvider) - val generator = EventStreamUnmarshallerGenerator( - protocol, - codegenContext, - test.operationShape, - test.streamShape, - ::builderSymbol, - ) - - test.project.lib { - rust( - """ - use aws_smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage}; - use aws_smithy_types::{Blob, DateTime}; - use crate::error::*; - use crate::model::*; - - fn msg( - message_type: &'static str, - event_type: &'static str, - content_type: &'static str, - payload: &'static [u8], - ) -> Message { - let message = Message::new(payload) - .add_header(Header::new(":message-type", HeaderValue::String(message_type.into()))) - .add_header(Header::new(":content-type", HeaderValue::String(content_type.into()))); - if message_type == "event" { - message.add_header(Header::new(":event-type", HeaderValue::String(event_type.into()))) - } else { - message.add_header(Header::new(":exception-type", HeaderValue::String(event_type.into()))) - } - } - fn expect_event(unmarshalled: UnmarshalledMessage) -> T { - match unmarshalled { - UnmarshalledMessage::Event(event) => event, - _ => panic!("expected event, got: {:?}", unmarshalled), - } - } - fn expect_error(unmarshalled: UnmarshalledMessage) -> E { - match unmarshalled { - UnmarshalledMessage::Error(error) => error, - _ => panic!("expected error, got: {:?}", unmarshalled), - } - } - """, - ) - - unitTest( - name = "message_with_blob", - test = """ - let message = msg("event", "MessageWithBlob", "application/octet-stream", b"hello, world!"); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithBlob( - MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build() - ), - expect_event(result.unwrap()) - ); - """, - ) - - if (testCase.target == CodegenTarget.CLIENT) { - unitTest( - "unknown_message", - """ - let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!"); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::Unknown, - expect_event(result.unwrap()) - ); - """, - ) - } - - unitTest( - "message_with_string", - """ - let message = msg("event", "MessageWithString", "text/plain", b"hello, world!"); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithString(MessageWithString::builder().data("hello, world!").build()), - expect_event(result.unwrap()) - ); - """, - ) - - unitTest( - "message_with_struct", - """ - let message = msg( - "event", - "MessageWithStruct", - "${testCase.responseContentType}", - br#"${testCase.validTestStruct}"# - ); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct( - TestStruct::builder() - .some_string("hello") - .some_int(5) - .build() - ).build()), - expect_event(result.unwrap()) - ); - """, - ) - - unitTest( - "message_with_union", - """ - let message = msg( - "event", - "MessageWithUnion", - "${testCase.responseContentType}", - br#"${testCase.validTestUnion}"# - ); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithUnion(MessageWithUnion::builder().some_union( - TestUnion::Foo("hello".into()) - ).build()), - expect_event(result.unwrap()) - ); - """, - ) - - unitTest( - "message_with_headers", - """ - let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"") - .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) - .add_header(Header::new("boolean", HeaderValue::Bool(true))) - .add_header(Header::new("byte", HeaderValue::Byte(55i8))) - .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) - .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) - .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) - .add_header(Header::new("string", HeaderValue::String("test".into()))) - .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithHeaders(MessageWithHeaders::builder() - .blob(Blob::new(&b"test"[..])) - .boolean(true) - .byte(55i8) - .int(100_000i32) - .long(9_000_000_000i64) - .short(16_000i16) - .string("test") - .timestamp(DateTime::from_secs(5)) - .build() - ), - expect_event(result.unwrap()) - ); - """, - ) - - unitTest( - "message_with_header_and_payload", - """ - let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload") - .add_header(Header::new("header", HeaderValue::String("header".into()))); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() - .header("header") - .payload(Blob::new(&b"payload"[..])) - .build() - ), - expect_event(result.unwrap()) - ); - """, - ) - - unitTest( - "message_with_no_header_payload_traits", - """ - let message = msg( - "event", - "MessageWithNoHeaderPayloadTraits", - "${testCase.responseContentType}", - br#"${testCase.validMessageWithNoHeaderPayloadTraits}"# - ); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() - .some_int(5) - .some_string("hello") - .build() - ), - expect_event(result.unwrap()) - ); - """, - ) - - val (someError, kindSuffix) = when (testCase.target) { - CodegenTarget.CLIENT -> listOf("TestStreamErrorKind::SomeError", ".kind") - CodegenTarget.SERVER -> listOf("TestStreamError::SomeError", "") - } - unitTest( - "some_error", - """ - let message = msg( - "exception", - "SomeError", - "${testCase.responseContentType}", - br#"${testCase.validSomeError}"# - ); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - match expect_error(result.unwrap())$kindSuffix { - $someError(err) => assert_eq!(Some("some error"), err.message()), - kind => panic!("expected SomeError, but got {:?}", kind), - } - """, - ) - - if (testCase.target == CodegenTarget.CLIENT) { - unitTest( - "generic_error", - """ - let message = msg( - "exception", - "UnmodeledError", - "${testCase.responseContentType}", - br#"${testCase.validUnmodeledError}"# - ); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - match expect_error(result.unwrap())$kindSuffix { - TestStreamErrorKind::Unhandled(err) => { - let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err)); - let expected = "message: \"unmodeled error\""; - assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'"); - } - kind => panic!("expected generic error, but got {:?}", kind), - } - """, - ) - } - - unitTest( - "bad_content_type", - """ - let message = msg( - "event", - "MessageWithBlob", - "wrong-content-type", - br#"${testCase.validTestStruct}"# - ); - let result = ${format(generator.render())}().unmarshall(&message); - assert!(result.is_err(), "expected error, got: {:?}", result); - assert!(format!("{}", result.err().unwrap()).contains("expected :content-type to be")); - """, - ) - } - test.project.compileAndTest() - } -} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/EventStreamMarshallerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/EventStreamMarshallerGeneratorTest.kt deleted file mode 100644 index 213318b027..0000000000 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/EventStreamMarshallerGeneratorTest.kt +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize - -import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.ArgumentsSource -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig -import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings -import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.server.smithy.protocols.EventStreamTestModels -import software.amazon.smithy.rust.codegen.server.smithy.protocols.EventStreamTestTools - -class EventStreamMarshallerGeneratorTest { - @ParameterizedTest - @ArgumentsSource(EventStreamTestModels.MarshallTestCasesProvider::class) - fun test(testCase: EventStreamTestModels.TestCase) { - val test = EventStreamTestTools.generateTestProject(testCase) - - val codegenContext = CodegenContext( - test.model, - test.symbolProvider, - test.serviceShape, - ShapeId.from(testCase.protocolShapeId), - testRustSettings(), - target = testCase.target, - ) - val protocol = testCase.protocolBuilder(codegenContext) - val generator = EventStreamMarshallerGenerator( - test.model, - testCase.target, - TestRuntimeConfig, - test.symbolProvider, - test.streamShape, - protocol.structuredDataSerializer(test.operationShape), - testCase.requestContentType, - ) - - test.project.lib { - val protocolTestHelpers = CargoDependency.smithyProtocolTestHelpers(TestRuntimeConfig) - .copy(scope = DependencyScope.Compile) - rustTemplate( - """ - use aws_smithy_eventstream::frame::{Message, Header, HeaderValue, MarshallMessage}; - use std::collections::HashMap; - use aws_smithy_types::{Blob, DateTime}; - use crate::error::*; - use crate::model::*; - - use #{validate_body}; - use #{MediaType}; - - fn headers_to_map<'a>(headers: &'a [Header]) -> HashMap { - let mut map = HashMap::new(); - for header in headers { - map.insert(header.name().as_str().to_string(), header.value()); - } - map - } - - fn str_header(value: &'static str) -> HeaderValue { - HeaderValue::String(value.into()) - } - """, - "validate_body" to protocolTestHelpers.toType().resolve("validate_body"), - "MediaType" to protocolTestHelpers.toType().resolve("MediaType"), - ) - - unitTest( - "message_with_blob", - """ - let event = TestStream::MessageWithBlob( - MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build() - ); - let result = ${format(generator.render())}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap()); - assert_eq!(&b"hello, world!"[..], message.payload()); - """, - ) - - unitTest( - "message_with_string", - """ - let event = TestStream::MessageWithString( - MessageWithString::builder().data("hello, world!").build() - ); - let result = ${format(generator.render())}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap()); - assert_eq!(&b"hello, world!"[..], message.payload()); - """, - ) - - unitTest( - "message_with_struct", - """ - let event = TestStream::MessageWithStruct( - MessageWithStruct::builder().some_struct( - TestStruct::builder() - .some_string("hello") - .some_int(5) - .build() - ).build() - ); - let result = ${format(generator.render())}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); - - validate_body( - message.payload(), - ${testCase.validTestStruct.dq()}, - MediaType::from(${testCase.requestContentType.dq()}) - ).unwrap(); - """, - ) - - unitTest( - "message_with_union", - """ - let event = TestStream::MessageWithUnion(MessageWithUnion::builder().some_union( - TestUnion::Foo("hello".into()) - ).build()); - let result = ${format(generator.render())}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); - - validate_body( - message.payload(), - ${testCase.validTestUnion.dq()}, - MediaType::from(${testCase.requestContentType.dq()}) - ).unwrap(); - """, - ) - - unitTest( - "message_with_headers", - """ - let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder() - .blob(Blob::new(&b"test"[..])) - .boolean(true) - .byte(55i8) - .int(100_000i32) - .long(9_000_000_000i64) - .short(16_000i16) - .string("test") - .timestamp(DateTime::from_secs(5)) - .build() - ); - let result = ${format(generator.render())}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let actual_message = result.unwrap(); - let expected_message = Message::new(&b""[..]) - .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) - .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into()))) - .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) - .add_header(Header::new("boolean", HeaderValue::Bool(true))) - .add_header(Header::new("byte", HeaderValue::Byte(55i8))) - .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) - .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) - .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) - .add_header(Header::new("string", HeaderValue::String("test".into()))) - .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); - assert_eq!(expected_message, actual_message); - """, - ) - - unitTest( - "message_with_header_and_payload", - """ - let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() - .header("header") - .payload(Blob::new(&b"payload"[..])) - .build() - ); - let result = ${format(generator.render())}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let actual_message = result.unwrap(); - let expected_message = Message::new(&b"payload"[..]) - .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) - .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into()))) - .add_header(Header::new("header", HeaderValue::String("header".into()))) - .add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into()))); - assert_eq!(expected_message, actual_message); - """, - ) - - unitTest( - "message_with_no_header_payload_traits", - """ - let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() - .some_int(5) - .some_string("hello") - .build() - ); - let result = ${format(generator.render())}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); - - validate_body( - message.payload(), - ${testCase.validMessageWithNoHeaderPayloadTraits.dq()}, - MediaType::from(${testCase.requestContentType.dq()}) - ).unwrap(); - """, - ) - } - test.project.compileAndTest() - } -}