Skip to content

Commit eac52eb

Browse files
authored
Fix event stream :content-type for struct messages (#3603)
Event stream operations with struct shaped messages were using the wrong `:content-type` message header value, which I think wasn't caught before since the supported AWS S3/Transcribe event stream operations don't serialize struct messages. This PR fixes the message content type serialization. ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._
1 parent 5461e4f commit eac52eb

File tree

9 files changed

+79
-18
lines changed

9 files changed

+79
-18
lines changed

CHANGELOG.next.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,9 @@ message = "SDK crates now set the `rust-version` property in their Cargo.toml fi
5858
references = ["smithy-rs#3601"]
5959
meta = { "breaking" = false, "tada" = true, "bug" = false }
6060
author = "jdisanti"
61+
62+
[[smithy-rs]]
63+
message = "Fix event stream `:content-type` message headers for struct messages. Note: this was the `:content-type` header on individual event message frames that was incorrect, not the HTTP `content-type` header for the initial request."
64+
references = ["smithy-rs#3603"]
65+
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "all" }
66+
author = "jdisanti"

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ class AwsJsonHttpBindingResolver(
8383
"application/x-amz-json-${awsJsonVersion.value}"
8484

8585
override fun responseContentType(operationShape: OperationShape): String = requestContentType(operationShape)
86+
87+
override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
88+
ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/json")
8689
}
8790

8891
/**

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package software.amazon.smithy.rust.codegen.core.smithy.protocols
77

8+
import software.amazon.smithy.model.shapes.MemberShape
89
import software.amazon.smithy.model.shapes.OperationShape
910
import software.amazon.smithy.model.shapes.ToShapeId
1011
import software.amazon.smithy.model.traits.HttpTrait
@@ -38,6 +39,9 @@ class AwsQueryCompatibleHttpBindingResolver(
3839

3940
override fun responseContentType(operationShape: OperationShape): String =
4041
awsJsonHttpBindingResolver.requestContentType(operationShape)
42+
43+
override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
44+
awsJsonHttpBindingResolver.eventStreamMessageContentType(memberShape)
4145
}
4246

4347
class AwsQueryCompatible(

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols
88
import software.amazon.smithy.model.Model
99
import software.amazon.smithy.model.knowledge.HttpBinding
1010
import software.amazon.smithy.model.knowledge.HttpBindingIndex
11+
import software.amazon.smithy.model.shapes.BlobShape
1112
import software.amazon.smithy.model.shapes.MemberShape
1213
import software.amazon.smithy.model.shapes.OperationShape
14+
import software.amazon.smithy.model.shapes.StringShape
1315
import software.amazon.smithy.model.shapes.ToShapeId
1416
import software.amazon.smithy.model.traits.HttpTrait
1517
import software.amazon.smithy.model.traits.TimestampFormatTrait
@@ -98,6 +100,11 @@ interface HttpBindingResolver {
98100
* Determines the response content type for given [operationShape].
99101
*/
100102
fun responseContentType(operationShape: OperationShape): String?
103+
104+
/**
105+
* Determines the value of the event stream `:content-type` header based on union member
106+
*/
107+
fun eventStreamMessageContentType(memberShape: MemberShape): String?
101108
}
102109

103110
/**
@@ -108,20 +115,38 @@ data class ProtocolContentTypes(
108115
val requestDocument: String? = null,
109116
/** Response content type override for when the shape is a Document */
110117
val responseDocument: String? = null,
111-
/** EventStream content type */
118+
/** EventStream content type initial request/response content-type */
112119
val eventStreamContentType: String? = null,
120+
/** EventStream content type for struct message shapes (for `:content-type`) */
121+
val eventStreamMessageContentType: String? = null,
113122
) {
114123
companion object {
115124
/** Create an instance of [ProtocolContentTypes] where all content types are the same */
116-
fun consistent(type: String) = ProtocolContentTypes(type, type, type)
125+
fun consistent(type: String) = ProtocolContentTypes(type, type, type, type)
126+
127+
/**
128+
* Returns the event stream message `:content-type` for the given event stream union member shape.
129+
*
130+
* The `protocolContentType` is the content-type to use for non-string/non-blob shapes.
131+
*/
132+
fun eventStreamMemberContentType(
133+
model: Model,
134+
memberShape: MemberShape,
135+
protocolContentType: String?,
136+
): String? =
137+
when (model.expectShape(memberShape.target)) {
138+
is StringShape -> "text/plain"
139+
is BlobShape -> "application/octet-stream"
140+
else -> protocolContentType
141+
}
117142
}
118143
}
119144

120145
/**
121146
* An [HttpBindingResolver] that relies on the HttpTrait data in the Smithy models.
122147
*/
123148
open class HttpTraitHttpBindingResolver(
124-
model: Model,
149+
private val model: Model,
125150
private val contentTypes: ProtocolContentTypes,
126151
) : HttpBindingResolver {
127152
private val httpIndex: HttpBindingIndex = HttpBindingIndex.of(model)
@@ -158,6 +183,9 @@ open class HttpTraitHttpBindingResolver(
158183
contentTypes.eventStreamContentType,
159184
).orNull()
160185

186+
override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
187+
ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, contentTypes.eventStreamMessageContentType)
188+
161189
// Sort the members after extracting them from the map to have a consistent order
162190
private fun mappedBindings(bindings: Map<String, HttpBinding>): List<HttpBindingDescriptor> =
163191
bindings.values.map(::HttpBindingDescriptor).sortedBy { it.memberName }
@@ -172,6 +200,7 @@ open class StaticHttpBindingResolver(
172200
private val httpTrait: HttpTrait,
173201
private val requestContentType: String,
174202
private val responseContentType: String,
203+
private val eventStreamMessageContentType: String? = null,
175204
) : HttpBindingResolver {
176205
private fun bindings(shape: ToShapeId?) =
177206
shape?.let { model.expectShape(it.toShapeId()) }?.members()
@@ -192,4 +221,7 @@ open class StaticHttpBindingResolver(
192221
override fun requestContentType(operationShape: OperationShape): String = requestContentType
193222

194223
override fun responseContentType(operationShape: OperationShape): String = responseContentType
224+
225+
override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
226+
ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, eventStreamMessageContentType)
195227
}

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ class HttpBoundProtocolPayloadGenerator(
197197
if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) {
198198
val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName)
199199
writer.serializeViaEventStream(
200-
operationShape,
201200
payloadMember,
202201
serializerGenerator,
203202
shapeName,
@@ -206,7 +205,6 @@ class HttpBoundProtocolPayloadGenerator(
206205
} else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) {
207206
val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName)
208207
writer.serializeViaEventStream(
209-
operationShape,
210208
payloadMember,
211209
serializerGenerator,
212210
"output",
@@ -239,7 +237,6 @@ class HttpBoundProtocolPayloadGenerator(
239237
}
240238

241239
private fun RustWriter.serializeViaEventStream(
242-
operationShape: OperationShape,
243240
memberShape: MemberShape,
244241
serializerGenerator: StructuredDataSerializerGenerator,
245242
outerName: String,
@@ -248,11 +245,10 @@ class HttpBoundProtocolPayloadGenerator(
248245
val memberName = symbolProvider.toMemberName(memberShape)
249246
val unionShape = model.expectShape(memberShape.target, UnionShape::class.java)
250247

251-
val contentType =
252-
when (target) {
253-
CodegenTarget.CLIENT -> httpBindingResolver.requestContentType(operationShape)
254-
CodegenTarget.SERVER -> httpBindingResolver.responseContentType(operationShape)
255-
}
248+
val payloadContentType =
249+
httpBindingResolver.eventStreamMessageContentType(memberShape)
250+
?: throw CodegenException("event streams must set a content type")
251+
256252
val errorMarshallerConstructorFn =
257253
EventStreamErrorMarshallerGenerator(
258254
model,
@@ -261,7 +257,7 @@ class HttpBoundProtocolPayloadGenerator(
261257
symbolProvider,
262258
unionShape,
263259
serializerGenerator,
264-
contentType ?: throw CodegenException("event streams must set a content type"),
260+
payloadContentType,
265261
).render()
266262
val marshallerConstructorFn =
267263
EventStreamMarshallerGenerator(
@@ -271,7 +267,7 @@ class HttpBoundProtocolPayloadGenerator(
271267
symbolProvider,
272268
unionShape,
273269
serializerGenerator,
274-
contentType,
270+
payloadContentType,
275271
).render()
276272

277273
// TODO(EventStream): [RPC] RPC protocols need to send an initial message with the

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,15 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol {
7474
)
7575

7676
override val httpBindingResolver: HttpBindingResolver =
77-
RestJsonHttpBindingResolver(codegenContext.model, ProtocolContentTypes("application/json", "application/json", "application/vnd.amazon.eventstream"))
77+
RestJsonHttpBindingResolver(
78+
codegenContext.model,
79+
ProtocolContentTypes(
80+
requestDocument = "application/json",
81+
responseDocument = "application/json",
82+
eventStreamContentType = "application/vnd.amazon.eventstream",
83+
eventStreamMessageContentType = "application/json",
84+
),
85+
)
7886

7987
override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS
8088

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,15 @@ open class RestXml(val codegenContext: CodegenContext) : Protocol {
3636
}
3737

3838
override val httpBindingResolver: HttpBindingResolver =
39-
HttpTraitHttpBindingResolver(codegenContext.model, ProtocolContentTypes("application/xml", "application/xml", "application/vnd.amazon.eventstream"))
39+
HttpTraitHttpBindingResolver(
40+
codegenContext.model,
41+
ProtocolContentTypes(
42+
requestDocument = "application/xml",
43+
responseDocument = "application/xml",
44+
eventStreamContentType = "application/vnd.amazon.eventstream",
45+
eventStreamMessageContentType = "application/xml",
46+
),
47+
)
4048

4149
override val defaultTimestampFormat: TimestampFormatTrait.Format =
4250
TimestampFormatTrait.Format.DATE_TIME

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ object EventStreamMarshallTestCases {
111111
let headers = headers_to_map(message.headers());
112112
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
113113
assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap());
114-
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
114+
assert_eq!(&str_header(${testCase.eventStreamMessageContentType.dq()}), *headers.get(":content-type").unwrap());
115115
116116
validate_body(
117117
message.payload(),
@@ -146,7 +146,7 @@ object EventStreamMarshallTestCases {
146146
let headers = headers_to_map(message.headers());
147147
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
148148
assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap());
149-
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
149+
assert_eq!(&str_header(${testCase.eventStreamMessageContentType.dq()}), *headers.get(":content-type").unwrap());
150150
151151
validate_body(
152152
message.payload(),
@@ -236,7 +236,7 @@ object EventStreamMarshallTestCases {
236236
let headers = headers_to_map(message.headers());
237237
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
238238
assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap());
239-
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
239+
assert_eq!(&str_header(${testCase.eventStreamMessageContentType.dq()}), *headers.get(":content-type").unwrap());
240240
241241
validate_body(
242242
message.payload(),

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ object EventStreamTestModels {
109109
val mediaType: String,
110110
val requestContentType: String,
111111
val responseContentType: String,
112+
val eventStreamMessageContentType: String,
112113
val validTestStruct: String,
113114
val validMessageWithNoHeaderPayloadTraits: String,
114115
val validTestUnion: String,
@@ -130,6 +131,7 @@ object EventStreamTestModels {
130131
mediaType = "application/json",
131132
requestContentType = "application/vnd.amazon.eventstream",
132133
responseContentType = "application/json",
134+
eventStreamMessageContentType = "application/json",
133135
validTestStruct = """{"someString":"hello","someInt":5}""",
134136
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
135137
validTestUnion = """{"Foo":"hello"}""",
@@ -145,6 +147,7 @@ object EventStreamTestModels {
145147
mediaType = "application/x-amz-json-1.1",
146148
requestContentType = "application/x-amz-json-1.1",
147149
responseContentType = "application/x-amz-json-1.1",
150+
eventStreamMessageContentType = "application/json",
148151
validTestStruct = """{"someString":"hello","someInt":5}""",
149152
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
150153
validTestUnion = """{"Foo":"hello"}""",
@@ -160,6 +163,7 @@ object EventStreamTestModels {
160163
mediaType = "application/xml",
161164
requestContentType = "application/vnd.amazon.eventstream",
162165
responseContentType = "application/xml",
166+
eventStreamMessageContentType = "application/xml",
163167
validTestStruct =
164168
"""
165169
<TestStruct>

0 commit comments

Comments
 (0)