@@ -12,7 +12,8 @@ use crate::{FromLitStr, MetaList};
12
12
#[ derive( Copy , Clone , Debug ) ]
13
13
pub enum TraceDeriveKind {
14
14
NullTrace ,
15
- Regular
15
+ Regular ,
16
+ Deserialize
16
17
}
17
18
18
19
trait PossiblyIgnoredParam {
@@ -134,7 +135,7 @@ pub struct TraceTypeParam {
134
135
#[ darling( skip) ]
135
136
ignore : bool ,
136
137
#[ darling( skip) ]
137
- collector_id : bool
138
+ collector_id : bool ,
138
139
}
139
140
impl TraceTypeParam {
140
141
fn normalize ( & mut self ) -> Result < ( ) , Error > {
@@ -188,6 +189,9 @@ struct TraceField {
188
189
/// to be traced.
189
190
#[ darling( default ) ]
190
191
unsafe_skip_trace : bool ,
192
+
193
+ #[ darling( forward_attrs( serde) ) ]
194
+ attrs : Vec < syn:: Attribute >
191
195
}
192
196
impl TraceField {
193
197
fn expand_trace ( & self , idx : usize , access : & FieldAccess , immutable : bool ) -> TokenStream {
@@ -211,9 +215,12 @@ impl TraceField {
211
215
}
212
216
213
217
#[ derive( Debug , FromVariant ) ]
218
+ #[ darling( attributes( zerogc) ) ]
214
219
struct TraceVariant {
215
220
ident : Ident ,
216
- fields : darling:: ast:: Fields < TraceField >
221
+ fields : darling:: ast:: Fields < TraceField > ,
222
+ #[ darling( forward_attrs( serde) ) ]
223
+ attrs : Vec < syn:: Attribute >
217
224
}
218
225
impl TraceVariant {
219
226
fn fields ( & self ) -> impl Iterator < Item =& ' _ TraceField > + ' _ {
@@ -267,27 +274,26 @@ pub struct TraceDeriveInput {
267
274
is_copy : bool ,
268
275
/// If the type should implement `TraceImmutable` in addition to `Trace
269
276
#[ darling( default , rename = "immutable" ) ]
270
- wants_immutable_trace : bool
277
+ wants_immutable_trace : bool ,
278
+ #[ darling( forward_attrs( serde) ) ]
279
+ attrs : Vec < syn:: Attribute >
271
280
}
272
281
impl TraceDeriveInput {
273
- pub fn determine_field_types ( & self , include_ignored : bool ) -> HashSet < Type > {
282
+ fn all_fields ( & self ) -> Vec < & TraceField > {
274
283
match self . data {
275
284
Data :: Enum ( ref variants) => {
276
- variants. iter ( )
277
- . flat_map ( |var| var. fields ( ) )
278
- . filter ( |f| !f. unsafe_skip_trace || include_ignored)
279
- . map ( |fd| & fd. ty )
280
- . cloned ( )
281
- . collect ( )
282
- }
283
- Data :: Struct ( ref s) => {
284
- s. fields . iter ( )
285
- . filter ( |f| !f. unsafe_skip_trace || include_ignored)
286
- . map ( |f| & f. ty ) . cloned ( )
287
- . collect ( )
288
- }
285
+ variants. iter ( ) . flat_map ( |var| var. fields ( ) ) . collect ( )
286
+ } ,
287
+ Data :: Struct ( ref fields) => fields. iter ( ) . collect ( )
289
288
}
290
289
}
290
+ pub fn determine_field_types ( & self , include_ignored : bool ) -> HashSet < Type > {
291
+ self . all_fields ( ) . iter ( )
292
+ . filter ( |f| !f. unsafe_skip_trace || include_ignored)
293
+ . map ( |fd| & fd. ty )
294
+ . cloned ( )
295
+ . collect ( )
296
+ }
291
297
pub fn normalize ( & mut self , kind : TraceDeriveKind ) -> Result < ( ) , Error > {
292
298
if * self . nop_trace {
293
299
crate :: emit_warning ( "#[zerogc(nop_trace)] is deprecated (use #[derive(NullTrace)] instead)" , self . nop_trace . span ( ) )
@@ -330,6 +336,17 @@ impl TraceDeriveInput {
330
336
fn gc_lifetime ( & self ) -> Option < & ' _ Lifetime > {
331
337
self . generics . gc_lifetime . as_ref ( )
332
338
}
339
+ fn generics_with_gc_lifetime ( & self , lt : Lifetime ) -> ( syn:: Lifetime , Generics ) {
340
+ let mut generics = self . generics . original . clone ( ) ;
341
+ let gc_lifetime: syn:: Lifetime = match self . gc_lifetime ( ) {
342
+ Some ( lt) => lt. clone ( ) ,
343
+ None => {
344
+ generics. params . push ( GenericParam :: Lifetime ( LifetimeDef :: new ( lt. clone ( ) ) ) ) ;
345
+ lt
346
+ }
347
+ } ;
348
+ ( gc_lifetime, generics)
349
+ }
333
350
/// Expand a `GcSafe` for a specific combination of `Id` & 'gc
334
351
///
335
352
/// Implicitly modifies the specified generics
@@ -351,7 +368,8 @@ impl TraceDeriveInput {
351
368
} ,
352
369
TraceDeriveKind :: NullTrace => {
353
370
quote ! ( zerogc:: NullTrace )
354
- }
371
+ } ,
372
+ TraceDeriveKind :: Deserialize => unreachable ! ( )
355
373
} ;
356
374
for tp in self . generics . regular_type_params ( ) {
357
375
let tp = & tp. ident ;
@@ -363,12 +381,14 @@ impl TraceDeriveInput {
363
381
generics. make_where_clause ( ) . predicates . push (
364
382
parse_quote ! ( #tp: #requirement)
365
383
)
366
- }
384
+ } ,
385
+ TraceDeriveKind :: Deserialize => unreachable ! ( )
367
386
}
368
387
}
369
388
let assertion: Ident = match kind {
370
389
TraceDeriveKind :: NullTrace => parse_quote ! ( verify_null_trace) ,
371
- TraceDeriveKind :: Regular => parse_quote ! ( assert_gc_safe)
390
+ TraceDeriveKind :: Regular => parse_quote ! ( assert_gc_safe) ,
391
+ TraceDeriveKind :: Deserialize => unreachable ! ( )
372
392
} ;
373
393
let ty_generics = self . generics . original . split_for_impl ( ) . 1 ;
374
394
let ( impl_generics, _, where_clause) = generics. split_for_impl ( ) ;
@@ -383,14 +403,7 @@ impl TraceDeriveInput {
383
403
} )
384
404
}
385
405
fn expand_gcsafe ( & self , kind : TraceDeriveKind ) -> Result < TokenStream , Error > {
386
- let mut generics = self . generics . original . clone ( ) ;
387
- let gc_lifetime: syn:: Lifetime = match self . gc_lifetime ( ) {
388
- Some ( lt) => lt. clone ( ) ,
389
- None => {
390
- generics. params . push ( parse_quote ! ( ' gc) ) ;
391
- parse_quote ! ( ' gc)
392
- }
393
- } ;
406
+ let ( gc_lifetime, mut generics) = self . generics_with_gc_lifetime ( parse_quote ! ( ' gc) ) ;
394
407
match kind {
395
408
TraceDeriveKind :: NullTrace => {
396
409
// Verify we don't have any explicit collector id
@@ -410,9 +423,108 @@ impl TraceDeriveInput {
410
423
self . expand_gcsafe_sepcific ( kind, initial, id, gc_lt)
411
424
}
412
425
)
413
- }
426
+ } ,
427
+ TraceDeriveKind :: Deserialize => unreachable ! ( )
414
428
}
415
429
}
430
+
431
+ fn expand_deserialize ( & self ) -> Result < TokenStream , Error > {
432
+ if !crate :: DESERIALIZE_ENABLED {
433
+ return Err ( Error :: custom ( "The `zerogc/serde1` feature is disabled (please enable it)" ) ) ;
434
+ }
435
+ let ( gc_lifetime, generics) = self . generics_with_gc_lifetime ( parse_quote ! ( ' gc) ) ;
436
+ self . expand_for_each_regular_id (
437
+ generics, TraceDeriveKind :: Deserialize , gc_lifetime,
438
+ & mut |kind, initial, id, gc_lt| {
439
+ assert ! ( matches!( kind, TraceDeriveKind :: Deserialize ) ) ;
440
+ let id_is_generic = self . generics . original . type_params ( )
441
+ . any ( |param| id. is_ident ( & param. ident ) ) ;
442
+ let mut generics = initial. unwrap ( ) ;
443
+ generics. params . push ( parse_quote ! ( ' deserialize) ) ;
444
+ let requirement = quote ! ( for <' deser2> zerogc:: serde:: GcDeserialize :: <#gc_lt, ' deser2, #id>) ;
445
+ for target in self . generics . regular_type_params ( ) {
446
+ let target = & target. ident ;
447
+ generics. make_where_clause ( ) . predicates . push ( parse_quote ! ( #target: #requirement) ) ;
448
+ }
449
+ let ty_generics = self . generics . original . split_for_impl ( ) . 1 ;
450
+ let ( impl_generics, _, where_clause) = generics. split_for_impl ( ) ;
451
+ let target_type = & self . ident ;
452
+ let forward_attrs = & self . attrs ;
453
+ let deserialize_field = |f : & TraceField | {
454
+ let named = f. ident . as_ref ( ) . map ( |name| quote ! ( #name: ) ) ;
455
+ let ty = & f. ty ;
456
+ let forwarded_attrs = & f. attrs ;
457
+ let bound = format ! (
458
+ "{}: for<'deserialize> zerogc::serde::GcDeserialize<{}, 'deserialize, {}>" , ty. to_token_stream( ) ,
459
+ gc_lt. to_token_stream( ) , id. to_token_stream( )
460
+ ) ;
461
+ quote ! {
462
+ #( #forwarded_attrs) *
463
+ # [ serde( deserialize_with = "deserialize_hack" , bound( deserialize = #bound) ) ]
464
+ #named #ty
465
+ }
466
+ } ;
467
+ let handle_fields = |fields : & darling:: ast:: Fields < TraceField > | {
468
+ let handled_fields = fields. fields . iter ( ) . map ( deserialize_field) ;
469
+ match fields. style {
470
+ Style :: Tuple => {
471
+ quote ! { ( #( #handled_fields) , * ) }
472
+ }
473
+ Style :: Struct => {
474
+ quote ! ( { #( #handled_fields) , * } )
475
+ }
476
+ Style :: Unit => quote ! ( )
477
+ }
478
+ } ;
479
+ let original_generics = & self . generics . original ;
480
+ let inner = match self . data {
481
+ Data :: Enum ( ref variants) => {
482
+ let variants = variants. iter ( ) . map ( |v| {
483
+ let forward_attrs = & v. attrs ;
484
+ let name = & v. ident ;
485
+ let inner = handle_fields ( & v. fields ) ;
486
+ quote ! {
487
+ #( #forward_attrs) *
488
+ #name #inner
489
+ }
490
+ } ) ;
491
+ quote ! ( enum HackRemoteDeserialize #original_generics { #( #variants) , * } )
492
+ }
493
+ Data :: Struct ( ref f) => {
494
+ let fields = handle_fields ( f) ;
495
+ quote ! ( struct HackRemoteDeserialize #original_generics # fields)
496
+ }
497
+ } ;
498
+ let remote_name = target_type. to_token_stream ( ) . to_string ( ) ;
499
+ let id_decl = if id_is_generic {
500
+ Some ( quote ! ( #id: zerogc:: CollectorId , ) )
501
+ } else { None } ;
502
+ Ok ( quote ! {
503
+ impl #impl_generics zerogc:: serde:: GcDeserialize <#gc_lt, ' deserialize, #id> for #target_type #ty_generics #where_clause {
504
+ fn deserialize_gc<D : serde:: Deserializer <' deserialize>>( ctx: & #gc_lt <<#id as zerogc:: CollectorId >:: System as zerogc:: GcSystem >:: Context , deserializer: D ) -> Result <Self , D :: Error > {
505
+ use serde:: Deserializer ;
506
+ let _guard = unsafe { zerogc:: serde:: hack:: set_context( ctx) } ;
507
+ unsafe {
508
+ debug_assert_eq!( _guard. get_unchecked( ) as * const _, ctx as * const _) ;
509
+ }
510
+ /// Hack function to deserialize via `serde::hack`, with the appropriate `Id` type
511
+ ///
512
+ /// Needed because the actual function is unsafe
513
+ #[ track_caller]
514
+ fn deserialize_hack<' gc, ' de, #id_decl D : serde:: de:: Deserializer <' de>, T : zerogc:: serde:: GcDeserialize <#gc_lt, ' de, #id>>( deser: D ) -> Result <T , D :: Error > {
515
+ unsafe { zerogc:: serde:: hack:: unchecked_deserialize_hack:: <' gc, ' de, D , #id, T >( deser) }
516
+ }
517
+ # [ derive( serde:: Deserialize ) ]
518
+ # [ serde( remote = #remote_name) ]
519
+ #( #forward_attrs) *
520
+ #inner ;
521
+ HackRemoteDeserialize :: deserialize( deserializer)
522
+ }
523
+ }
524
+ } )
525
+ }
526
+ )
527
+ }
416
528
fn expand_for_each_regular_id (
417
529
& self , generics : Generics ,
418
530
kind : TraceDeriveKind ,
@@ -669,6 +781,11 @@ impl TraceDeriveInput {
669
781
} )
670
782
}
671
783
fn expand_trusted_drop ( & self , kind : TraceDeriveKind ) -> TokenStream {
784
+ let mut generics = self . generics . original . clone ( ) ;
785
+ for param in self . generics . regular_type_params ( ) {
786
+ let name = & param. ident ;
787
+ generics. make_where_clause ( ) . predicates . push ( parse_quote ! ( #name: zerogc:: TrustedDrop ) ) ;
788
+ }
672
789
#[ allow( clippy:: if_same_then_else) ] // Only necessary because of detailed comment
673
790
let protective_drop = if self . is_copy {
674
791
/*
@@ -703,7 +820,7 @@ impl TraceDeriveInput {
703
820
} ) )
704
821
} ;
705
822
let target_type = & self . ident ;
706
- let ( impl_generics, ty_generics, where_clause) = self . generics . original . split_for_impl ( ) ;
823
+ let ( impl_generics, ty_generics, where_clause) = generics. split_for_impl ( ) ;
707
824
quote ! {
708
825
#protective_drop
709
826
unsafe impl #impl_generics zerogc:: TrustedDrop for #target_type #ty_generics #where_clause { }
@@ -832,6 +949,9 @@ impl TraceDeriveInput {
832
949
} )
833
950
}
834
951
pub fn expand ( & self , kind : TraceDeriveKind ) -> Result < TokenStream , Error > {
952
+ if matches ! ( kind, TraceDeriveKind :: Deserialize ) {
953
+ return self . expand_deserialize ( ) ;
954
+ }
835
955
let gcsafe = self . expand_gcsafe ( kind) ?;
836
956
let trace_immutable = if self . wants_immutable_trace {
837
957
Some ( self . expand_trace ( kind, true ) ?)
0 commit comments