Skip to content

Commit 1f52c05

Browse files
committed
feat: propagate generics to generated function
1 parent 2ef7858 commit 1f52c05

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ mod llvm_enzyme {
7373
}
7474

7575
// Get information about the function the macro is applied to
76-
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
76+
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
7777
match &iitem.kind {
78-
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
79-
Some((iitem.vis.clone(), sig.clone(), ident.clone()))
78+
ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
79+
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
8080
}
8181
_ => None,
8282
}
@@ -210,17 +210,19 @@ mod llvm_enzyme {
210210
}
211211
let dcx = ecx.sess.dcx();
212212

213-
// first get information about the annotable item:
214-
let Some((vis, sig, primal)) = (match &item {
213+
// first get information about the annotable item: visibility, signature, name and generic
214+
// parameters.
215+
// these will be used to generate the differentiated version of the function
216+
let Some((vis, sig, primal, generics)) = (match &item {
215217
Annotatable::Item(iitem) => extract_item_info(iitem),
216218
Annotatable::Stmt(stmt) => match &stmt.kind {
217219
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
218220
_ => None,
219221
},
220222
Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
221223
match &assoc_item.kind {
222-
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
223-
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
224+
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
225+
Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
224226
}
225227
_ => None,
226228
}
@@ -312,7 +314,7 @@ mod llvm_enzyme {
312314
defaultness: ast::Defaultness::Final,
313315
sig: d_sig,
314316
ident: first_ident(&meta_item_vec[0]),
315-
generics: Generics::default(),
317+
generics,
316318
contract: None,
317319
body: Some(d_body),
318320
define_opaque: None,

0 commit comments

Comments
 (0)