Skip to content

Commit 7525a38

Browse files
committed
Support evaluating dyn Trait methods
1 parent a063f00 commit 7525a38

File tree

4 files changed

+197
-60
lines changed

4 files changed

+197
-60
lines changed

crates/hir-ty/src/consteval/tests.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,57 @@ fn function_traits() {
10081008
);
10091009
}
10101010

1011+
#[test]
1012+
fn dyn_trait() {
1013+
check_number(
1014+
r#"
1015+
//- minicore: coerce_unsized, index, slice
1016+
trait Foo {
1017+
fn foo(&self) -> u8 { 10 }
1018+
}
1019+
struct S1;
1020+
struct S2;
1021+
struct S3;
1022+
impl Foo for S1 {
1023+
fn foo(&self) -> u8 { 1 }
1024+
}
1025+
impl Foo for S2 {
1026+
fn foo(&self) -> u8 { 2 }
1027+
}
1028+
impl Foo for S3 {}
1029+
const GOAL: u8 = {
1030+
let x: &[&dyn Foo] = &[&S1, &S2, &S3];
1031+
x[0].foo() + x[1].foo() + x[2].foo()
1032+
};
1033+
"#,
1034+
13,
1035+
);
1036+
check_number(
1037+
r#"
1038+
//- minicore: coerce_unsized, index, slice
1039+
trait Foo {
1040+
fn foo(&self) -> i32 { 10 }
1041+
}
1042+
trait Bar {
1043+
fn bar(&self) -> i32 { 20 }
1044+
}
1045+
1046+
struct S;
1047+
impl Foo for S {
1048+
fn foo(&self) -> i32 { 200 }
1049+
}
1050+
impl Bar for dyn Foo {
1051+
fn bar(&self) -> i32 { 700 }
1052+
}
1053+
const GOAL: i32 = {
1054+
let x: &dyn Foo = &S;
1055+
x.bar() + x.foo()
1056+
};
1057+
"#,
1058+
900,
1059+
);
1060+
}
1061+
10111062
#[test]
10121063
fn array_and_index() {
10131064
check_number(

crates/hir-ty/src/method_resolution.rs

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
use std::{ops::ControlFlow, sync::Arc};
66

77
use base_db::{CrateId, Edition};
8-
use chalk_ir::{cast::Cast, Mutability, TyKind, UniverseIndex};
8+
use chalk_ir::{cast::Cast, Mutability, TyKind, UniverseIndex, WhereClause};
99
use hir_def::{
1010
data::ImplData, item_scope::ItemScope, lang_item::LangItem, nameres::DefMap, AssocItemId,
1111
BlockId, ConstId, FunctionId, HasModule, ImplId, ItemContainerId, Lookup, ModuleDefId,
@@ -692,6 +692,38 @@ pub fn lookup_impl_const(
692692
.unwrap_or((const_id, subs))
693693
}
694694

695+
/// Checks if the self parameter of `Trait` method is the `dyn Trait` and we should
696+
/// call the method using the vtable.
697+
pub fn is_dyn_method(
698+
db: &dyn HirDatabase,
699+
_env: Arc<TraitEnvironment>,
700+
func: FunctionId,
701+
fn_subst: Substitution,
702+
) -> Option<usize> {
703+
let ItemContainerId::TraitId(trait_id) = func.lookup(db.upcast()).container else {
704+
return None;
705+
};
706+
let trait_params = db.generic_params(trait_id.into()).type_or_consts.len();
707+
let fn_params = fn_subst.len(Interner) - trait_params;
708+
let trait_ref = TraitRef {
709+
trait_id: to_chalk_trait_id(trait_id),
710+
substitution: Substitution::from_iter(Interner, fn_subst.iter(Interner).skip(fn_params)),
711+
};
712+
let self_ty = trait_ref.self_type_parameter(Interner);
713+
if let TyKind::Dyn(d) = self_ty.kind(Interner) {
714+
let is_my_trait_in_bounds = d.bounds.skip_binders().as_slice(Interner).iter().any(|x| match x.skip_binders() {
715+
// rustc doesn't accept `impl Foo<2> for dyn Foo<5>`, so if the trait id is equal, no matter
716+
// what the generics are, we are sure that the method is come from the vtable.
717+
WhereClause::Implemented(tr) => tr.trait_id == trait_ref.trait_id,
718+
_ => false,
719+
});
720+
if is_my_trait_in_bounds {
721+
return Some(fn_params);
722+
}
723+
}
724+
None
725+
}
726+
695727
/// Looks up the impl method that actually runs for the trait method `func`.
696728
///
697729
/// Returns `func` if it's not a method defined in a trait or the lookup failed.
@@ -701,9 +733,8 @@ pub fn lookup_impl_method(
701733
func: FunctionId,
702734
fn_subst: Substitution,
703735
) -> (FunctionId, Substitution) {
704-
let trait_id = match func.lookup(db.upcast()).container {
705-
ItemContainerId::TraitId(id) => id,
706-
_ => return (func, fn_subst),
736+
let ItemContainerId::TraitId(trait_id) = func.lookup(db.upcast()).container else {
737+
return (func, fn_subst)
707738
};
708739
let trait_params = db.generic_params(trait_id.into()).type_or_consts.len();
709740
let fn_params = fn_subst.len(Interner) - trait_params;

crates/hir-ty/src/mir/eval.rs

Lines changed: 103 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,26 @@ use crate::{
2323
infer::{normalize, PointerCast},
2424
layout::layout_of_ty,
2525
mapping::from_chalk,
26-
method_resolution::lookup_impl_method,
26+
method_resolution::{is_dyn_method, lookup_impl_method},
2727
traits::FnTrait,
2828
CallableDefId, Const, ConstScalar, FnDefId, Interner, MemoryMap, Substitution,
29-
TraitEnvironment, Ty, TyBuilder, TyExt,
29+
TraitEnvironment, Ty, TyBuilder, TyExt, GenericArgData,
3030
};
3131

3232
use super::{
3333
const_as_usize, return_slot, AggregateKind, BinOp, CastKind, LocalId, MirBody, MirLowerError,
3434
Operand, Place, ProjectionElem, Rvalue, StatementKind, Terminator, UnOp,
3535
};
3636

37+
macro_rules! from_bytes {
38+
($ty:tt, $value:expr) => {
39+
($ty::from_le_bytes(match ($value).try_into() {
40+
Ok(x) => x,
41+
Err(_) => return Err(MirEvalError::TypeError("mismatched size")),
42+
}))
43+
};
44+
}
45+
3746
#[derive(Debug, Default)]
3847
struct VTableMap {
3948
ty_to_id: HashMap<Ty, usize>,
@@ -54,6 +63,11 @@ impl VTableMap {
5463
fn ty(&self, id: usize) -> Result<&Ty> {
5564
self.id_to_ty.get(id).ok_or(MirEvalError::InvalidVTableId(id))
5665
}
66+
67+
fn ty_of_bytes(&self, bytes: &[u8]) -> Result<&Ty> {
68+
let id = from_bytes!(usize, bytes);
69+
self.ty(id)
70+
}
5771
}
5872

5973
pub struct Evaluator<'a> {
@@ -110,15 +124,6 @@ impl IntervalOrOwned {
110124
}
111125
}
112126

113-
macro_rules! from_bytes {
114-
($ty:tt, $value:expr) => {
115-
($ty::from_le_bytes(match ($value).try_into() {
116-
Ok(x) => x,
117-
Err(_) => return Err(MirEvalError::TypeError("mismatched size")),
118-
}))
119-
};
120-
}
121-
122127
impl Address {
123128
fn from_bytes(x: &[u8]) -> Result<Self> {
124129
Ok(Address::from_usize(from_bytes!(usize, x)))
@@ -781,7 +786,18 @@ impl Evaluator<'_> {
781786
}
782787
_ => not_supported!("slice unsizing from non pointers"),
783788
},
784-
TyKind::Dyn(_) => not_supported!("dyn pointer unsize cast"),
789+
TyKind::Dyn(_) => match &current_ty.data(Interner).kind {
790+
TyKind::Raw(_, ty) | TyKind::Ref(_, _, ty) => {
791+
let vtable = self.vtable_map.id(ty.clone());
792+
let addr =
793+
self.eval_operand(operand, locals)?.get(&self)?;
794+
let mut r = Vec::with_capacity(16);
795+
r.extend(addr.iter().copied());
796+
r.extend(vtable.to_le_bytes().into_iter());
797+
Owned(r)
798+
}
799+
_ => not_supported!("dyn unsizing from non pointers"),
800+
},
785801
_ => not_supported!("unknown unsized cast"),
786802
}
787803
}
@@ -1227,44 +1243,8 @@ impl Evaluator<'_> {
12271243
let arg_bytes = args
12281244
.iter()
12291245
.map(|x| Ok(self.eval_operand(x, &locals)?.get(&self)?.to_owned()))
1230-
.collect::<Result<Vec<_>>>()?
1231-
.into_iter();
1232-
let function_data = self.db.function_data(def);
1233-
let is_intrinsic = match &function_data.abi {
1234-
Some(abi) => *abi == Interned::new_str("rust-intrinsic"),
1235-
None => match def.lookup(self.db.upcast()).container {
1236-
hir_def::ItemContainerId::ExternBlockId(block) => {
1237-
let id = block.lookup(self.db.upcast()).id;
1238-
id.item_tree(self.db.upcast())[id.value].abi.as_deref()
1239-
== Some("rust-intrinsic")
1240-
}
1241-
_ => false,
1242-
},
1243-
};
1244-
let result = if is_intrinsic {
1245-
self.exec_intrinsic(
1246-
function_data.name.as_text().unwrap_or_default().as_str(),
1247-
arg_bytes,
1248-
generic_args,
1249-
&locals,
1250-
)?
1251-
} else if let Some(x) = self.detect_lang_function(def) {
1252-
self.exec_lang_item(x, arg_bytes)?
1253-
} else {
1254-
let (imp, generic_args) = lookup_impl_method(
1255-
self.db,
1256-
self.trait_env.clone(),
1257-
def,
1258-
generic_args.clone(),
1259-
);
1260-
let generic_args = self.subst_filler(&generic_args, &locals);
1261-
let def = imp.into();
1262-
let mir_body =
1263-
self.db.mir_body(def).map_err(|e| MirEvalError::MirLowerError(imp, e))?;
1264-
self.interpret_mir(&mir_body, arg_bytes, generic_args)
1265-
.map_err(|e| MirEvalError::InFunction(imp, Box::new(e)))?
1266-
};
1267-
self.write_memory(dest_addr, &result)?;
1246+
.collect::<Result<Vec<_>>>()?;
1247+
self.exec_fn_with_args(def, arg_bytes, generic_args, locals, dest_addr)?;
12681248
}
12691249
CallableDefId::StructId(id) => {
12701250
let (size, variant_layout, tag) =
@@ -1284,6 +1264,77 @@ impl Evaluator<'_> {
12841264
Ok(())
12851265
}
12861266

1267+
fn exec_fn_with_args(
1268+
&mut self,
1269+
def: FunctionId,
1270+
arg_bytes: Vec<Vec<u8>>,
1271+
generic_args: Substitution,
1272+
locals: &Locals<'_>,
1273+
dest_addr: Address,
1274+
) -> Result<()> {
1275+
let function_data = self.db.function_data(def);
1276+
let is_intrinsic = match &function_data.abi {
1277+
Some(abi) => *abi == Interned::new_str("rust-intrinsic"),
1278+
None => match def.lookup(self.db.upcast()).container {
1279+
hir_def::ItemContainerId::ExternBlockId(block) => {
1280+
let id = block.lookup(self.db.upcast()).id;
1281+
id.item_tree(self.db.upcast())[id.value].abi.as_deref()
1282+
== Some("rust-intrinsic")
1283+
}
1284+
_ => false,
1285+
},
1286+
};
1287+
let result = if is_intrinsic {
1288+
self.exec_intrinsic(
1289+
function_data.name.as_text().unwrap_or_default().as_str(),
1290+
arg_bytes.iter().cloned(),
1291+
generic_args,
1292+
&locals,
1293+
)?
1294+
} else if let Some(x) = self.detect_lang_function(def) {
1295+
self.exec_lang_item(x, &arg_bytes)?
1296+
} else {
1297+
if let Some(self_ty_idx) =
1298+
is_dyn_method(self.db, self.trait_env.clone(), def, generic_args.clone())
1299+
{
1300+
// In the layout of current possible receiver, which at the moment of writing this code is one of
1301+
// `&T`, `&mut T`, `Box<T>`, `Rc<T>`, `Arc<T>`, and `Pin<P>` where `P` is one of possible recievers,
1302+
// the vtable is exactly in the `[ptr_size..2*ptr_size]` bytes. So we can use it without branching on
1303+
// the type.
1304+
let ty = self
1305+
.vtable_map
1306+
.ty_of_bytes(&arg_bytes[0][self.ptr_size()..self.ptr_size() * 2])?;
1307+
let ty = GenericArgData::Ty(ty.clone()).intern(Interner);
1308+
let mut args_for_target = arg_bytes;
1309+
args_for_target[0] = args_for_target[0][0..self.ptr_size()].to_vec();
1310+
let generics_for_target = Substitution::from_iter(
1311+
Interner,
1312+
generic_args
1313+
.iter(Interner)
1314+
.enumerate()
1315+
.map(|(i, x)| if i == self_ty_idx { &ty } else { x })
1316+
);
1317+
return self.exec_fn_with_args(
1318+
def,
1319+
args_for_target,
1320+
generics_for_target,
1321+
locals,
1322+
dest_addr,
1323+
);
1324+
}
1325+
let (imp, generic_args) =
1326+
lookup_impl_method(self.db, self.trait_env.clone(), def, generic_args.clone());
1327+
let generic_args = self.subst_filler(&generic_args, &locals);
1328+
let def = imp.into();
1329+
let mir_body =
1330+
self.db.mir_body(def).map_err(|e| MirEvalError::MirLowerError(imp, e))?;
1331+
self.interpret_mir(&mir_body, arg_bytes.iter().cloned(), generic_args)
1332+
.map_err(|e| MirEvalError::InFunction(imp, Box::new(e)))?
1333+
};
1334+
self.write_memory(dest_addr, &result)?;
1335+
Ok(())
1336+
}
1337+
12871338
fn exec_fn_trait(
12881339
&mut self,
12891340
ft: FnTrait,
@@ -1317,12 +1368,9 @@ impl Evaluator<'_> {
13171368
Ok(())
13181369
}
13191370

1320-
fn exec_lang_item(
1321-
&self,
1322-
x: LangItem,
1323-
mut args: std::vec::IntoIter<Vec<u8>>,
1324-
) -> Result<Vec<u8>> {
1371+
fn exec_lang_item(&self, x: LangItem, args: &[Vec<u8>]) -> Result<Vec<u8>> {
13251372
use LangItem::*;
1373+
let mut args = args.iter();
13261374
match x {
13271375
PanicFmt | BeginPanic => Err(MirEvalError::Panic),
13281376
SliceLen => {

crates/hir-ty/src/mir/lower.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,14 @@ impl MirLowerCtx<'_> {
230230
self.lower_const(c, current, place, expr_id.into())?;
231231
return Ok(Some(current))
232232
},
233-
_ => not_supported!("associated functions and types"),
233+
hir_def::AssocItemId::FunctionId(_) => {
234+
// FnDefs are zero sized, no action is needed.
235+
return Ok(Some(current))
236+
}
237+
hir_def::AssocItemId::TypeAliasId(_) => {
238+
// FIXME: If it is unreachable, use proper error instead of `not_supported`.
239+
not_supported!("associated functions and types")
240+
},
234241
}
235242
} else if let Some(variant) = self
236243
.infer

0 commit comments

Comments
 (0)