Skip to content

Commit 93e7e87

Browse files
committed
Implement const generics
1 parent 09e479d commit 93e7e87

File tree

14 files changed

+319
-6
lines changed

14 files changed

+319
-6
lines changed

crates/formality-check/src/where_clauses.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ impl super::Check<'_> {
2222
Ok(())
2323
}
2424

25-
#[context("prove_where_clauses_well_formed({where_clause:?})")]
25+
#[context("prove_where_clause_well_formed({where_clause:?})")]
2626
fn prove_where_clause_well_formed(
2727
&self,
2828
in_env: &Env,
@@ -45,6 +45,10 @@ impl super::Check<'_> {
4545
let wc = e.instantiate_universally(binder);
4646
self.prove_where_clause_well_formed(&e, assumptions, &wc)
4747
}
48+
WhereClauseData::TypeOfConst(ct, ty) => {
49+
self.prove_parameter_well_formed(in_env, &assumptions, ct.clone())?;
50+
self.prove_parameter_well_formed(in_env, assumptions, ty.clone())
51+
}
4852
}
4953
}
5054

crates/formality-prove/src/prove/prove_wf.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,11 @@ judgment_fn! {
3232
--- ("tuples")
3333
(prove_wf(decls, env, assumptions, RigidTy { name: RigidName::Tuple(_), parameters }) => c)
3434
)
35+
36+
(
37+
(for_all(&decls, &env, &assumptions, &parameters, &prove_wf) => c)
38+
--- ("integers and booleans")
39+
(prove_wf(decls, env, assumptions, RigidTy { name: RigidName::ScalarId(_), parameters }) => c)
40+
)
3541
}
3642
}

crates/formality-rust/src/grammar.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use formality_macros::term;
44
use formality_types::{
55
cast::Upcast,
66
grammar::{
7-
AdtId, AssociatedItemId, Binder, CrateId, Fallible, FieldId, FnId, Lt, Parameter, TraitId,
8-
TraitRef, Ty, Wc,
7+
AdtId, AssociatedItemId, Binder, Const, CrateId, Fallible, FieldId, FnId, Lt, Parameter,
8+
TraitId, TraitRef, Ty, Wc,
99
},
1010
term::Term,
1111
};
@@ -331,6 +331,7 @@ impl WhereClause {
331331
let wc = where_clause.invert()?;
332332
Some(Wc::for_all(&vars, wc))
333333
}
334+
WhereClauseData::TypeOfConst(_, _) => None,
334335
}
335336
}
336337
}
@@ -345,6 +346,9 @@ pub enum WhereClauseData {
345346

346347
#[grammar(for $v0)]
347348
ForAll(Binder<WhereClause>),
349+
350+
#[grammar(type_of_const $v0 is $v1)]
351+
TypeOfConst(Const, Ty),
348352
}
349353

350354
#[term($data)]

crates/formality-rust/src/prove.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,9 @@ impl ToWcs for WhereClause {
386386
.map(|wc| Wc::for_all(&vars, wc))
387387
.collect()
388388
}
389+
WhereClauseData::TypeOfConst(ct, ty) => {
390+
Predicate::ConstHasType(ct.clone(), ty.clone()).upcast()
391+
}
389392
}
390393
}
391394
}

crates/formality-types/src/fold.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33
use crate::{
44
cast::Upcast,
55
collections::Set,
6-
grammar::{Lt, LtData, Parameter, Ty, TyData, Variable},
6+
grammar::{Const, ConstData, Lt, LtData, Parameter, Ty, TyData, ValTree, Variable},
77
visit::Visit,
88
};
99

@@ -73,6 +73,25 @@ impl Fold for Ty {
7373
}
7474
}
7575

76+
impl Fold for Const {
77+
fn substitute(&self, substitution_fn: SubstitutionFn<'_>) -> Self {
78+
match self.data() {
79+
ConstData::Value(v) => Self::new(v.substitute(substitution_fn)),
80+
ConstData::Variable(v) => match substitution_fn(v.clone()) {
81+
None => self.clone(),
82+
Some(Parameter::Const(c)) => c,
83+
Some(param) => panic!("ill-kinded substitute: expected const, got {param:?}"),
84+
},
85+
}
86+
}
87+
}
88+
89+
impl Fold for ValTree {
90+
fn substitute(&self, _substitution_fn: SubstitutionFn<'_>) -> Self {
91+
self.clone()
92+
}
93+
}
94+
7695
impl Fold for Lt {
7796
fn substitute(&self, substitution_fn: SubstitutionFn<'_>) -> Self {
7897
match self.data() {

crates/formality-types/src/grammar.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
mod binder;
2+
mod consts;
23
mod formulas;
34
mod ids;
45
mod kinded;
56
mod ty;
67
mod wc;
78

89
pub use binder::*;
10+
pub use consts::*;
911
pub use formulas::*;
1012
pub use ids::*;
1113
pub use kinded::*;
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
mod valtree;
2+
3+
use crate::cast::{Upcast, UpcastFrom};
4+
5+
use super::{Parameter, Variable};
6+
use formality_macros::{term, Visit};
7+
use std::sync::Arc;
8+
pub use valtree::*;
9+
10+
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Visit)]
11+
pub struct Const {
12+
data: Arc<ConstData>,
13+
}
14+
impl Const {
15+
pub fn data(&self) -> &ConstData {
16+
&self.data
17+
}
18+
19+
pub fn new(data: impl Upcast<ConstData>) -> Self {
20+
Self {
21+
data: Arc::new(data.upcast()),
22+
}
23+
}
24+
pub fn as_variable(&self) -> Option<Variable> {
25+
match self.data() {
26+
ConstData::Value(_) => None,
27+
ConstData::Variable(var) => Some(var.clone()),
28+
}
29+
}
30+
}
31+
32+
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Visit)]
33+
pub enum ConstData {
34+
Value(ValTree),
35+
Variable(Variable),
36+
}
37+
38+
#[term]
39+
pub enum Bool {
40+
#[grammar(true)]
41+
True,
42+
#[grammar(false)]
43+
False,
44+
}
45+
46+
impl UpcastFrom<Bool> for Const {
47+
fn upcast_from(term: Bool) -> Self {
48+
Self::new(ValTree::upcast_from(term))
49+
}
50+
}
51+
52+
impl UpcastFrom<ValTree> for ConstData {
53+
fn upcast_from(v: ValTree) -> Self {
54+
Self::Value(v)
55+
}
56+
}
57+
58+
impl UpcastFrom<Variable> for ConstData {
59+
fn upcast_from(v: Variable) -> Self {
60+
Self::Variable(v)
61+
}
62+
}
63+
64+
impl UpcastFrom<Const> for Parameter {
65+
fn upcast_from(term: Const) -> Self {
66+
Self::Const(term)
67+
}
68+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use formality_macros::Visit;
2+
3+
use crate::cast::{Upcast, UpcastFrom};
4+
5+
use super::{Bool, ConstData};
6+
7+
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Visit)]
8+
pub enum ValTree {
9+
Leaf(Scalar),
10+
Branches(Vec<ValTree>),
11+
}
12+
13+
impl std::fmt::Debug for ValTree {
14+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15+
match self {
16+
Self::Leaf(s) => s.fmt(f),
17+
Self::Branches(branches) => branches.fmt(f),
18+
}
19+
}
20+
}
21+
22+
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Visit)]
23+
pub struct Scalar {
24+
bits: u128,
25+
}
26+
27+
impl std::fmt::Debug for Scalar {
28+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29+
self.bits.fmt(f)
30+
}
31+
}
32+
33+
impl Scalar {
34+
pub fn new(bits: u128) -> Self {
35+
Self { bits }
36+
}
37+
}
38+
39+
impl UpcastFrom<Bool> for ValTree {
40+
fn upcast_from(term: Bool) -> Self {
41+
Scalar::upcast_from(term).upcast()
42+
}
43+
}
44+
45+
impl UpcastFrom<Scalar> for ValTree {
46+
fn upcast_from(s: Scalar) -> Self {
47+
Self::Leaf(s)
48+
}
49+
}
50+
51+
impl UpcastFrom<Scalar> for ConstData {
52+
fn upcast_from(s: Scalar) -> Self {
53+
ValTree::upcast_from(s).upcast()
54+
}
55+
}
56+
57+
impl UpcastFrom<Bool> for Scalar {
58+
fn upcast_from(term: Bool) -> Self {
59+
Scalar { bits: term as u128 }
60+
}
61+
}

crates/formality-types/src/grammar/formulas.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::cast::To;
44
use crate::cast::Upcast;
55
use crate::cast_impl;
66

7+
use super::Const;
78
use super::Parameter;
89
use super::Parameters;
910
use super::TraitId;
@@ -27,6 +28,9 @@ pub enum Predicate {
2728

2829
#[grammar(@IsLocal($v0))]
2930
IsLocal(TraitRef),
31+
32+
#[grammar(@ConstHasType($v0, $v1))]
33+
ConstHasType(Const, Ty),
3034
}
3135

3236
/// A coinductive predicate is one that can be proven via a cycle.
@@ -79,6 +83,7 @@ pub enum Skeleton {
7983
WellFormed,
8084
WellFormedTraitRef(TraitId),
8185
IsLocal(TraitId),
86+
ConstHasType,
8287

8388
Equals,
8489
Sub,
@@ -115,6 +120,10 @@ impl Predicate {
115120
trait_id,
116121
parameters,
117122
}) => (Skeleton::IsLocal(trait_id.clone()), parameters.clone()),
123+
Predicate::ConstHasType(ct, ty) => (
124+
Skeleton::ConstHasType,
125+
vec![ct.clone().upcast(), ty.clone().upcast()],
126+
),
118127
}
119128
}
120129
}

crates/formality-types/src/grammar/ty.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ use crate::{
1313
fold::Fold,
1414
};
1515

16-
use super::{AdtId, AssociatedItemId, Binder, FnId, TraitId};
16+
use super::{
17+
consts::{Const, ConstData},
18+
AdtId, AssociatedItemId, Binder, FnId, TraitId,
19+
};
1720

1821
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
1922
pub struct Ty {
@@ -263,13 +266,16 @@ pub enum Parameter {
263266
Ty(Ty),
264267
#[cast]
265268
Lt(Lt),
269+
#[grammar(const $v0)]
270+
Const(Const),
266271
}
267272

268273
impl Parameter {
269274
pub fn kind(&self) -> ParameterKind {
270275
match self {
271276
Parameter::Ty(_) => ParameterKind::Ty,
272277
Parameter::Lt(_) => ParameterKind::Lt,
278+
Parameter::Const(_) => ParameterKind::Const,
273279
}
274280
}
275281

@@ -281,13 +287,15 @@ impl Parameter {
281287
match self {
282288
Parameter::Ty(v) => v.as_variable(),
283289
Parameter::Lt(v) => v.as_variable(),
290+
Parameter::Const(v) => v.as_variable(),
284291
}
285292
}
286293

287294
pub fn data(&self) -> ParameterData<'_> {
288295
match self {
289296
Parameter::Ty(v) => ParameterData::Ty(v.data()),
290297
Parameter::Lt(v) => ParameterData::Lt(v.data()),
298+
Parameter::Const(v) => ParameterData::Const(v.data()),
291299
}
292300
}
293301
}
@@ -297,13 +305,15 @@ pub type Parameters = Vec<Parameter>;
297305
pub enum ParameterData<'me> {
298306
Ty(&'me TyData),
299307
Lt(&'me LtData),
308+
Const(&'me ConstData),
300309
}
301310

302311
#[term]
303312
#[derive(Copy)]
304313
pub enum ParameterKind {
305314
Ty,
306315
Lt,
316+
Const,
307317
}
308318

309319
#[term]
@@ -517,6 +527,7 @@ impl UpcastFrom<Variable> for Parameter {
517527
match v.kind() {
518528
ParameterKind::Lt => Lt::new(v).upcast(),
519529
ParameterKind::Ty => Ty::new(v).upcast(),
530+
ParameterKind::Const => Const::new(v).upcast(),
520531
}
521532
}
522533
}

0 commit comments

Comments
 (0)