@@ -9,7 +9,7 @@ use syn::{
9
9
parse:: { Parse , ParseStream } ,
10
10
parse_macro_input,
11
11
spanned:: Spanned ,
12
- Expr , Ident , Item , ItemEnum , Token , Variant ,
12
+ Expr , Ident , DeriveInput , Data , Token , Variant ,
13
13
} ;
14
14
15
15
struct Flag < ' a > {
@@ -58,14 +58,8 @@ pub fn bitflags_internal(
58
58
input : proc_macro:: TokenStream ,
59
59
) -> proc_macro:: TokenStream {
60
60
let Parameters { default } = parse_macro_input ! ( attr as Parameters ) ;
61
- let mut ast = parse_macro_input ! ( input as Item ) ;
62
- let output = match ast {
63
- Item :: Enum ( ref mut item_enum) => gen_enumflags ( item_enum, default) ,
64
- _ => Err ( syn:: Error :: new_spanned (
65
- & ast,
66
- "#[bitflags] requires an enum" ,
67
- ) ) ,
68
- } ;
61
+ let mut ast = parse_macro_input ! ( input as DeriveInput ) ;
62
+ let output = gen_enumflags ( & mut ast, default) ;
69
63
70
64
output
71
65
. unwrap_or_else ( |err| {
@@ -247,17 +241,29 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenSt
247
241
}
248
242
}
249
243
250
- fn gen_enumflags ( ast : & mut ItemEnum , default : Vec < Ident > ) -> Result < TokenStream , syn:: Error > {
244
+ fn gen_enumflags ( ast : & mut DeriveInput , default : Vec < Ident > ) -> Result < TokenStream , syn:: Error > {
251
245
let ident = & ast. ident ;
252
246
253
247
let span = Span :: call_site ( ) ;
254
248
249
+ let ast_variants = match & mut ast. data {
250
+ Data :: Enum ( ref mut data) => & mut data. variants ,
251
+ Data :: Struct ( data) => {
252
+ return Err ( syn:: Error :: new_spanned ( & data. struct_token ,
253
+ "expected enum for #[bitflags], found struct" ) ) ;
254
+ }
255
+ Data :: Union ( data) => {
256
+ return Err ( syn:: Error :: new_spanned ( & data. union_token ,
257
+ "expected enum for #[bitflags], found union" ) ) ;
258
+ }
259
+ } ;
260
+
255
261
let repr = extract_repr ( & ast. attrs ) ?
256
262
. ok_or_else ( || syn:: Error :: new_spanned ( ident,
257
263
"repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield." ) ) ?;
258
264
let bits = type_bits ( & repr) ?;
259
265
260
- let mut variants = collect_flags ( ast . variants . iter_mut ( ) ) ?;
266
+ let mut variants = collect_flags ( ast_variants . iter_mut ( ) ) ?;
261
267
let deferred = variants
262
268
. iter ( )
263
269
. flat_map ( |variant| check_flag ( ident, variant, bits) . transpose ( ) )
@@ -273,7 +279,12 @@ fn gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream,
273
279
}
274
280
275
281
let std = quote_spanned ! ( span => :: enumflags2:: _internal:: core) ;
276
- let variant_names = ast. variants . iter ( ) . map ( |v| & v. ident ) . collect :: < Vec < _ > > ( ) ;
282
+ let ast_variants = match & ast. data {
283
+ Data :: Enum ( ref data) => & data. variants ,
284
+ _ => unreachable ! ( ) ,
285
+ } ;
286
+
287
+ let variant_names = ast_variants. iter ( ) . map ( |v| & v. ident ) . collect :: < Vec < _ > > ( ) ;
277
288
278
289
Ok ( quote_spanned ! {
279
290
span =>
0 commit comments