Skip to content

Commit 4c30f00

Browse files
authored
Refactor determining server error type when deserializing an @httpPayload (#3752)
Determining the error type when deserializing an `@httpPayload` is a protocol-specific concern, and as such should not live in `ServerHttpBoundProtocolGenerator`, which should remain protocol-agnostic. This commits makes that determination part of the `ServerProtocol` interface. As a drive-by improvement, the companion object in `ServerHttpBoundProtocolGenerator` has also been removed, since its members have been unused for a long time. ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._
1 parent dc66ae4 commit 4c30f00

File tree

3 files changed

+75
-56
lines changed

3 files changed

+75
-56
lines changed

codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
package software.amazon.smithy.rust.codegen.server.smithy.generators.http
77

8-
import software.amazon.smithy.codegen.core.Symbol
98
import software.amazon.smithy.model.shapes.OperationShape
109
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
1110
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
@@ -20,12 +19,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindi
2019
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
2120
import software.amazon.smithy.rust.codegen.core.smithy.mapRustType
2221
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
23-
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
2422
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
23+
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
2524
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape
2625

2726
class ServerRequestBindingGenerator(
28-
protocol: Protocol,
27+
val protocol: ServerProtocol,
2928
codegenContext: ServerCodegenContext,
3029
operationShape: OperationShape,
3130
additionalHttpBindingCustomizations: List<HttpBindingCustomization> = listOf(),
@@ -50,12 +49,11 @@ class ServerRequestBindingGenerator(
5049

5150
fun generateDeserializePayloadFn(
5251
binding: HttpBindingDescriptor,
53-
errorSymbol: Symbol,
5452
structuredHandler: RustWriter.(String) -> Unit,
5553
): RuntimeType =
5654
httpBindingGenerator.generateDeserializePayloadFn(
5755
binding,
58-
errorSymbol,
56+
protocol.deserializePayloadErrorType(binding).toSymbol(),
5957
structuredHandler,
6058
HttpMessageType.REQUEST,
6159
)

codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol
88
import software.amazon.smithy.model.shapes.MemberShape
99
import software.amazon.smithy.model.shapes.OperationShape
1010
import software.amazon.smithy.model.shapes.Shape
11+
import software.amazon.smithy.model.shapes.StringShape
1112
import software.amazon.smithy.model.shapes.StructureShape
1213
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
1314
import software.amazon.smithy.rust.codegen.core.rustlang.rust
@@ -17,7 +18,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
1718
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
1819
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
1920
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
21+
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
2022
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver
23+
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
2124
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
2225
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
2326
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
@@ -70,8 +73,8 @@ interface ServerProtocol : Protocol {
7073
fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType
7174

7275
/**
73-
* In some protocols, such as restJson1,
74-
* when there is no modeled body input, content type must not be set and the body must be empty.
76+
* In some protocols, such as `restJson1` and `rpcv2Cbor`,
77+
* when there is no modeled body input, `content-type` must not be set and the body must be empty.
7578
* Returns a boolean indicating whether to perform this check.
7679
*/
7780
fun serverContentTypeCheckNoModeledInput(): Boolean = false
@@ -90,6 +93,19 @@ interface ServerProtocol : Protocol {
9093
fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
9194
ServerCargoDependency.smithyHttpServer(runtimeConfig)
9295
.toType().resolve("protocol::$protocolModulePath::runtime_error::RuntimeError")
96+
97+
/**
98+
* The function that deserializes a payload-bound shape takes as input a byte slab and returns a `Result` holding
99+
* the deserialized shape if successful. What error type should we use in case of failure?
100+
*
101+
* The shape could be payload-bound either because of the `@httpPayload` trait, or because it's part of an event
102+
* stream.
103+
*
104+
* Note that despite the trait (https://smithy.io/2.0/spec/http-bindings.html#httppayload-trait) being able to
105+
* target any structure member shape, AWS Protocols only support binding the following shape types to the payload
106+
* (and Smithy does indeed enforce this at model build-time): string, blob, structure, union, and document
107+
*/
108+
fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType
93109
}
94110

95111
fun returnSymbolToParseFn(codegenContext: ServerCodegenContext): (Shape) -> ReturnSymbolToParse {
@@ -185,6 +201,18 @@ class ServerAwsJsonProtocol(
185201
override fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
186202
ServerCargoDependency.smithyHttpServer(runtimeConfig)
187203
.toType().resolve("protocol::aws_json::runtime_error::RuntimeError")
204+
205+
/*
206+
* Note that despite the AWS JSON 1.x protocols not supporting the `@httpPayload` trait, event streams are bound
207+
* to the payload.
208+
*/
209+
override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
210+
deserializePayloadErrorType(
211+
codegenContext,
212+
binding,
213+
requestRejection(runtimeConfig),
214+
RuntimeType.smithyJson(codegenContext.runtimeConfig).resolve("deserialize::error::DeserializeError"),
215+
)
188216
}
189217

190218
private fun restRouterType(runtimeConfig: RuntimeConfig) =
@@ -227,6 +255,14 @@ class ServerRestJsonProtocol(
227255
override fun serverRouterRuntimeConstructor() = "new_rest_json_router"
228256

229257
override fun serverContentTypeCheckNoModeledInput() = true
258+
259+
override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
260+
deserializePayloadErrorType(
261+
codegenContext,
262+
binding,
263+
requestRejection(runtimeConfig),
264+
RuntimeType.smithyJson(codegenContext.runtimeConfig).resolve("deserialize::error::DeserializeError"),
265+
)
230266
}
231267

232268
class ServerRestXmlProtocol(
@@ -252,6 +288,32 @@ class ServerRestXmlProtocol(
252288
override fun serverRouterRuntimeConstructor() = "new_rest_xml_router"
253289

254290
override fun serverContentTypeCheckNoModeledInput() = true
291+
292+
override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
293+
deserializePayloadErrorType(
294+
codegenContext,
295+
binding,
296+
requestRejection(runtimeConfig),
297+
RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"),
298+
)
299+
}
300+
301+
/** Just a common function to keep things DRY. **/
302+
fun deserializePayloadErrorType(
303+
codegenContext: CodegenContext,
304+
binding: HttpBindingDescriptor,
305+
requestRejection: RuntimeType,
306+
protocolSerializationFormatError: RuntimeType,
307+
): RuntimeType {
308+
check(binding.location == HttpLocation.PAYLOAD)
309+
310+
if (codegenContext.model.expectShape(binding.member.target) is StringShape) {
311+
// The only way deserializing a string can fail is if the HTTP body does not contain valid UTF-8.
312+
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3750): we're returning an incorrect `RequestRejection` variant here.
313+
return requestRejection
314+
}
315+
316+
return protocolSerializationFormatError
255317
}
256318

257319
/**

codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt

Lines changed: 8 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55

66
package software.amazon.smithy.rust.codegen.server.smithy.protocols
77

8-
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
9-
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
10-
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
11-
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
128
import software.amazon.smithy.codegen.core.Symbol
139
import software.amazon.smithy.model.knowledge.HttpBindingIndex
1410
import software.amazon.smithy.model.node.ExpectationNotMetException
@@ -20,7 +16,6 @@ import software.amazon.smithy.model.shapes.MemberShape
2016
import software.amazon.smithy.model.shapes.NumberShape
2117
import software.amazon.smithy.model.shapes.OperationShape
2218
import software.amazon.smithy.model.shapes.Shape
23-
import software.amazon.smithy.model.shapes.StringShape
2419
import software.amazon.smithy.model.shapes.StructureShape
2520
import software.amazon.smithy.model.traits.ErrorTrait
2621
import software.amazon.smithy.model.traits.HttpErrorTrait
@@ -124,13 +119,7 @@ class ServerHttpBoundProtocolGenerator(
124119
) : ServerProtocolGenerator(
125120
protocol,
126121
ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations),
127-
) {
128-
// Define suffixes for operation input / output / error wrappers
129-
companion object {
130-
const val OPERATION_INPUT_WRAPPER_SUFFIX = "OperationInputWrapper"
131-
const val OPERATION_OUTPUT_WRAPPER_SUFFIX = "OperationOutputWrapper"
132-
}
133-
}
122+
)
134123

135124
class ServerHttpBoundProtocolPayloadGenerator(
136125
codegenContext: CodegenContext,
@@ -697,8 +686,6 @@ class ServerHttpBoundProtocolTraitImplGenerator(
697686
inputShape: StructureShape,
698687
bindings: List<HttpBindingDescriptor>,
699688
) {
700-
val httpBindingGenerator =
701-
ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
702689
val structuredDataParser = protocol.structuredDataParser()
703690
Attribute.AllowUnusedMut.render(this)
704691
rust(
@@ -740,7 +727,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
740727
for (binding in bindings) {
741728
val member = binding.member
742729
val parsedValue =
743-
serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser)
730+
serverRenderBindingParser(binding, operationShape, httpBindingGenerator(operationShape), structuredDataParser)
744731
val valueToSet =
745732
if (symbolProvider.toSymbol(binding.member).isOptional()) {
746733
"Some(value)"
@@ -801,13 +788,8 @@ class ServerHttpBoundProtocolTraitImplGenerator(
801788
val structureShapeHandler: RustWriter.(String) -> Unit = { body ->
802789
rust("#T($body)", structuredDataParser.payloadParser(binding.member))
803790
}
804-
val errorSymbol = getDeserializePayloadErrorSymbol(binding)
805791
val deserializer =
806-
httpBindingGenerator.generateDeserializePayloadFn(
807-
binding,
808-
errorSymbol,
809-
structuredHandler = structureShapeHandler,
810-
)
792+
httpBindingGenerator.generateDeserializePayloadFn(binding, structuredHandler = structureShapeHandler)
811793
return writable {
812794
if (binding.member.isStreaming(model)) {
813795
rustTemplate(
@@ -1196,9 +1178,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
11961178
binding: HttpBindingDescriptor,
11971179
operationShape: OperationShape,
11981180
) {
1199-
val httpBindingGenerator =
1200-
ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
1201-
val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding)
1181+
val deserializer = httpBindingGenerator(operationShape).generateDeserializeHeaderFn(binding)
12021182
writer.rustTemplate(
12031183
"""
12041184
#{deserializer}(&headers)?
@@ -1215,8 +1195,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
12151195
) {
12161196
check(binding.location == HttpLocation.PREFIX_HEADERS)
12171197

1218-
val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape)
1219-
val deserializer = httpBindingGenerator.generateDeserializePrefixHeadersFn(binding)
1198+
val deserializer = httpBindingGenerator(operationShape).generateDeserializePrefixHeadersFn(binding)
12201199
writer.rustTemplate(
12211200
"""
12221201
#{deserializer}(&headers)?
@@ -1300,33 +1279,13 @@ class ServerHttpBoundProtocolTraitImplGenerator(
13001279
}
13011280
}
13021281

1303-
/**
1304-
* Returns the error type of the function that deserializes a non-streaming HTTP payload (a byte slab) into the
1305-
* shape targeted by the `httpPayload` trait.
1306-
*/
1307-
private fun getDeserializePayloadErrorSymbol(binding: HttpBindingDescriptor): Symbol {
1308-
check(binding.location == HttpLocation.PAYLOAD)
1309-
1310-
if (model.expectShape(binding.member.target) is StringShape) {
1311-
return protocol.requestRejection(runtimeConfig).toSymbol()
1312-
}
1313-
return when (codegenContext.protocol) {
1314-
RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> {
1315-
RuntimeType.smithyJson(runtimeConfig).resolve("deserialize::error::DeserializeError").toSymbol()
1316-
}
1317-
RestXmlTrait.ID -> {
1318-
RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError").toSymbol()
1319-
}
1320-
else -> {
1321-
TODO("Protocol ${codegenContext.protocol} not supported yet")
1322-
}
1323-
}
1324-
}
1325-
13261282
private fun streamingBodyTraitBounds(operationShape: OperationShape) =
13271283
if (operationShape.inputShape(model).hasStreamingMember(model)) {
13281284
"\n B: Into<#{SmithyTypes}::byte_stream::ByteStream>,"
13291285
} else {
13301286
""
13311287
}
1288+
1289+
private fun httpBindingGenerator(operationShape: OperationShape) =
1290+
ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
13321291
}

0 commit comments

Comments
 (0)