Skip to content

Commit d711a4a

Browse files
committed
auto-generate constructors for structs, variants
1 parent f064001 commit d711a4a

File tree

10 files changed

+100
-24
lines changed

10 files changed

+100
-24
lines changed

crates/formality-check/src/where_clauses.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl super::Check<'_> {
5151
WhereClauseData::TypeOfConst(ct, ty) => {
5252
match ct.data() {
5353
ConstData::Value(_, t) => {
54-
self.prove_goal(in_env, &assumptions, Relation::eq(ty, t))?
54+
self.prove_goal(in_env, &assumptions, Relation::equals(ty, t))?
5555
}
5656
ConstData::Variable(_) => {}
5757
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
use convert_case::{Case, Casing};
2+
use proc_macro2::{Ident, TokenStream};
3+
use quote::quote_spanned;
4+
use synstructure::VariantInfo;
5+
6+
const RUST_KEYWORDS: &[&str] = &[
7+
"mut", "true", "false", "const", "static", "ref", "struct", "enum", "trait", "union", "fn",
8+
"use", "return", "move", "let", "break", "loop", "continue", "await", "if", "for",
9+
];
10+
11+
/// Create methods to build this type.
12+
///
13+
/// For a struct, we create a single `new` method that takes each field.
14+
///
15+
/// For an enum, we create methods named after each variant.
16+
pub(crate) fn constructor_methods(s: synstructure::Structure) -> TokenStream {
17+
match s.ast().data {
18+
syn::Data::Struct(_) => derive_new_for_struct(s),
19+
syn::Data::Enum(_) => derive_new_for_variants(s),
20+
syn::Data::Union(_) => return Default::default(),
21+
}
22+
}
23+
24+
fn derive_new_for_struct(s: synstructure::Structure<'_>) -> TokenStream {
25+
derive_new_for_variant(
26+
&s,
27+
&s.variants()[0],
28+
&Ident::new("new", s.ast().ident.span()),
29+
)
30+
}
31+
32+
fn derive_new_for_variants(s: synstructure::Structure<'_>) -> TokenStream {
33+
s.variants()
34+
.iter()
35+
.map(|v| {
36+
let mut fn_name = v.ast().ident.to_string().to_case(Case::Snake);
37+
if RUST_KEYWORDS.iter().any(|&kw| kw == fn_name) {
38+
fn_name.push('_');
39+
}
40+
let fn_name = Ident::new(&fn_name, v.ast().ident.span());
41+
derive_new_for_variant(&s, v, &fn_name)
42+
})
43+
.collect()
44+
}
45+
46+
fn derive_new_for_variant(
47+
s: &synstructure::Structure<'_>,
48+
v: &VariantInfo<'_>,
49+
fn_name: &Ident,
50+
) -> TokenStream {
51+
let type_name = &s.ast().ident;
52+
let (impl_generics, type_generics, where_clauses) = s.ast().generics.split_for_impl();
53+
54+
let binding_names = v.bindings().iter().map(|b| &b.binding).collect::<Vec<_>>();
55+
let binding_types = v.bindings().iter().map(|b| &b.ast().ty).collect::<Vec<_>>();
56+
let construct = v.construct(|_b, i| {
57+
let name = binding_names[i];
58+
quote_spanned!(
59+
binding_names[i].span() =>
60+
formality_core::Upcast::upcast(#name)
61+
)
62+
});
63+
64+
quote_spanned! { v.ast().ident.span() =>
65+
#[allow(dead_code)]
66+
impl #impl_generics #type_name #type_generics
67+
where #where_clauses
68+
{
69+
pub fn #fn_name(
70+
#(
71+
#binding_names: impl formality_core::Upcast<#binding_types>,
72+
)*
73+
) -> Self {
74+
#construct
75+
}
76+
}
77+
}
78+
}

crates/formality-macros/src/custom.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use proc_macro2::TokenStream;
44
pub(crate) struct Customize {
55
pub parse: bool,
66
pub debug: bool,
7+
pub constructors: bool,
78
}
89

910
impl syn::parse::Parse for Customize {
@@ -28,6 +29,13 @@ impl syn::parse::Parse for Customize {
2829
result.debug = true;
2930
}
3031

32+
proc_macro2::TokenTree::Ident(ident) if ident == "constructors" => {
33+
if result.constructors {
34+
return Err(syn::Error::new(ident.span(), "already customizing debug"));
35+
}
36+
result.constructors = true;
37+
}
38+
3139
_ => {
3240
return Err(syn::Error::new(
3341
token.span(),

crates/formality-macros/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ extern crate proc_macro;
88
mod as_methods;
99
mod attrs;
1010
mod cast;
11+
mod constructors;
1112
mod custom;
1213
mod debug;
1314
mod fixed_point;

crates/formality-macros/src/term.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use syn::DeriveInput;
55
use crate::{
66
attrs::{self, remove_formality_attributes},
77
cast::{downcast_impls, upcast_impls},
8+
constructors::constructor_methods,
89
debug::derive_debug_with_spec,
910
fold::derive_fold,
1011
parse::derive_parse_with_spec,
@@ -35,6 +36,11 @@ pub fn term(spec: Option<FormalitySpec>, mut input: DeriveInput) -> syn::Result<
3536
let term_impl = derive_term(synstructure::Structure::new(&input));
3637
let downcast_impls = downcast_impls(synstructure::Structure::new(&input));
3738
let upcast_impls = upcast_impls(synstructure::Structure::new(&input));
39+
let constructors = if customize.constructors {
40+
None
41+
} else {
42+
Some(constructor_methods(synstructure::Structure::new(&input)))
43+
};
3844
remove_formality_attributes(&mut input);
3945

4046
Ok(quote! {
@@ -48,6 +54,7 @@ pub fn term(spec: Option<FormalitySpec>, mut input: DeriveInput) -> syn::Result<
4854
#term_impl
4955
#(#downcast_impls)*
5056
#(#upcast_impls)*
57+
#constructors
5158
})
5259
}
5360

crates/formality-prove/src/prove/prove_eq.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use super::{constraints::Constraints, env::Env};
1717

1818
/// Goal(s) to prove `a` and `b` are equal
1919
pub fn eq(a: impl Upcast<Parameter>, b: impl Upcast<Parameter>) -> Relation {
20-
Relation::eq(a, b)
20+
Relation::equals(a, b)
2121
}
2222

2323
judgment_fn! {

crates/formality-types/src/grammar/consts.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub use valtree::*;
77

88
#[term]
99
#[cast]
10+
#[customize(constructors)] // FIXME: figure out upcasts with arc or special-case
1011
pub struct Const {
1112
data: Arc<ConstData>,
1213
}

crates/formality-types/src/grammar/formulas.rs

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -171,18 +171,6 @@ pub enum Relation {
171171
}
172172

173173
impl Relation {
174-
pub fn eq(p1: impl Upcast<Parameter>, p2: impl Upcast<Parameter>) -> Self {
175-
Self::Equals(p1.upcast(), p2.upcast())
176-
}
177-
178-
pub fn outlives(p1: impl Upcast<Parameter>, p2: impl Upcast<Parameter>) -> Self {
179-
Self::Outlives(p1.upcast(), p2.upcast())
180-
}
181-
182-
pub fn sub(p1: impl Upcast<Parameter>, p2: impl Upcast<Parameter>) -> Self {
183-
Self::Sub(p1.upcast(), p2.upcast())
184-
}
185-
186174
#[tracing::instrument(level = "trace", ret)]
187175
pub fn debone(&self) -> (Skeleton, Vec<Parameter>) {
188176
match self {
@@ -200,15 +188,6 @@ pub struct TraitRef {
200188
pub parameters: Parameters,
201189
}
202190

203-
impl TraitRef {
204-
pub fn new(id: &TraitId, parameters: impl Upcast<Vec<Parameter>>) -> Self {
205-
Self {
206-
trait_id: id.clone(),
207-
parameters: parameters.upcast(),
208-
}
209-
}
210-
}
211-
212191
impl TraitId {
213192
pub fn with(
214193
&self,

crates/formality-types/src/grammar/ty.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use super::{
1313

1414
#[term]
1515
#[cast]
16+
#[customize(constructors)] // FIXME: figure out upcasts with arc or special-case
1617
pub struct Ty {
1718
data: Arc<TyData>,
1819
}
@@ -294,6 +295,7 @@ pub enum Variance {
294295

295296
#[term]
296297
#[cast]
298+
#[customize(constructors)] // FIXME: figure out upcasts with arc or special-case
297299
pub struct Lt {
298300
data: Arc<LtData>,
299301
}

crates/formality-types/src/grammar/wc.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ impl Wcs {
2626
assert_eq!(a.len(), b.len());
2727
a.into_iter()
2828
.zip(b)
29-
.map(|(a, b)| Relation::eq(a, b))
29+
.map(|(a, b)| Relation::equals(a, b))
3030
.upcasted()
3131
.collect()
3232
}

0 commit comments

Comments
 (0)