Skip to content

Commit a1ec8b8

Browse files
authored
Allow injecting methods with generic type parameters in the config object (#3274)
This is a follow-up to #3111. Currently, the injected methods are limited to taking in concrete types. This PR allows for these methods to take in generic type parameters as well. ```rust impl<L, H, M> SimpleServiceConfigBuilder<L, H, M> { pub fn aws_auth<C>(config: C) { ... } } ``` ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._
1 parent 0d6cf72 commit a1ec8b8

File tree

2 files changed

+166
-71
lines changed

2 files changed

+166
-71
lines changed

codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt

Lines changed: 91 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package software.amazon.smithy.rust.codegen.server.smithy.generators
77

8+
import software.amazon.smithy.codegen.core.CodegenException
89
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
910
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
1011
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
@@ -13,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.join
1314
import software.amazon.smithy.rust.codegen.core.rustlang.rust
1415
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
1516
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
17+
import software.amazon.smithy.rust.codegen.core.rustlang.rustTypeParameters
1618
import software.amazon.smithy.rust.codegen.core.rustlang.writable
1719
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
1820
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
@@ -33,7 +35,7 @@ data class ConfigMethod(
3335
val docs: String,
3436
/** The parameters of the method. **/
3537
val params: List<Binding>,
36-
/** In case the method is fallible, the error type it returns. **/
38+
/** In case the method is fallible, the concrete error type it returns. **/
3739
val errorType: RuntimeType?,
3840
/** The code block inside the method. **/
3941
val initializer: Initializer,
@@ -104,15 +106,45 @@ data class Initializer(
104106
* }
105107
*
106108
* has two variable bindings. The `bar` name is bound to a `String` variable and the `baz` name is bound to a
107-
* `u64` variable.
109+
* `u64` variable. Both are bindings that use concrete types. Types can also be generic:
110+
*
111+
* ```rust
112+
* fn foo<T>(bar: T) { }
108113
* ```
109114
*/
110-
data class Binding(
111-
/** The name of the variable. */
112-
val name: String,
113-
/** The type of the variable. */
114-
val ty: RuntimeType,
115-
)
115+
sealed class Binding {
116+
data class Generic(
117+
/** The name of the variable. The name of the type parameter will be the PascalCased variable name. */
118+
val name: String,
119+
/** The type of the variable. */
120+
val ty: RuntimeType,
121+
/**
122+
* The generic type parameters contained in `ty`. For example, if `ty` renders to `Vec<T>` with `T` being a
123+
* generic type parameter, then `genericTys` should be a singleton set containing `"T"`.
124+
* You can't use `L`, `H`, or `M` as the names to refer to any generic types.
125+
* */
126+
val genericTys: Set<String>,
127+
) : Binding()
128+
129+
data class Concrete(
130+
/** The name of the variable. */
131+
val name: String,
132+
/** The type of the variable. */
133+
val ty: RuntimeType,
134+
) : Binding()
135+
136+
fun name() =
137+
when (this) {
138+
is Concrete -> this.name
139+
is Generic -> this.name
140+
}
141+
142+
fun ty() =
143+
when (this) {
144+
is Concrete -> this.ty
145+
is Generic -> this.ty
146+
}
147+
}
116148

117149
class ServiceConfigGenerator(
118150
codegenContext: ServerCodegenContext,
@@ -271,10 +303,10 @@ class ServiceConfigGenerator(
271303
writable {
272304
rustTemplate(
273305
"""
274-
if !self.${it.requiredBuilderFlagName()} {
275-
return #{Err}(${serviceName}ConfigError::${it.requiredErrorVariant()});
276-
}
277-
""",
306+
if !self.${it.requiredBuilderFlagName()} {
307+
return #{Err}(${serviceName}ConfigError::${it.requiredErrorVariant()});
308+
}
309+
""",
278310
*codegenScope,
279311
)
280312
}
@@ -303,19 +335,19 @@ class ServiceConfigGenerator(
303335
writable {
304336
rust(
305337
"""
306-
##[error("service is not fully configured; invoke `${it.name}` on the config builder")]
307-
${it.requiredErrorVariant()},
308-
""",
338+
##[error("service is not fully configured; invoke `${it.name}` on the config builder")]
339+
${it.requiredErrorVariant()},
340+
""",
309341
)
310342
}
311343
}
312344
rustTemplate(
313345
"""
314-
##[derive(Debug, #{ThisError}::Error)]
315-
pub enum ${serviceName}ConfigError {
316-
#{Variants:W}
317-
}
318-
""",
346+
##[derive(Debug, #{ThisError}::Error)]
347+
pub enum ${serviceName}ConfigError {
348+
#{Variants:W}
349+
}
350+
""",
319351
"ThisError" to ServerCargoDependency.ThisError.toType(),
320352
"Variants" to variants.join("\n"),
321353
)
@@ -327,8 +359,20 @@ class ServiceConfigGenerator(
327359
writable {
328360
val paramBindings =
329361
it.params.map { binding ->
330-
writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) }
362+
writable { rustTemplate("${binding.name()}: #{BindingTy},", "BindingTy" to binding.ty()) }
331363
}.join("\n")
364+
val genericBindings = it.params.filterIsInstance<Binding.Generic>()
365+
val lhmBindings =
366+
genericBindings.filter {
367+
it.genericTys.contains("L") || it.genericTys.contains("H") || it.genericTys.contains("M")
368+
}
369+
if (lhmBindings.isNotEmpty()) {
370+
throw CodegenException(
371+
"Injected config method `${it.name}` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings: $lhmBindings",
372+
)
373+
}
374+
val paramBindingsGenericTys = genericBindings.flatMap { it.genericTys }.toSet()
375+
val paramBindingsGenericsWritable = rustTypeParameters(*paramBindingsGenericTys.toTypedArray())
332376

333377
// This produces a nested type like: "S<B, S<A, T>>", where
334378
// - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack
@@ -345,7 +389,7 @@ class ServiceConfigGenerator(
345389
rustTemplate(
346390
"#{StackType}<#{Ty}, #{Acc:W}>",
347391
"StackType" to stackType,
348-
"Ty" to next.ty,
392+
"Ty" to next.ty(),
349393
"Acc" to acc,
350394
)
351395
}
@@ -362,12 +406,12 @@ class ServiceConfigGenerator(
362406
writable {
363407
rustTemplate(
364408
"""
365-
${serviceName}ConfigBuilder<
366-
#{LayersReturnTy:W},
367-
#{HttpPluginsReturnTy:W},
368-
#{ModelPluginsReturnTy:W},
369-
>
370-
""",
409+
${serviceName}ConfigBuilder<
410+
#{LayersReturnTy:W},
411+
#{HttpPluginsReturnTy:W},
412+
#{ModelPluginsReturnTy:W},
413+
>
414+
""",
371415
"LayersReturnTy" to layersReturnTy,
372416
"HttpPluginsReturnTy" to httpPluginsReturnTy,
373417
"ModelPluginsReturnTy" to modelPluginsReturnTy,
@@ -391,14 +435,15 @@ class ServiceConfigGenerator(
391435
docs(it.docs)
392436
rustBlockTemplate(
393437
"""
394-
pub fn ${it.name}(
395-
##[allow(unused_mut)]
396-
mut self,
397-
#{ParamBindings:W}
398-
) -> #{ReturnTy:W}
399-
""",
438+
pub fn ${it.name}#{ParamBindingsGenericsWritable}(
439+
##[allow(unused_mut)]
440+
mut self,
441+
#{ParamBindings:W}
442+
) -> #{ReturnTy:W}
443+
""",
400444
"ReturnTy" to returnTy,
401445
"ParamBindings" to paramBindings,
446+
"ParamBindingsGenericsWritable" to paramBindingsGenericsWritable,
402447
) {
403448
rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code)
404449

@@ -412,9 +457,9 @@ class ServiceConfigGenerator(
412457
conditionalBlock("Ok(", ")", conditional = it.errorType != null) {
413458
val registrations =
414459
(
415-
it.initializer.layerBindings.map { ".layer(${it.name})" } +
416-
it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } +
417-
it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" }
460+
it.initializer.layerBindings.map { ".layer(${it.name()})" } +
461+
it.initializer.httpPluginBindings.map { ".http_plugin(${it.name()})" } +
462+
it.initializer.modelPluginBindings.map { ".model_plugin(${it.name()})" }
418463
).joinToString("")
419464
rust("self$registrations")
420465
}
@@ -437,9 +482,9 @@ class ServiceConfigGenerator(
437482
writable {
438483
rustBlockTemplate(
439484
"""
440-
/// Build the configuration.
441-
pub fn build(self) -> #{BuilderBuildReturnTy:W}
442-
""",
485+
/// Build the configuration.
486+
pub fn build(self) -> #{BuilderBuildReturnTy:W}
487+
""",
443488
"BuilderBuildReturnTy" to builderBuildReturnType(),
444489
) {
445490
rustTemplate(
@@ -450,12 +495,12 @@ class ServiceConfigGenerator(
450495
conditionalBlock("Ok(", ")", isBuilderFallible) {
451496
rust(
452497
"""
453-
super::${serviceName}Config {
454-
layers: self.layers,
455-
http_plugins: self.http_plugins,
456-
model_plugins: self.model_plugins,
457-
}
458-
""",
498+
super::${serviceName}Config {
499+
layers: self.layers,
500+
http_plugins: self.http_plugins,
501+
model_plugins: self.model_plugins,
502+
}
503+
""",
459504
)
460505
}
461506
}

0 commit comments

Comments
 (0)