@@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.MemberShape
14
14
import software.amazon.smithy.model.shapes.OperationShape
15
15
import software.amazon.smithy.model.shapes.Shape
16
16
import software.amazon.smithy.model.shapes.StructureShape
17
+ import software.amazon.smithy.model.traits.EnumTrait
17
18
import software.amazon.smithy.model.traits.HttpTrait
18
19
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
19
20
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
@@ -32,6 +33,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional
32
33
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
33
34
import software.amazon.smithy.rust.codegen.core.util.dq
34
35
import software.amazon.smithy.rust.codegen.core.util.expectMember
36
+ import software.amazon.smithy.rust.codegen.core.util.hasTrait
35
37
import software.amazon.smithy.rust.codegen.core.util.inputShape
36
38
37
39
fun HttpTrait.uriFormatString (): String {
@@ -54,7 +56,7 @@ fun SmithyPattern.rustFormatString(prefix: String, separator: String): String {
54
56
* Generates methods to serialize and deserialize requests based on the HTTP trait. Specifically:
55
57
* 1. `fn update_http_request(builder: http::request::Builder) -> Builder`
56
58
*
57
- * This method takes a builder (perhaps pre configured with some headers) from the caller and sets the HTTP
59
+ * This method takes a builder (perhaps pre- configured with some headers) from the caller and sets the HTTP
58
60
* headers & URL based on the HTTP trait implementation.
59
61
*/
60
62
class RequestBindingGenerator (
@@ -71,7 +73,7 @@ class RequestBindingGenerator(
71
73
private val httpBindingGenerator =
72
74
HttpBindingGenerator (protocol, codegenContext, codegenContext.symbolProvider, operationShape, ::builderSymbol)
73
75
private val index = HttpBindingIndex .of(model)
74
- private val Encoder = CargoDependency .smithyTypes(runtimeConfig).toType().member(" primitive::Encoder" )
76
+ private val encoder = CargoDependency .smithyTypes(runtimeConfig).toType().member(" primitive::Encoder" )
75
77
76
78
private val codegenScope = arrayOf(
77
79
" BuildError" to runtimeConfig.operationBuildError(),
@@ -207,24 +209,58 @@ class RequestBindingGenerator(
207
209
val memberShape = param.member
208
210
val memberSymbol = symbolProvider.toSymbol(memberShape)
209
211
val memberName = symbolProvider.toMemberName(memberShape)
210
- val outerTarget = model.expectShape(memberShape.target)
211
- ifSet(outerTarget, memberSymbol, " &_input.$memberName " ) { field ->
212
- // if `param` is a list, generate another level of iteration
213
- listForEach(outerTarget, field) { innerField, targetId ->
214
- val target = model.expectShape(targetId)
215
- rust(
216
- " query.push_kv(${param.locationName.dq()} , ${
217
- paramFmtFun(writer, target, memberShape, innerField)
218
- } );" ,
212
+ val target = model.expectShape(memberShape.target)
213
+
214
+ if (memberShape.isRequired) {
215
+ val codegenScope = arrayOf(
216
+ " BuildError" to OperationBuildError (runtimeConfig).missingField(
217
+ memberName,
218
+ " cannot be empty or unset" ,
219
+ ),
220
+ )
221
+ val derefName = safeName(" inner" )
222
+ rust(" let $derefName = &_input.$memberName ;" )
223
+ if (memberSymbol.isOptional()) {
224
+ rustTemplate(
225
+ " let $derefName = $derefName .as_ref().ok_or_else(|| #{BuildError:W})?;" ,
226
+ * codegenScope,
219
227
)
220
228
}
229
+
230
+ // Strings that aren't enums must be checked to see if they're empty
231
+ if (target.isStringShape && ! target.hasTrait<EnumTrait >()) {
232
+ rustBlock(" if $derefName .is_empty()" ) {
233
+ rustTemplate(" return Err(#{BuildError:W});" , * codegenScope)
234
+ }
235
+ }
236
+
237
+ paramList(target, derefName, param, writer, memberShape)
238
+ } else {
239
+ ifSet(target, memberSymbol, " &_input.$memberName " ) { field ->
240
+ // if `param` is a list, generate another level of iteration
241
+ paramList(target, field, param, writer, memberShape)
242
+ }
221
243
}
222
244
}
223
245
writer.rust(" Ok(())" )
224
246
}
225
247
return true
226
248
}
227
249
250
+ private fun RustWriter.paramList (
251
+ outerTarget : Shape ,
252
+ field : String ,
253
+ param : HttpBinding ,
254
+ writer : RustWriter ,
255
+ memberShape : MemberShape ,
256
+ ) {
257
+ listForEach(outerTarget, field) { innerField, targetId ->
258
+ val target = model.expectShape(targetId)
259
+ val value = paramFmtFun(writer, target, memberShape, innerField)
260
+ rust(""" query.push_kv("${param.locationName} ", $value );""" )
261
+ }
262
+ }
263
+
228
264
/* *
229
265
* Format [member] when used as a queryParam
230
266
*/
@@ -234,18 +270,21 @@ class RequestBindingGenerator(
234
270
val func = writer.format(RuntimeType .QueryFormat (runtimeConfig, " fmt_string" ))
235
271
" &$func (&$targetName )"
236
272
}
273
+
237
274
target.isTimestampShape -> {
238
275
val timestampFormat =
239
276
index.determineTimestampFormat(member, HttpBinding .Location .QUERY , protocol.defaultTimestampFormat)
240
277
val timestampFormatType = RuntimeType .TimestampFormat (runtimeConfig, timestampFormat)
241
278
val func = writer.format(RuntimeType .QueryFormat (runtimeConfig, " fmt_timestamp" ))
242
279
" &$func ($targetName , ${writer.format(timestampFormatType)} )?"
243
280
}
281
+
244
282
target.isListShape || target.isMemberShape -> {
245
283
throw IllegalArgumentException (" lists should be handled at a higher level" )
246
284
}
285
+
247
286
else -> {
248
- " ${writer.format(Encoder )} ::from(${autoDeref(targetName)} ).encode()"
287
+ " ${writer.format(encoder )} ::from(${autoDeref(targetName)} ).encode()"
249
288
}
250
289
}
251
290
}
@@ -272,17 +311,19 @@ class RequestBindingGenerator(
272
311
}
273
312
rust(" let $outputVar = $func ($input , #T);" , encodingStrategy)
274
313
}
314
+
275
315
target.isTimestampShape -> {
276
316
val timestampFormat =
277
317
index.determineTimestampFormat(member, HttpBinding .Location .LABEL , protocol.defaultTimestampFormat)
278
318
val timestampFormatType = RuntimeType .TimestampFormat (runtimeConfig, timestampFormat)
279
319
val func = format(RuntimeType .LabelFormat (runtimeConfig, " fmt_timestamp" ))
280
320
rust(" let $outputVar = $func ($input , ${format(timestampFormatType)} )?;" )
281
321
}
322
+
282
323
else -> {
283
324
rust(
284
325
" let mut ${outputVar} _encoder = #T::from(${autoDeref(input)} ); let $outputVar = ${outputVar} _encoder.encode();" ,
285
- Encoder ,
326
+ encoder ,
286
327
)
287
328
}
288
329
}
0 commit comments