@@ -14,6 +14,9 @@ import software.amazon.smithy.model.shapes.Shape
14
14
import software.amazon.smithy.model.shapes.ShapeId
15
15
import software.amazon.smithy.model.shapes.StructureShape
16
16
import software.amazon.smithy.model.transform.ModelTransformer
17
+ import software.amazon.smithy.rulesengine.traits.EndpointTestCase
18
+ import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput
19
+ import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait
17
20
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
18
21
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
19
22
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization
@@ -31,7 +34,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
31
34
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
32
35
import software.amazon.smithy.rust.codegen.core.smithy.traits.AllowInvalidXmlRoot
33
36
import software.amazon.smithy.rust.codegen.core.util.letIf
34
- import software.amazon.smithy.rustsdk.endpoints.stripEndpointTrait
35
37
import software.amazon.smithy.rustsdk.getBuiltIn
36
38
import software.amazon.smithy.rustsdk.toWritable
37
39
import java.util.logging.Logger
@@ -63,23 +65,39 @@ class S3Decorator : ClientCodegenDecorator {
63
65
logger.info(" Adding AllowInvalidXmlRoot trait to $it " )
64
66
(it as StructureShape ).toBuilder().addTrait(AllowInvalidXmlRoot ()).build()
65
67
}
66
- }.let (StripBucketFromHttpPath ()::transform).let (stripEndpointTrait(" RequestRoute" ))
68
+ }
69
+ // the model has the bucket in the path
70
+ .let (StripBucketFromHttpPath ()::transform)
71
+ // the tests in EP2 are incorrect and are missing request route
72
+ .let (
73
+ FilterEndpointTests (
74
+ operationInputFilter = { input ->
75
+ when (input.operationName) {
76
+ // it's impossible to express HostPrefix behavior in the current EP2 rules schema :-/
77
+ // A handwritten test was written to cover this behavior
78
+ " WriteGetObjectResponse" -> null
79
+ else -> input
80
+ }
81
+ },
82
+ )::transform,
83
+ )
67
84
68
85
override fun endpointCustomizations (codegenContext : ClientCodegenContext ): List <EndpointCustomization > {
69
- return listOf (object : EndpointCustomization {
70
- override fun setBuiltInOnServiceConfig (name : String , value : Node , configBuilderRef : String ): Writable ? {
71
- if (! name.startsWith(" AWS::S3" )) {
72
- return null
73
- }
74
- val builtIn = codegenContext.getBuiltIn(name) ? : return null
75
- return writable {
76
- rustTemplate(
77
- " let $configBuilderRef = $configBuilderRef .${builtIn.name.rustName()} (#{value});" ,
78
- " value" to value.toWritable(),
79
- )
86
+ return listOf (
87
+ object : EndpointCustomization {
88
+ override fun setBuiltInOnServiceConfig (name : String , value : Node , configBuilderRef : String ): Writable ? {
89
+ if (! name.startsWith(" AWS::S3" )) {
90
+ return null
91
+ }
92
+ val builtIn = codegenContext.getBuiltIn(name) ? : return null
93
+ return writable {
94
+ rustTemplate(
95
+ " let $configBuilderRef = $configBuilderRef .${builtIn.name.rustName()} (#{value});" ,
96
+ " value" to value.toWritable(),
97
+ )
98
+ }
80
99
}
81
- }
82
- },
100
+ },
83
101
)
84
102
}
85
103
@@ -88,6 +106,28 @@ class S3Decorator : ClientCodegenDecorator {
88
106
}
89
107
}
90
108
109
+ class FilterEndpointTests (
110
+ private val testFilter : (EndpointTestCase ) -> EndpointTestCase ? = { a -> a },
111
+ private val operationInputFilter : (EndpointTestOperationInput ) -> EndpointTestOperationInput ? = { a -> a },
112
+ ) {
113
+ fun updateEndpointTests (endpointTests : List <EndpointTestCase >): List <EndpointTestCase > {
114
+ val filteredTests = endpointTests.mapNotNull { test -> testFilter(test) }
115
+ return filteredTests.map { test ->
116
+ val operationInputs = test.operationInputs
117
+ test.toBuilder().operationInputs(operationInputs.mapNotNull { operationInputFilter(it) }).build()
118
+ }
119
+ }
120
+
121
+ fun transform (model : Model ) = ModelTransformer .create().mapTraits(model) { _, trait ->
122
+ when (trait) {
123
+ is EndpointTestsTrait -> EndpointTestsTrait .builder().testCases(updateEndpointTests(trait.testCases))
124
+ .version(trait.version).build()
125
+
126
+ else -> trait
127
+ }
128
+ }
129
+ }
130
+
91
131
class S3ProtocolOverride (codegenContext : CodegenContext ) : RestXml(codegenContext) {
92
132
private val runtimeConfig = codegenContext.runtimeConfig
93
133
private val errorScope = arrayOf(
0 commit comments