Skip to content

Commit 7234f73

Browse files
Construct body for by-move coroutine closure output
1 parent 875b806 commit 7234f73

File tree

23 files changed

+229
-15
lines changed

23 files changed

+229
-15
lines changed

compiler/rustc_const_eval/src/interpret/terminator.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
542542
| ty::InstanceDef::ReifyShim(..)
543543
| ty::InstanceDef::ClosureOnceShim { .. }
544544
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
545+
| ty::InstanceDef::CoroutineByMoveShim { .. }
545546
| ty::InstanceDef::FnPtrShim(..)
546547
| ty::InstanceDef::DropGlue(..)
547548
| ty::InstanceDef::CloneShim(..)

compiler/rustc_hir_typeck/src/callee.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
182182
coroutine_closure_sig.to_coroutine(
183183
self.tcx,
184184
closure_args.parent_args(),
185+
closure_args.kind_ty(),
185186
self.tcx.coroutine_for_closure(def_id),
186187
tupled_upvars_ty,
187188
),

compiler/rustc_hir_typeck/src/closure.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
175175
interior,
176176
));
177177

178+
let kind_ty = match kind {
179+
hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure) => self
180+
.next_ty_var(TypeVariableOrigin {
181+
kind: TypeVariableOriginKind::ClosureSynthetic,
182+
span: expr_span,
183+
}),
184+
_ => tcx.types.unit,
185+
};
186+
178187
let coroutine_args = ty::CoroutineArgs::new(
179188
tcx,
180189
ty::CoroutineArgsParts {
181190
parent_args,
191+
kind_ty,
182192
resume_ty,
183193
yield_ty,
184194
return_ty: liberated_sig.output(),
@@ -256,6 +266,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
256266
sig.to_coroutine(
257267
tcx,
258268
parent_args,
269+
closure_kind_ty,
259270
tcx.coroutine_for_closure(expr_def_id),
260271
coroutine_upvars_ty,
261272
)

compiler/rustc_hir_typeck/src/upvar.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
393393
args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
394394
coroutine_captures_by_ref_ty,
395395
);
396+
397+
let ty::Coroutine(_, args) = *self.typeck_results.borrow().expr_ty(body.value).kind()
398+
else {
399+
bug!();
400+
};
401+
self.demand_eqtype(
402+
span,
403+
args.as_coroutine().kind_ty(),
404+
Ty::from_closure_kind(self.tcx, closure_kind),
405+
);
396406
}
397407

398408
self.log_closure_min_capture_info(closure_def_id, span);

compiler/rustc_middle/src/mir/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ pub struct CoroutineInfo<'tcx> {
260260
/// Coroutine drop glue. This field is populated after the state transform pass.
261261
pub coroutine_drop: Option<Body<'tcx>>,
262262

263+
/// The body of the coroutine, modified to take its upvars by move.
264+
/// TODO:
265+
pub by_move_body: Option<Body<'tcx>>,
266+
263267
/// The layout of a coroutine. This field is populated after the state transform pass.
264268
pub coroutine_layout: Option<CoroutineLayout<'tcx>>,
265269

@@ -279,6 +283,7 @@ impl<'tcx> CoroutineInfo<'tcx> {
279283
coroutine_kind,
280284
yield_ty: Some(yield_ty),
281285
resume_ty: Some(resume_ty),
286+
by_move_body: None,
282287
coroutine_drop: None,
283288
coroutine_layout: None,
284289
}

compiler/rustc_middle/src/mir/mono.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ impl<'tcx> CodegenUnit<'tcx> {
403403
| InstanceDef::Virtual(..)
404404
| InstanceDef::ClosureOnceShim { .. }
405405
| InstanceDef::ConstructCoroutineInClosureShim { .. }
406+
| InstanceDef::CoroutineByMoveShim { .. }
406407
| InstanceDef::DropGlue(..)
407408
| InstanceDef::CloneShim(..)
408409
| InstanceDef::ThreadLocalShim(..)

compiler/rustc_middle/src/mir/visit.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ macro_rules! make_mir_visitor {
346346
ty::InstanceDef::ThreadLocalShim(_def_id) |
347347
ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
348348
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } |
349+
ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: _def_id } |
349350
ty::InstanceDef::DropGlue(_def_id, None) => {}
350351

351352
ty::InstanceDef::FnPtrShim(_def_id, ty) |

compiler/rustc_middle/src/ty/instance.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ pub enum InstanceDef<'tcx> {
101101
target_kind: ty::ClosureKind,
102102
},
103103

104+
/// TODO:
105+
CoroutineByMoveShim { coroutine_def_id: DefId },
106+
104107
/// Compiler-generated accessor for thread locals which returns a reference to the thread local
105108
/// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking
106109
/// native support.
@@ -186,6 +189,7 @@ impl<'tcx> InstanceDef<'tcx> {
186189
coroutine_closure_def_id: def_id,
187190
target_kind: _,
188191
}
192+
| ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: def_id }
189193
| InstanceDef::DropGlue(def_id, _)
190194
| InstanceDef::CloneShim(def_id, _)
191195
| InstanceDef::FnPtrAddrShim(def_id, _) => def_id,
@@ -206,6 +210,7 @@ impl<'tcx> InstanceDef<'tcx> {
206210
| InstanceDef::Intrinsic(..)
207211
| InstanceDef::ClosureOnceShim { .. }
208212
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
213+
| ty::InstanceDef::CoroutineByMoveShim { .. }
209214
| InstanceDef::DropGlue(..)
210215
| InstanceDef::CloneShim(..)
211216
| InstanceDef::FnPtrAddrShim(..) => None,
@@ -302,6 +307,7 @@ impl<'tcx> InstanceDef<'tcx> {
302307
| InstanceDef::DropGlue(_, Some(_)) => false,
303308
InstanceDef::ClosureOnceShim { .. }
304309
| InstanceDef::ConstructCoroutineInClosureShim { .. }
310+
| InstanceDef::CoroutineByMoveShim { .. }
305311
| InstanceDef::DropGlue(..)
306312
| InstanceDef::Item(_)
307313
| InstanceDef::Intrinsic(..)
@@ -340,6 +346,7 @@ fn fmt_instance(
340346
InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"),
341347
InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"),
342348
InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"),
349+
InstanceDef::CoroutineByMoveShim { .. } => write!(f, " - shim"),
343350
InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"),
344351
InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"),
345352
InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"),
@@ -631,7 +638,19 @@ impl<'tcx> Instance<'tcx> {
631638
};
632639

633640
if tcx.lang_items().get(coroutine_callable_item) == Some(trait_item_id) {
634-
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args })
641+
let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind()
642+
else {
643+
bug!()
644+
};
645+
646+
if args.as_coroutine().kind_ty() == id_args.as_coroutine().kind_ty() {
647+
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
648+
} else {
649+
Some(Instance {
650+
def: ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id },
651+
args,
652+
})
653+
}
635654
} else {
636655
// All other methods should be defaulted methods of the built-in trait.
637656
// This is important for `Iterator`'s combinators, but also useful for

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,6 +2355,7 @@ impl<'tcx> TyCtxt<'tcx> {
23552355
| ty::InstanceDef::Virtual(..)
23562356
| ty::InstanceDef::ClosureOnceShim { .. }
23572357
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
2358+
| ty::InstanceDef::CoroutineByMoveShim { .. }
23582359
| ty::InstanceDef::DropGlue(..)
23592360
| ty::InstanceDef::CloneShim(..)
23602361
| ty::InstanceDef::ThreadLocalShim(..)

compiler/rustc_middle/src/ty/sty.rs

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -475,13 +475,15 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
475475
self,
476476
tcx: TyCtxt<'tcx>,
477477
parent_args: &'tcx [GenericArg<'tcx>],
478+
kind_ty: Ty<'tcx>,
478479
coroutine_def_id: DefId,
479480
tupled_upvars_ty: Ty<'tcx>,
480481
) -> Ty<'tcx> {
481482
let coroutine_args = ty::CoroutineArgs::new(
482483
tcx,
483484
ty::CoroutineArgsParts {
484485
parent_args,
486+
kind_ty,
485487
resume_ty: self.resume_ty,
486488
yield_ty: self.yield_ty,
487489
return_ty: self.return_ty,
@@ -512,7 +514,13 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
512514
env_region,
513515
);
514516

515-
self.to_coroutine(tcx, parent_args, coroutine_def_id, tupled_upvars_ty)
517+
self.to_coroutine(
518+
tcx,
519+
parent_args,
520+
Ty::from_closure_kind(tcx, closure_kind),
521+
coroutine_def_id,
522+
tupled_upvars_ty,
523+
)
516524
}
517525

518526
/// Given a closure kind, compute the tupled upvars that the given coroutine would return.
@@ -564,6 +572,8 @@ pub struct CoroutineArgs<'tcx> {
564572
pub struct CoroutineArgsParts<'tcx> {
565573
/// This is the args of the typeck root.
566574
pub parent_args: &'tcx [GenericArg<'tcx>],
575+
// TODO: why
576+
pub kind_ty: Ty<'tcx>,
567577
pub resume_ty: Ty<'tcx>,
568578
pub yield_ty: Ty<'tcx>,
569579
pub return_ty: Ty<'tcx>,
@@ -582,6 +592,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
582592
pub fn new(tcx: TyCtxt<'tcx>, parts: CoroutineArgsParts<'tcx>) -> CoroutineArgs<'tcx> {
583593
CoroutineArgs {
584594
args: tcx.mk_args_from_iter(parts.parent_args.iter().copied().chain([
595+
parts.kind_ty.into(),
585596
parts.resume_ty.into(),
586597
parts.yield_ty.into(),
587598
parts.return_ty.into(),
@@ -595,16 +606,23 @@ impl<'tcx> CoroutineArgs<'tcx> {
595606
/// The ordering assumed here must match that used by `CoroutineArgs::new` above.
596607
fn split(self) -> CoroutineArgsParts<'tcx> {
597608
match self.args[..] {
598-
[ref parent_args @ .., resume_ty, yield_ty, return_ty, witness, tupled_upvars_ty] => {
599-
CoroutineArgsParts {
600-
parent_args,
601-
resume_ty: resume_ty.expect_ty(),
602-
yield_ty: yield_ty.expect_ty(),
603-
return_ty: return_ty.expect_ty(),
604-
witness: witness.expect_ty(),
605-
tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
606-
}
607-
}
609+
[
610+
ref parent_args @ ..,
611+
kind_ty,
612+
resume_ty,
613+
yield_ty,
614+
return_ty,
615+
witness,
616+
tupled_upvars_ty,
617+
] => CoroutineArgsParts {
618+
parent_args,
619+
kind_ty: kind_ty.expect_ty(),
620+
resume_ty: resume_ty.expect_ty(),
621+
yield_ty: yield_ty.expect_ty(),
622+
return_ty: return_ty.expect_ty(),
623+
witness: witness.expect_ty(),
624+
tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
625+
},
608626
_ => bug!("coroutine args missing synthetics"),
609627
}
610628
}
@@ -614,6 +632,11 @@ impl<'tcx> CoroutineArgs<'tcx> {
614632
self.split().parent_args
615633
}
616634

635+
// TODO:
636+
pub fn kind_ty(self) -> Ty<'tcx> {
637+
self.split().kind_ty
638+
}
639+
617640
/// This describes the types that can be contained in a coroutine.
618641
/// It will be a type variable initially and unified in the last stages of typeck of a body.
619642
/// It contains a tuple of all the types that could end up on a coroutine frame.
@@ -2381,7 +2404,7 @@ impl<'tcx> Ty<'tcx> {
23812404
) -> Ty<'tcx> {
23822405
debug_assert_eq!(
23832406
coroutine_args.len(),
2384-
tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 5,
2407+
tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 6,
23852408
"coroutine constructed with incorrect number of substitutions"
23862409
);
23872410
Ty::new(tcx, Coroutine(def_id, coroutine_args))

0 commit comments

Comments
 (0)