From 38c03ab70812c6510cf5a8499763ae893f54ed60 Mon Sep 17 00:00:00 2001 From: Jake Park Date: Mon, 7 Jul 2025 11:06:35 +0900 Subject: [PATCH] Add `enums_other_variant` option --- graphql_client_cli/src/generate.rs | 7 +- graphql_client_cli/src/main.rs | 10 ++- graphql_client_codegen/src/codegen/enums.rs | 26 +++++- graphql_client_codegen/src/codegen_options.rs | 13 +++ graphql_client_codegen/src/tests/mod.rs | 88 +++++++++++++++++++ graphql_query_derive/src/attributes.rs | 68 ++++++++++++++ graphql_query_derive/src/lib.rs | 4 +- 7 files changed, 208 insertions(+), 8 deletions(-) diff --git a/graphql_client_cli/src/generate.rs b/graphql_client_cli/src/generate.rs index 1a36d0cf..0768c434 100644 --- a/graphql_client_cli/src/generate.rs +++ b/graphql_client_cli/src/generate.rs @@ -22,6 +22,7 @@ pub(crate) struct CliCodegenParams { pub output_directory: Option, pub custom_scalars_module: Option, pub fragments_other_variant: bool, + pub enums_other_variant: bool, pub external_enums: Option>, pub custom_variable_types: Option, pub custom_response_type: Option, @@ -42,6 +43,7 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> { selected_operation, custom_scalars_module, fragments_other_variant, + enums_other_variant, external_enums, custom_variable_types, custom_response_type, @@ -66,6 +68,7 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> { }); options.set_fragments_other_variant(fragments_other_variant); + options.set_enums_other_variant(enums_other_variant); if let Some(selected_operation) = selected_operation { options.set_operation_name(selected_operation); @@ -93,11 +96,11 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> { options.set_custom_scalars_module(custom_scalars_module); } - + if let Some(custom_variable_types) = custom_variable_types { options.set_custom_variable_types(custom_variable_types.split(",").map(String::from).collect()); } - + if let Some(custom_response_type) = custom_response_type { options.set_custom_response_type(custom_response_type); } diff --git a/graphql_client_cli/src/main.rs b/graphql_client_cli/src/main.rs index 6721953b..7e8a71a5 100644 --- a/graphql_client_cli/src/main.rs +++ b/graphql_client_cli/src/main.rs @@ -91,6 +91,10 @@ enum Cli { /// --fragments-other-variant #[clap(long = "fragments-other-variant")] fragments_other_variant: bool, + /// A flag indicating if the enums should have an "Other" variant for unknown values + /// --enums-other-variant + #[clap(long = "enums-other-variant")] + enums_other_variant: bool, /// List of externally defined enum types. Type names must match those used in the schema exactly #[clap(long = "external-enums", num_args(0..), action(clap::ArgAction::Append))] external_enums: Option>, @@ -139,8 +143,9 @@ fn main() -> CliResult<()> { selected_operation, custom_scalars_module, fragments_other_variant, - external_enums, - custom_variable_types, + enums_other_variant, + external_enums, + custom_variable_types, custom_response_type, } => generate::generate_code(generate::CliCodegenParams { query_path, @@ -154,6 +159,7 @@ fn main() -> CliResult<()> { output_directory, custom_scalars_module, fragments_other_variant, + enums_other_variant, external_enums, custom_variable_types, custom_response_type, diff --git a/graphql_client_codegen/src/codegen/enums.rs b/graphql_client_codegen/src/codegen/enums.rs index adc59a89..261aa2d2 100644 --- a/graphql_client_codegen/src/codegen/enums.rs +++ b/graphql_client_codegen/src/codegen/enums.rs @@ -60,18 +60,38 @@ pub(super) fn generate_enum_definitions<'a, 'schema: 'a>( let name = name_ident; + let other_variant = if *options.enums_other_variant() { + Some(quote!(Other(String),)) + } else { + None + }; + + let other_serialize = if *options.enums_other_variant() { + Some(quote!(#name::Other(ref s) => &s,)) + } else { + None + }; + + let other_deserialize = if *options.enums_other_variant() { + Some(quote!(_ => Ok(#name::Other(s)),)) + } else { + // If no Other variant, we need to handle unknown values + // Return an error for unknown variants when Other is not enabled + Some(quote!(_ => Err(#serde::de::Error::unknown_variant(&s, &[#(#variant_str),*])),)) + }; + quote! { #derives pub enum #name { #(#variant_names,)* - Other(String), + #other_variant } impl #serde::Serialize for #name { fn serialize(&self, ser: S) -> Result { ser.serialize_str(match *self { #(#constructors => #variant_str,)* - #name::Other(ref s) => &s, + #other_serialize }) } } @@ -82,7 +102,7 @@ pub(super) fn generate_enum_definitions<'a, 'schema: 'a>( match s.as_str() { #(#variant_str => Ok(#constructors),)* - _ => Ok(#name::Other(s)), + #other_deserialize } } } diff --git a/graphql_client_codegen/src/codegen_options.rs b/graphql_client_codegen/src/codegen_options.rs index 7b3d8d73..a5d2dc0f 100644 --- a/graphql_client_codegen/src/codegen_options.rs +++ b/graphql_client_codegen/src/codegen_options.rs @@ -45,6 +45,8 @@ pub struct GraphQLClientCodegenOptions { extern_enums: Vec, /// Flag to trigger generation of Other variant for fragments Enum fragments_other_variant: bool, + /// Flag to trigger generation of Other variant for enums + enums_other_variant: bool, /// Skip Serialization of None values. skip_serializing_none: bool, /// Path to the serde crate. @@ -73,6 +75,7 @@ impl GraphQLClientCodegenOptions { custom_scalars_module: Default::default(), extern_enums: Default::default(), fragments_other_variant: Default::default(), + enums_other_variant: true, // Default to true for backward compatibility skip_serializing_none: Default::default(), serde_path: syn::parse_quote!(::serde), custom_variable_types: Default::default(), @@ -247,6 +250,16 @@ impl GraphQLClientCodegenOptions { &self.fragments_other_variant } + /// Set the graphql client codegen options's enums other variant. + pub fn set_enums_other_variant(&mut self, enums_other_variant: bool) { + self.enums_other_variant = enums_other_variant; + } + + /// Get a reference to the graphql client codegen options's enums other variant. + pub fn enums_other_variant(&self) -> &bool { + &self.enums_other_variant + } + /// Set the graphql client codegen option's skip none value. pub fn set_skip_serializing_none(&mut self, skip_serializing_none: bool) { self.skip_serializing_none = skip_serializing_none diff --git a/graphql_client_codegen/src/tests/mod.rs b/graphql_client_codegen/src/tests/mod.rs index aaed3e5d..d2d91a87 100644 --- a/graphql_client_codegen/src/tests/mod.rs +++ b/graphql_client_codegen/src/tests/mod.rs @@ -154,3 +154,91 @@ fn skip_serializing_none_should_generate_serde_skip_serializing() { } }; } + +#[test] +fn enums_other_variant_true_should_generate_other_variant() { + let query_string = KEYWORDS_QUERY; + let schema_path = build_schema_path(KEYWORDS_SCHEMA_PATH); + + let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Cli); + options.set_enums_other_variant(true); + + let generated_tokens = + generate_module_token_stream_from_string(query_string, &schema_path, options) + .expect("Generate keywords module"); + + let generated_code = generated_tokens.to_string(); + + let r: syn::parse::Result = syn::parse2(generated_tokens); + match r { + Ok(_) => { + // Should contain Other(String) variant in the enum (with spaces in generated code) + assert!(generated_code.contains("Other (String)")); + // Should have Other variant in serialize match + assert!(generated_code.contains("AnEnum :: Other (ref s) => & s")); + // Should have Other variant in deserialize match + assert!(generated_code.contains("_ => Ok (AnEnum :: Other (s))")); + } + Err(e) => { + panic!("Error: {}\n Generated content: {}\n", e, &generated_code); + } + }; +} + +#[test] +fn enums_other_variant_false_should_not_generate_other_variant() { + let query_string = KEYWORDS_QUERY; + let schema_path = build_schema_path(KEYWORDS_SCHEMA_PATH); + + let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Cli); + options.set_enums_other_variant(false); + + let generated_tokens = + generate_module_token_stream_from_string(query_string, &schema_path, options) + .expect("Generate keywords module"); + + let generated_code = generated_tokens.to_string(); + + let r: syn::parse::Result = syn::parse2(generated_tokens); + match r { + Ok(_) => { + // Should NOT contain Other(String) variant in the enum (with spaces in generated code) + assert!(!generated_code.contains("Other (String)")); + // Should NOT have Other variant in serialize match + assert!(!generated_code.contains("AnEnum :: Other (ref s) => & s")); + // Should have error handling for unknown variants instead of Other + assert!(generated_code.contains("unknown_variant")); + } + Err(e) => { + panic!("Error: {}\n Generated content: {}\n", e, &generated_code); + } + }; +} + +#[test] +fn enums_other_variant_default_should_be_true_for_backward_compatibility() { + let query_string = KEYWORDS_QUERY; + let schema_path = build_schema_path(KEYWORDS_SCHEMA_PATH); + + // Use default options without explicitly setting enums_other_variant + let options = GraphQLClientCodegenOptions::new(CodegenMode::Cli); + + let generated_tokens = + generate_module_token_stream_from_string(query_string, &schema_path, options) + .expect("Generate keywords module"); + + let generated_code = generated_tokens.to_string(); + + let r: syn::parse::Result = syn::parse2(generated_tokens); + match r { + Ok(_) => { + // By default, should contain Other(String) variant for backward compatibility (with spaces in generated code) + assert!(generated_code.contains("Other (String)")); + assert!(generated_code.contains("AnEnum :: Other (ref s) => & s")); + assert!(generated_code.contains("_ => Ok (AnEnum :: Other (s))")); + } + Err(e) => { + panic!("Error: {}\n Generated content: {}\n", e, &generated_code); + } + }; +} diff --git a/graphql_query_derive/src/attributes.rs b/graphql_query_derive/src/attributes.rs index 535914fb..bfb3a383 100644 --- a/graphql_query_derive/src/attributes.rs +++ b/graphql_query_derive/src/attributes.rs @@ -126,6 +126,13 @@ pub fn extract_fragments_other_variant(ast: &syn::DeriveInput) -> bool { .unwrap_or(false) } +pub fn extract_enums_other_variant(ast: &syn::DeriveInput) -> bool { + extract_attr(ast, "enums_other_variant") + .ok() + .and_then(|s| FromStr::from_str(s.as_str()).ok()) + .unwrap_or(true) // Default to true for backward compatibility +} + pub fn extract_skip_serializing_none(ast: &syn::DeriveInput) -> bool { ident_exists(ast, "skip_serializing_none").is_ok() } @@ -247,6 +254,67 @@ mod test { assert!(!extract_fragments_other_variant(&parsed)); } + #[test] + fn test_enums_other_variant_set_to_true() { + let input = " + #[derive(GraphQLQuery)] + #[graphql( + schema_path = \"x\", + query_path = \"x\", + enums_other_variant = \"true\", + )] + struct MyQuery; + "; + let parsed = syn::parse_str(input).unwrap(); + assert!(extract_enums_other_variant(&parsed)); + } + + #[test] + fn test_enums_other_variant_set_to_false() { + let input = " + #[derive(GraphQLQuery)] + #[graphql( + schema_path = \"x\", + query_path = \"x\", + enums_other_variant = \"false\", + )] + struct MyQuery; + "; + let parsed = syn::parse_str(input).unwrap(); + assert!(!extract_enums_other_variant(&parsed)); + } + + #[test] + fn test_enums_other_variant_set_to_invalid() { + let input = " + #[derive(GraphQLQuery)] + #[graphql( + schema_path = \"x\", + query_path = \"x\", + enums_other_variant = \"invalid\", + )] + struct MyQuery; + "; + let parsed = syn::parse_str(input).unwrap(); + // Should default to true when invalid value is provided + assert!(extract_enums_other_variant(&parsed)); + } + + #[test] + fn test_enums_other_variant_unset() { + let input = " + #[derive(GraphQLQuery)] + #[graphql( + schema_path = \"x\", + query_path = \"x\", + )] + struct MyQuery; + "; + let parsed = syn::parse_str(input).unwrap(); + // Should default to true for backward compatibility + assert!(extract_enums_other_variant(&parsed)); + } + #[test] fn test_skip_serializing_none_set() { let input = r#" diff --git a/graphql_query_derive/src/lib.rs b/graphql_query_derive/src/lib.rs index c6a7eca3..be10c1c5 100644 --- a/graphql_query_derive/src/lib.rs +++ b/graphql_query_derive/src/lib.rs @@ -64,6 +64,7 @@ fn build_graphql_client_derive_options( let custom_scalars_module = attributes::extract_attr(input, "custom_scalars_module").ok(); let extern_enums = attributes::extract_attr_list(input, "extern_enums").ok(); let fragments_other_variant: bool = attributes::extract_fragments_other_variant(input); + let enums_other_variant: bool = attributes::extract_enums_other_variant(input); let skip_serializing_none: bool = attributes::extract_skip_serializing_none(input); let custom_variable_types = attributes::extract_attr_list(input, "variable_types").ok(); let custom_response_type = attributes::extract_attr(input, "response_type").ok(); @@ -71,6 +72,7 @@ fn build_graphql_client_derive_options( let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Derive); options.set_query_file(query_path); options.set_fragments_other_variant(fragments_other_variant); + options.set_enums_other_variant(enums_other_variant); options.set_skip_serializing_none(skip_serializing_none); if let Some(variables_derives) = variables_derives { @@ -106,7 +108,7 @@ fn build_graphql_client_derive_options( if let Some(custom_variable_types) = custom_variable_types { options.set_custom_variable_types(custom_variable_types); } - + if let Some(custom_response_type) = custom_response_type { options.set_custom_response_type(custom_response_type); }