Skip to content

Commit 87c5746

Browse files
committed
Attach types to value constants
1 parent 93e7e87 commit 87c5746

File tree

7 files changed

+45
-27
lines changed

7 files changed

+45
-27
lines changed

crates/formality-types/src/fold.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ impl Fold for Ty {
7676
impl Fold for Const {
7777
fn substitute(&self, substitution_fn: SubstitutionFn<'_>) -> Self {
7878
match self.data() {
79-
ConstData::Value(v) => Self::new(v.substitute(substitution_fn)),
79+
ConstData::Value(v, ty) => Self::valtree(
80+
v.substitute(substitution_fn),
81+
ty.substitute(substitution_fn),
82+
),
8083
ConstData::Variable(v) => match substitution_fn(v.clone()) {
8184
None => self.clone(),
8285
Some(Parameter::Const(c)) => c,

crates/formality-types/src/grammar/consts.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mod valtree;
22

33
use crate::cast::{Upcast, UpcastFrom};
44

5-
use super::{Parameter, Variable};
5+
use super::{Parameter, Ty, Variable};
66
use formality_macros::{term, Visit};
77
use std::sync::Arc;
88
pub use valtree::*;
@@ -21,20 +21,31 @@ impl Const {
2121
data: Arc::new(data.upcast()),
2222
}
2323
}
24+
25+
pub fn valtree(vt: impl Upcast<ValTree>, ty: impl Upcast<Ty>) -> Self {
26+
Self::new(ConstData::Value(vt.upcast(), ty.upcast()))
27+
}
28+
2429
pub fn as_variable(&self) -> Option<Variable> {
2530
match self.data() {
26-
ConstData::Value(_) => None,
31+
ConstData::Value(_, _) => None,
2732
ConstData::Variable(var) => Some(var.clone()),
2833
}
2934
}
3035
}
3136

3237
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Visit)]
3338
pub enum ConstData {
34-
Value(ValTree),
39+
Value(ValTree, Ty),
3540
Variable(Variable),
3641
}
3742

43+
impl UpcastFrom<Self> for ConstData {
44+
fn upcast_from(term: Self) -> Self {
45+
term
46+
}
47+
}
48+
3849
#[term]
3950
pub enum Bool {
4051
#[grammar(true)]
@@ -45,13 +56,7 @@ pub enum Bool {
4556

4657
impl UpcastFrom<Bool> for Const {
4758
fn upcast_from(term: Bool) -> Self {
48-
Self::new(ValTree::upcast_from(term))
49-
}
50-
}
51-
52-
impl UpcastFrom<ValTree> for ConstData {
53-
fn upcast_from(v: ValTree) -> Self {
54-
Self::Value(v)
59+
Self::new(ConstData::Value(term.upcast(), Ty::bool()))
5560
}
5661
}
5762

crates/formality-types/src/grammar/consts/valtree.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use formality_macros::Visit;
22

33
use crate::cast::{Upcast, UpcastFrom};
44

5-
use super::{Bool, ConstData};
5+
use super::Bool;
66

77
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Visit)]
88
pub enum ValTree {
@@ -48,9 +48,9 @@ impl UpcastFrom<Scalar> for ValTree {
4848
}
4949
}
5050

51-
impl UpcastFrom<Scalar> for ConstData {
52-
fn upcast_from(s: Scalar) -> Self {
53-
ValTree::upcast_from(s).upcast()
51+
impl UpcastFrom<Self> for ValTree {
52+
fn upcast_from(s: Self) -> Self {
53+
s
5454
}
5555
}
5656

crates/formality-types/src/grammar/ty.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ impl Ty {
8080
vec![l.to::<Parameter>(), self.to::<Parameter>()],
8181
)
8282
}
83+
84+
pub fn bool() -> Ty {
85+
RigidTy {
86+
name: RigidName::ScalarId(ScalarId::Bool),
87+
parameters: vec![],
88+
}
89+
.upcast()
90+
}
8391
}
8492

8593
impl UpcastFrom<TyData> for Ty {

crates/formality-types/src/grammar/ty/debug_impls.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ impl std::fmt::Debug for super::Ty {
1414
impl std::fmt::Debug for Const {
1515
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1616
match self.data() {
17-
crate::grammar::ConstData::Value(valtree) => write!(f, "{valtree:?}"),
17+
crate::grammar::ConstData::Value(valtree, ty) => write!(f, "{valtree:?}_{ty:?}"),
1818
crate::grammar::ConstData::Variable(r) => write!(f, "{r:?}"),
1919
}
2020
}

crates/formality-types/src/grammar/ty/parse_impls.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,15 @@ impl Parse for Const {
215215

216216
#[tracing::instrument(level = "trace", ret)]
217217
fn parse_int<'t>(scope: &crate::parse::Scope, text: &'t str) -> ParseResult<'t, Const> {
218-
let (pos, _) = text
219-
.match_indices(|c: char| !c.is_numeric())
220-
.next()
221-
.unwrap_or((text.len(), text));
222-
let (num, text) = text.split_at(pos);
218+
let (num, text) = text.split_once('_').ok_or_else(|| {
219+
ParseError::at(
220+
text,
221+
format!("numeric constants must be followed by an `_` and their type"),
222+
)
223+
})?;
223224
let n: u128 = num
224225
.parse()
225226
.map_err(|err| ParseError::at(num, format!("could not parse number: {err}")))?;
226-
Ok((Const::new(Scalar::new(n)), text))
227+
let (ty, text) = Ty::parse(scope, text)?;
228+
Ok((Const::valtree(Scalar::new(n), ty), text))
227229
}

tests/consts.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ fn test_holds() {
2525
expect_test::expect![[r#"
2626
Err(
2727
Error {
28-
context: "check_trait_impl(impl <> Foo < const 0 > for (rigid (scalar u32)) where [] { })",
29-
source: "failed to prove {Foo((rigid (scalar u32)), const 0)} given {}, got {}",
28+
context: "check_trait_impl(impl <> Foo < const 0_(rigid (scalar bool)) > for (rigid (scalar u32)) where [] { })",
29+
source: "failed to prove {Foo((rigid (scalar u32)), const 0_(rigid (scalar bool)))} given {}, got {}",
3030
},
3131
)
3232
"#]]
@@ -46,8 +46,8 @@ fn test_mismatch() {
4646
expect_test::expect![[r#"
4747
Err(
4848
Error {
49-
context: "check_trait_impl(impl <> Foo < const 42 > for (rigid (scalar u32)) where [] { })",
50-
source: "failed to prove {Foo((rigid (scalar u32)), const 42)} given {}, got {}",
49+
context: "check_trait_impl(impl <> Foo < const 42_(rigid (scalar u32)) > for (rigid (scalar u32)) where [] { })",
50+
source: "failed to prove {Foo((rigid (scalar u32)), const 42_(rigid (scalar u32)))} given {}, got {}",
5151
},
5252
)
5353
"#]]
@@ -56,7 +56,7 @@ fn test_mismatch() {
5656
crate Foo {
5757
trait Foo<const C> where [type_of_const C is bool] {}
5858
59-
impl<> Foo<const 42> for u32 where [] {}
59+
impl<> Foo<const 42_u32> for u32 where [] {}
6060
}
6161
]",
6262
));

0 commit comments

Comments
 (0)