diff --git a/gen/src/cfg.rs b/gen/src/cfg.rs index da589085b..74774a0ed 100644 --- a/gen/src/cfg.rs +++ b/gen/src/cfg.rs @@ -113,6 +113,7 @@ impl Api { match self { Api::Include(include) => &include.cfg, Api::Struct(strct) => &strct.cfg, + Api::TupleStruct(_) => &CfgExpr::Unconditional, Api::Enum(enm) => &enm.cfg, Api::CxxType(ety) | Api::RustType(ety) => &ety.cfg, Api::CxxFunction(efn) | Api::RustFunction(efn) => &efn.cfg, diff --git a/gen/src/namespace.rs b/gen/src/namespace.rs index b79c38f90..269197d18 100644 --- a/gen/src/namespace.rs +++ b/gen/src/namespace.rs @@ -8,7 +8,7 @@ impl Api { Api::CxxType(ety) | Api::RustType(ety) => &ety.name.namespace, Api::Enum(enm) => &enm.name.namespace, Api::Struct(strct) => &strct.name.namespace, - Api::Impl(_) | Api::Include(_) | Api::TypeAlias(_) => Default::default(), + Api::TupleStruct(_) | Api::Impl(_) | Api::Include(_) | Api::TypeAlias(_) => Default::default(), } } } diff --git a/gen/src/write.rs b/gen/src/write.rs index dddec7a55..cfb079649 100644 --- a/gen/src/write.rs +++ b/gen/src/write.rs @@ -3,16 +3,17 @@ use crate::gen::nested::NamespaceEntries; use crate::gen::out::OutFile; use crate::gen::{builtin, include, Opt}; use crate::syntax::atom::Atom::{self, *}; -use crate::syntax::instantiate::{ImplKey, NamedImplKey}; +use crate::syntax::instantiate::{DoubleNamedImplKey, ImplKey, NamedImplKey}; use crate::syntax::map::UnorderedMap as Map; use crate::syntax::set::UnorderedSet; -use crate::syntax::symbol::{self, Symbol}; +use crate::syntax::symbol::{self, Segment, Symbol}; use crate::syntax::trivial::{self, TrivialReason}; use crate::syntax::{ - derive, mangle, Api, Doc, Enum, EnumRepr, ExternFn, ExternType, Pair, Signature, Struct, Trait, - Type, TypeAlias, Types, Var, + derive, mangle, Api, Doc, Enum, EnumRepr, ExternFn, ExternType, ForeignName, Pair, Signature, Struct, Trait, + Type, TypeAlias, Types, }; use proc_macro2::Ident; +use std::iter::once; pub(super) fn gen(apis: &[Api], types: &Types, opt: &Opt, header: bool) -> Vec { let mut out_file = OutFile::new(header, opt, types); @@ -219,6 +220,7 @@ fn pick_includes_and_builtins(out: &mut OutFile, apis: &[Api]) { Type::SharedPtr(_) | Type::WeakPtr(_) => out.include.memory = true, Type::Str(_) => out.builtin.rust_str = true, Type::CxxVector(_) => out.include.vector = true, + Type::CxxFunction(_) => out.include.functional = true, Type::Fn(_) => out.builtin.rust_fn = true, Type::SliceRef(_) => out.builtin.rust_slice = true, Type::Array(_) => out.include.array = true, @@ -695,6 +697,41 @@ fn begin_function_definition(out: &mut OutFile) { } } +fn write_cxx_function_argument<'a>(out: &mut OutFile<'a>, name: &ForeignName, ty: &Type) { + if let Type::RustBox(_) = ty { + write_type(out, ty); + write!(out, "::from_raw({})", name); + } else if let Type::UniquePtr(_) = ty { + write_type(out, ty); + write!(out, "({})", name); + } else if ty == RustString { + out.builtin.unsafe_bitcopy = true; + write!( + out, + "::rust::String(::rust::unsafe_bitcopy, *{})", + name, + ); + } else if let Type::RustVec(_) = ty { + out.builtin.unsafe_bitcopy = true; + write_type(out, ty); + write!(out, "(::rust::unsafe_bitcopy, *{})", name); + } else if out.types.needs_indirect_abi(ty) { + out.include.utility = true; + write!(out, "::std::move(*{})", name); + } else { + write!(out, "{}", name); + } +} + +fn write_cxx_function_parameter<'a>(out: &mut OutFile<'a>, name: &ForeignName, ty: &Type) { + if ty == RustString { + write!(out, "const "); + } else if let Type::RustVec(_) = ty { + write!(out, "const "); + } + write_extern_arg(out, name, ty); +} + fn write_cxx_function_shim<'a>(out: &mut OutFile<'a>, efn: &'a ExternFn) { out.next_section(); out.set_namespace(&efn.name.namespace); @@ -722,12 +759,7 @@ fn write_cxx_function_shim<'a>(out: &mut OutFile<'a>, efn: &'a ExternFn) { if i > 0 || efn.receiver.is_some() { write!(out, ", "); } - if arg.ty == RustString { - write!(out, "const "); - } else if let Type::RustVec(_) = arg.ty { - write!(out, "const "); - } - write_extern_arg(out, arg); + write_cxx_function_parameter(out, &arg.name.cxx, &arg.ty); } let indirect_return = indirect_return(efn, out.types); if indirect_return { @@ -811,29 +843,7 @@ fn write_cxx_function_shim<'a>(out: &mut OutFile<'a>, efn: &'a ExternFn) { if i > 0 { write!(out, ", "); } - if let Type::RustBox(_) = &arg.ty { - write_type(out, &arg.ty); - write!(out, "::from_raw({})", arg.name.cxx); - } else if let Type::UniquePtr(_) = &arg.ty { - write_type(out, &arg.ty); - write!(out, "({})", arg.name.cxx); - } else if arg.ty == RustString { - out.builtin.unsafe_bitcopy = true; - write!( - out, - "::rust::String(::rust::unsafe_bitcopy, *{})", - arg.name.cxx, - ); - } else if let Type::RustVec(_) = arg.ty { - out.builtin.unsafe_bitcopy = true; - write_type(out, &arg.ty); - write!(out, "(::rust::unsafe_bitcopy, *{})", arg.name.cxx); - } else if out.types.needs_indirect_abi(&arg.ty) { - out.include.utility = true; - write!(out, "::std::move(*{})", arg.name.cxx); - } else { - write!(out, "{}", arg.name.cxx); - } + write_cxx_function_argument(out, &arg.name.cxx, &arg.ty); } write!(out, ")"); match &efn.ret { @@ -920,7 +930,7 @@ fn write_rust_function_decl_impl( if needs_comma { write!(out, ", "); } - write_extern_arg(out, arg); + write_extern_arg(out, &arg.name.cxx, &arg.ty); needs_comma = true; } if indirect_return(sig, out.types) { @@ -1196,29 +1206,40 @@ fn write_extern_return_type_space(out: &mut OutFile, ty: &Option) { } } -fn write_extern_arg(out: &mut OutFile, arg: &Var) { - match &arg.ty { +fn write_extern_arg(out: &mut OutFile, name: &ForeignName, ty: &Type) { + match ty { Type::RustBox(ty) | Type::UniquePtr(ty) | Type::CxxVector(ty) => { write_type_space(out, &ty.inner); write!(out, "*"); } - _ => write_type_space(out, &arg.ty), + _ => write_type_space(out, ty), } - if out.types.needs_indirect_abi(&arg.ty) { + if out.types.needs_indirect_abi(ty) { write!(out, "*"); } - write!(out, "{}", arg.name.cxx); + write!(out, "{}", name); } fn write_type(out: &mut OutFile, ty: &Type) { match ty { Type::Ident(ident) => match Atom::from(&ident.rust) { Some(atom) => write_atom(out, atom), - None => write!( - out, - "{}", - out.types.resolve(ident).name.to_fully_qualified(), - ), + None => { + if let Some(tuple_struct) = out.types.resolve_tuple_struct(ident) { + for (i, ty) in tuple_struct.types.iter().enumerate() { + if i > 0 { + write!(out, ", "); + } + write_type(out, &ty); + } + } else { + write!( + out, + "{}", + out.types.resolve(ident).name.to_fully_qualified(), + ) + } + }, }, Type::RustBox(ty) => { write!(out, "::rust::Box<"); @@ -1250,6 +1271,13 @@ fn write_type(out: &mut OutFile, ty: &Type) { write_type(out, &ty.inner); write!(out, ">"); } + Type::CxxFunction(ty) => { + write!(out, "::std::function<"); + write_type(out, &ty.second); + write!(out, "("); + write_type(out, &ty.first); + write!(out, ")>"); + } Type::Ref(r) => { write_pointee_type(out, &r.inner, r.mutable); write!(out, " &"); @@ -1289,7 +1317,9 @@ fn write_type(out: &mut OutFile, ty: &Type) { write_type(out, &a.inner); write!(out, ", {}>", &a.len); } - Type::Void(_) => unreachable!(), + Type::Void(_) => { + write!(out, "void"); + }, } } @@ -1308,27 +1338,6 @@ fn write_pointee_type(out: &mut OutFile, inner: &Type, mutable: bool) { } } -fn write_atom(out: &mut OutFile, atom: Atom) { - match atom { - Bool => write!(out, "bool"), - Char => write!(out, "char"), - U8 => write!(out, "::std::uint8_t"), - U16 => write!(out, "::std::uint16_t"), - U32 => write!(out, "::std::uint32_t"), - U64 => write!(out, "::std::uint64_t"), - Usize => write!(out, "::std::size_t"), - I8 => write!(out, "::std::int8_t"), - I16 => write!(out, "::std::int16_t"), - I32 => write!(out, "::std::int32_t"), - I64 => write!(out, "::std::int64_t"), - Isize => write!(out, "::rust::isize"), - F32 => write!(out, "float"), - F64 => write!(out, "double"), - CxxString => write!(out, "::std::string"), - RustString => write!(out, "::rust::String"), - } -} - fn write_type_space(out: &mut OutFile, ty: &Type) { write_type(out, ty); write_space_after_type(out, ty); @@ -1343,6 +1352,7 @@ fn write_space_after_type(out: &mut OutFile, ty: &Type) { | Type::WeakPtr(_) | Type::Str(_) | Type::CxxVector(_) + | Type::CxxFunction(_) | Type::RustVec(_) | Type::SliceRef(_) | Type::Fn(_) @@ -1356,6 +1366,7 @@ fn write_space_after_type(out: &mut OutFile, ty: &Type) { enum UniquePtr<'a> { Ident(&'a Ident), CxxVector(&'a Ident), + CxxFunction(&'a str, &'a Symbol), } trait ToTypename { @@ -1364,7 +1375,54 @@ trait ToTypename { impl ToTypename for Ident { fn to_typename(&self, types: &Types) -> String { - types.resolve(self).name.to_fully_qualified() + if let Some(atom) = Atom::from(&self) { + atom.to_typename(types) + } else { + types.resolve(self).name.to_fully_qualified() + } + } +} + +impl ToTypename for Option<&Ident> { + fn to_typename(&self, types: &Types) -> String { + if let Some(some) = self { + some.to_typename(types) + } else { + "void".to_owned() + } + } +} + +impl ToTypename for Atom { + fn to_typename(&self, _: &Types) -> String { + match &self { + Bool => "bool", + Char => "char", + U8 => "::std::uint8_t", + U16 => "::std::uint16_t", + U32 => "::std::uint32_t", + U64 => "::std::uint64_t", + Usize => "::std::size_t", + I8 => "::std::int8_t", + I16 => "::std::int16_t", + I32 => "::std::int32_t", + I64 => "::std::int64_t", + Isize => "::rust::isize", + F32 => "float", + F64 => "double", + CxxString => "::std::string", + RustString => "::rust::String", + }.to_owned() + } +} + +impl ToTypename for Type { + fn to_typename(&self, types: &Types) -> String { + // TODO: invert logic, so that `write_type()` delegates to `to_typename()` + let opt = Opt::default(); + let mut out = OutFile::new(false, &opt, types); + write_type(&mut out, self); + String::from_utf8(out.content()).unwrap() } } @@ -1375,6 +1433,9 @@ impl<'a> ToTypename for UniquePtr<'a> { UniquePtr::CxxVector(element) => { format!("::std::vector<{}>", element.to_typename(types)) } + UniquePtr::CxxFunction(typename, _) => { + (*typename).to_owned() + }, } } } @@ -1385,7 +1446,11 @@ trait ToMangled { impl ToMangled for Ident { fn to_mangled(&self, types: &Types) -> Symbol { - types.resolve(self).name.to_symbol() + if let Some(_) = Atom::from(&self) { + Symbol::from_idents(once(&self as &dyn Segment)) + } else { + types.resolve(self).name.to_symbol() + } } } @@ -1396,8 +1461,30 @@ impl<'a> ToMangled for UniquePtr<'a> { UniquePtr::CxxVector(element) => { symbol::join(&[&"std", &"vector", &element.to_mangled(types)]) } + UniquePtr::CxxFunction(_, mangled) => { + symbol::join(&[&"std", &"function", *mangled]) + } + } + } +} + +fn std_function_typename(args_types: &[&Type], ret: &str, types: &Types) -> String { + let mut typename = String::new(); + typename.push_str("std::function<"); + typename.push_str(ret); + typename.push_str("("); + for (index, ty) in args_types.iter().enumerate() { + if index != 0 { + typename.push_str( ", "); } + typename.push_str(&ty.to_typename(types)[..]); } + typename.push_str( ")>"); + typename +} + +fn write_atom(out: &mut OutFile, atom: Atom) { + write!(out, "{}", atom.to_typename(out.types)); } fn write_generic_instantiations(out: &mut OutFile) { @@ -1417,6 +1504,7 @@ fn write_generic_instantiations(out: &mut OutFile) { ImplKey::SharedPtr(ident) => write_shared_ptr(out, ident), ImplKey::WeakPtr(ident) => write_weak_ptr(out, ident), ImplKey::CxxVector(ident) => write_cxx_vector(out, ident), + ImplKey::CxxFunction(ident) => write_cxx_std_function(out, ident), } } out.end_block(Block::ExternC); @@ -1644,6 +1732,7 @@ fn write_unique_ptr_common(out: &mut OutFile, ty: UniquePtr) { // for Opaque types because the 'new' method is not implemented. UniquePtr::Ident(ident) => out.types.is_maybe_trivial(ident), UniquePtr::CxxVector(_) => false, + UniquePtr::CxxFunction(_, _) => false, }; let conditional_delete = match ty { @@ -1651,6 +1740,7 @@ fn write_unique_ptr_common(out: &mut OutFile, ty: UniquePtr) { !out.types.structs.contains_key(ident) && !out.types.enums.contains_key(ident) } UniquePtr::CxxVector(_) => false, + UniquePtr::CxxFunction(_, _) => false, }; if conditional_delete { @@ -1658,6 +1748,7 @@ fn write_unique_ptr_common(out: &mut OutFile, ty: UniquePtr) { let definition = match ty { UniquePtr::Ident(ty) => &out.types.resolve(ty).name.cxx, UniquePtr::CxxVector(_) => unreachable!(), + UniquePtr::CxxFunction(_, _) => unreachable!(), }; writeln!( out, @@ -1933,3 +2024,45 @@ fn write_cxx_vector(out: &mut OutFile, key: NamedImplKey) { out.include.memory = true; write_unique_ptr_common(out, UniquePtr::CxxVector(element)); } + +fn write_cxx_std_function(out: &mut OutFile, key: DoubleNamedImplKey) { + let ret_typename = key.id2.to_typename(out.types); + let args_mangled = key.id1.to_mangled(out.types); + let ret_mangled = mangle::mangle_ident(key.id2, out.types); + let func_mangled = symbol::join(&[&args_mangled, &ret_mangled]); + + let args_types = if let Some(tuple) = out.types.resolve_tuple_struct(key.id1) { + tuple.types.iter().collect::>() + } else if let (Some(ty), _) = out.types.resolve_cxx_arg_type(&key) { + vec![ty] + } else { + panic!("No eligible arg types."); + }; + + let func_typename = std_function_typename(&args_types[..], &ret_typename[..], out.types); + + write!( + out, + "{} cxxbridge1$std$function$call${}({}* f", + ret_typename, func_mangled, func_typename, + ); + for (index, ty) in args_types.iter().enumerate() { + write!(out, ", "); + let name = ForeignName::parse(&format!("v{}", index), key.begin_span).unwrap(); + write_cxx_function_parameter(out, &name, ty); + } + + writeln!(out, ") {{"); + write!(out, " return (*f)("); + for (index, ty) in args_types.iter().enumerate() { + if index != 0 { + write!(out, ", "); + } + let name = ForeignName::parse(&format!("v{}", index), key.begin_span).unwrap(); + write_cxx_function_argument(out, &name, ty); + } + writeln!(out, ");"); + writeln!(out, "}}"); + + write_unique_ptr_common(out, UniquePtr::CxxFunction(&func_typename[..], &func_mangled)); +} diff --git a/macro/src/expand.rs b/macro/src/expand.rs index 8f5836a6b..551924ac1 100644 --- a/macro/src/expand.rs +++ b/macro/src/expand.rs @@ -2,20 +2,21 @@ use crate::syntax::atom::Atom::*; use crate::syntax::attrs::{self, OtherAttrs}; use crate::syntax::cfg::CfgExpr; use crate::syntax::file::Module; -use crate::syntax::instantiate::{ImplKey, NamedImplKey}; +use crate::syntax::instantiate::{DoubleNamedImplKey, ImplKey, NamedImplKey}; use crate::syntax::qualified::QualifiedName; use crate::syntax::report::Errors; use crate::syntax::symbol::Symbol; use crate::syntax::{ self, check, mangle, Api, Doc, Enum, ExternFn, ExternType, Impl, Lifetimes, Pair, Signature, - Struct, Trait, Type, TypeAlias, Types, + Struct, Trait, Type, TypeAlias, Types, TupleStruct }; use crate::type_id::Crate; use crate::{derive, generics}; use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote, quote_spanned, ToTokens}; use std::mem; -use syn::{parse_quote, punctuated, Generics, Lifetime, Result, Token}; +use syn::{parse_quote, punctuated, Generics, Index, Lifetime, Result, Token}; +use syn::punctuated::Punctuated; pub fn bridge(mut ffi: Module) -> Result { let ref mut errors = Errors::new(); @@ -68,6 +69,9 @@ fn expand(ffi: Module, doc: Doc, attrs: OtherAttrs, apis: &[Api], types: &Types) hidden.extend(expand_struct_operators(strct)); forbid.extend(expand_struct_forbid_drop(strct)); } + Api::TupleStruct(tstrct) => { + expanded.extend(expand_tuple_struct(tstrct)); + } Api::Enum(enm) => expanded.extend(expand_enum(enm)), Api::CxxType(ety) => { let ident = &ety.name.rust; @@ -111,6 +115,9 @@ fn expand(ffi: Module, doc: Doc, attrs: OtherAttrs, apis: &[Api], types: &Types) ImplKey::CxxVector(ident) => { expanded.extend(expand_cxx_vector(ident, explicit_impl, types)); } + ImplKey::CxxFunction(ident) => { + expanded.extend(expand_cxx_std_function(ident, types)); + } } } @@ -188,6 +195,24 @@ fn expand_struct(strct: &Struct) -> TokenStream { } } +fn expand_tuple_struct(strct: &TupleStruct) -> TokenStream { + let ident = &strct.name.rust; + let fields = strct.types.iter().map(|ty| { + quote!(pub #ty) + }); + let generics = &strct.generics; + let span = ident.span(); + let struct_def = quote_spanned! {span=> + pub struct #ident #generics ( + #(#fields,)* + ) + }; + + quote! { + #struct_def; + } +} + fn expand_struct_operators(strct: &Struct) -> TokenStream { let ident = &strct.name.rust; let generics = &strct.generics; @@ -439,6 +464,21 @@ fn expand_cxx_type_assert_pinned(ety: &ExternType, types: &Types) -> TokenStream } } +fn expand_cxx_function_parameter(var: &Ident, colon: Token![:], ty: &Type, types: &Types) -> TokenStream { + let ext_ty = expand_extern_type(ty, types, true); + if ty == RustString { + quote!(#var #colon *const #ext_ty) + } else if let Type::RustVec(_) = ty { + quote!(#var #colon *const #ext_ty) + } else if let Type::Fn(_) = ty { + quote!(#var #colon ::cxx::private::FatFunction) + } else if types.needs_indirect_abi(&ty) { + quote!(#var #colon *mut #ext_ty) + } else { + quote!(#var #colon #ext_ty) + } +} + fn expand_cxx_function_decl(efn: &ExternFn, types: &Types) -> TokenStream { let generics = &efn.generics; let receiver = efn.receiver.iter().map(|receiver| { @@ -448,18 +488,7 @@ fn expand_cxx_function_decl(efn: &ExternFn, types: &Types) -> TokenStream { let args = efn.args.iter().map(|arg| { let var = &arg.name.rust; let colon = arg.colon_token; - let ty = expand_extern_type(&arg.ty, types, true); - if arg.ty == RustString { - quote!(#var #colon *const #ty) - } else if let Type::RustVec(_) = arg.ty { - quote!(#var #colon *const #ty) - } else if let Type::Fn(_) = arg.ty { - quote!(#var #colon ::cxx::private::FatFunction) - } else if types.needs_indirect_abi(&arg.ty) { - quote!(#var #colon *mut #ty) - } else { - quote!(#var #colon #ty) - } + expand_cxx_function_parameter(var, colon, &arg.ty, types) }); let all_args = receiver.chain(args); let ret = if efn.throws { @@ -480,6 +509,70 @@ fn expand_cxx_function_decl(efn: &ExternFn, types: &Types) -> TokenStream { } } +fn expand_cxx_function_argument(var: &TokenStream, span: Span, ty: &Type, types: &Types) -> TokenStream { + match ty { + Type::Ident(ident) if ident.rust == RustString => { + quote_spanned!(span=> #var.as_mut_ptr() as *const ::cxx::private::RustString) + } + Type::RustBox(ty) => { + if types.is_considered_improper_ctype(&ty.inner) { + quote_spanned!(span=> ::cxx::alloc::boxed::Box::into_raw(#var).cast()) + } else { + quote_spanned!(span=> ::cxx::alloc::boxed::Box::into_raw(#var)) + } + } + Type::UniquePtr(ty) => { + if types.is_considered_improper_ctype(&ty.inner) { + quote_spanned!(span=> ::cxx::UniquePtr::into_raw(#var).cast()) + } else { + quote_spanned!(span=> ::cxx::UniquePtr::into_raw(#var)) + } + } + Type::RustVec(_) => quote_spanned!(span=> #var.as_mut_ptr() as *const ::cxx::private::RustVec<_>), + Type::Ref(ty) => match &ty.inner { + Type::Ident(ident) if ident.rust == RustString => match ty.mutable { + false => quote_spanned!(span=> ::cxx::private::RustString::from_ref(#var)), + true => quote_spanned!(span=> ::cxx::private::RustString::from_mut(#var)), + }, + Type::RustVec(vec) if vec.inner == RustString => match ty.mutable { + false => quote_spanned!(span=> ::cxx::private::RustVec::from_ref_vec_string(#var)), + true => quote_spanned!(span=> ::cxx::private::RustVec::from_mut_vec_string(#var)), + }, + Type::RustVec(_) => match ty.mutable { + false => quote_spanned!(span=> ::cxx::private::RustVec::from_ref(#var)), + true => quote_spanned!(span=> ::cxx::private::RustVec::from_mut(#var)), + }, + inner if types.is_considered_improper_ctype(inner) => { + let var = match ty.pinned { + false => quote!(#var), + true => quote_spanned!(span=> ::cxx::core::pin::Pin::into_inner_unchecked(#var)), + }; + match ty.mutable { + false => { + quote_spanned!(span=> #var as *const #inner as *const ::cxx::core::ffi::c_void) + } + true => quote_spanned!(span=> #var as *mut #inner as *mut ::cxx::core::ffi::c_void), + } + } + _ => quote!(#var), + }, + Type::Ptr(ty) => { + if types.is_considered_improper_ctype(&ty.inner) { + quote_spanned!(span=> #var.cast()) + } else { + quote!(#var) + } + } + Type::Str(_) => quote_spanned!(span=> ::cxx::private::RustStr::from(#var)), + Type::SliceRef(ty) => match ty.mutable { + false => quote_spanned!(span=> ::cxx::private::RustSlice::from_ref(#var)), + true => quote_spanned!(span=> ::cxx::private::RustSlice::from_mut(#var)), + }, + ty if types.needs_indirect_abi(ty) => quote_spanned!(span=> #var.as_mut_ptr()), + _ => quote!(#var), + } +} + fn expand_cxx_function_shim(efn: &ExternFn, types: &Types) -> TokenStream { let doc = &efn.doc; let attrs = &efn.attrs; @@ -514,69 +607,8 @@ fn expand_cxx_function_shim(efn: &ExternFn, types: &Types) -> TokenStream { .iter() .map(|receiver| receiver.var.to_token_stream()); let arg_vars = efn.args.iter().map(|arg| { - let var = &arg.name.rust; - let span = var.span(); - match &arg.ty { - Type::Ident(ident) if ident.rust == RustString => { - quote_spanned!(span=> #var.as_mut_ptr() as *const ::cxx::private::RustString) - } - Type::RustBox(ty) => { - if types.is_considered_improper_ctype(&ty.inner) { - quote_spanned!(span=> ::cxx::alloc::boxed::Box::into_raw(#var).cast()) - } else { - quote_spanned!(span=> ::cxx::alloc::boxed::Box::into_raw(#var)) - } - } - Type::UniquePtr(ty) => { - if types.is_considered_improper_ctype(&ty.inner) { - quote_spanned!(span=> ::cxx::UniquePtr::into_raw(#var).cast()) - } else { - quote_spanned!(span=> ::cxx::UniquePtr::into_raw(#var)) - } - } - Type::RustVec(_) => quote_spanned!(span=> #var.as_mut_ptr() as *const ::cxx::private::RustVec<_>), - Type::Ref(ty) => match &ty.inner { - Type::Ident(ident) if ident.rust == RustString => match ty.mutable { - false => quote_spanned!(span=> ::cxx::private::RustString::from_ref(#var)), - true => quote_spanned!(span=> ::cxx::private::RustString::from_mut(#var)), - }, - Type::RustVec(vec) if vec.inner == RustString => match ty.mutable { - false => quote_spanned!(span=> ::cxx::private::RustVec::from_ref_vec_string(#var)), - true => quote_spanned!(span=> ::cxx::private::RustVec::from_mut_vec_string(#var)), - }, - Type::RustVec(_) => match ty.mutable { - false => quote_spanned!(span=> ::cxx::private::RustVec::from_ref(#var)), - true => quote_spanned!(span=> ::cxx::private::RustVec::from_mut(#var)), - }, - inner if types.is_considered_improper_ctype(inner) => { - let var = match ty.pinned { - false => quote!(#var), - true => quote_spanned!(span=> ::cxx::core::pin::Pin::into_inner_unchecked(#var)), - }; - match ty.mutable { - false => { - quote_spanned!(span=> #var as *const #inner as *const ::cxx::core::ffi::c_void) - } - true => quote_spanned!(span=> #var as *mut #inner as *mut ::cxx::core::ffi::c_void), - } - } - _ => quote!(#var), - }, - Type::Ptr(ty) => { - if types.is_considered_improper_ctype(&ty.inner) { - quote_spanned!(span=> #var.cast()) - } else { - quote!(#var) - } - } - Type::Str(_) => quote_spanned!(span=> ::cxx::private::RustStr::from(#var)), - Type::SliceRef(ty) => match ty.mutable { - false => quote_spanned!(span=> ::cxx::private::RustSlice::from_ref(#var)), - true => quote_spanned!(span=> ::cxx::private::RustSlice::from_mut(#var)), - }, - ty if types.needs_indirect_abi(ty) => quote_spanned!(span=> #var.as_mut_ptr()), - _ => quote!(#var), - } + let span = arg.name.rust.span(); + expand_cxx_function_argument(&arg.name.rust.to_token_stream(), span, &arg.ty, types) }); let vars = receiver_var.chain(arg_vars); let trampolines = efn @@ -1724,6 +1756,133 @@ fn expand_cxx_vector( } } +fn expand_cxx_std_function( + key: DoubleNamedImplKey, + types: &Types, +) -> TokenStream { + let resolve = types.resolve(key.id1); + let rettype = if let Some(ident) = key.id2 { + quote! { #ident } + } else { + quote! { () } + }; + + let (name, arg_types, is_tuple_arg, impl_generics, args_generics, ref_lifetime) = if let Some(tuple) = types.resolve_tuple_struct(key.id1) { + (&tuple.name.rust, tuple.types.iter().collect::>(), true, resolve.generics.clone(), Some(resolve.generics.clone()), None) + } else if let (Some(ty), ref_lifetime) = types.resolve_cxx_arg_type(&key) { + let ref_lifetime = if ref_lifetime.is_none() && key.id1_ampersand.is_some() { + Some(Lifetime::new("'a", key.begin_span)) + } else { + ref_lifetime.cloned() + }; + + let mut lifetimes = Punctuated::new(); + if let Some(lifetime) = &ref_lifetime { + lifetimes.push_value(lifetime.clone()); + } + let generics = Lifetimes { + lt_token: Some(Token![<](key.begin_span)), + lifetimes, + gt_token: Some(Token![>](key.begin_span)), + }; + (key.id1, vec![ty], false, generics, None, ref_lifetime) + } else { + panic!("No eligible arg types."); + }; + let begin_span = key.begin_span; + let unsafe_token = format_ident!("unsafe", span = begin_span); + + let func_params = arg_types.iter().map(|ty| { + let name = Ident::new("_", begin_span); + expand_cxx_function_parameter(&name, Token![:](begin_span), ty, types) + }); + + let func_args = arg_types.iter().enumerate().map(|(index, ty)| { + let name = Ident::new("_a", begin_span); + let index = Index::from(index); + let var = if is_tuple_arg { + quote!(#name.#index) + } else { + quote!(#name) + }; + let indexing = if types.needs_indirect_abi(ty) { + quote!(::cxx::core::mem::MaybeUninit::new(#var)) + } else { + var + }; + expand_cxx_function_argument(&indexing, begin_span, ty, types) + }); + + let link_name = format!( + "cxxbridge1$std$function$call${}${}", + resolve.name.to_symbol(), + mangle::mangle_ident(key.id2, types), + ); + + let unique_ptr_prefix = format!("cxxbridge1$unique_ptr$std$function${}${}", resolve.name.to_symbol(), mangle::mangle_ident(key.id2, types)); + let link_unique_ptr_null = format!("{}$null", unique_ptr_prefix); + let link_unique_ptr_raw = format!("{}$raw", unique_ptr_prefix); + let link_unique_ptr_get = format!("{}$get", unique_ptr_prefix); + let link_unique_ptr_release = format!("{}$release", unique_ptr_prefix); + let link_unique_ptr_drop = format!("{}$drop", unique_ptr_prefix); + let typename = name.to_string(); + let ampersand = key.id1_ampersand; + + quote! { + #unsafe_token impl #impl_generics ::cxx::private::CxxFunctionArguments<#rettype> for #ampersand #ref_lifetime #name #args_generics { + unsafe fn __call(f: &::cxx::CxxFunction<#ampersand #ref_lifetime #name #args_generics, #rettype>, _a: #ampersand #ref_lifetime #name #args_generics) -> #rettype { + extern "C" { + #[link_name = #link_name] + fn __call #impl_generics(_: &::cxx::CxxFunction<#ampersand #ref_lifetime #name #args_generics, #rettype>, #(#func_params,)*) -> #rettype; + } + __call(f, #(#func_args,)*) + } + fn __typename(f: &mut ::cxx::core::fmt::Formatter<'_>) -> ::cxx::core::fmt::Result { + f.write_str(#typename) + } + fn __unique_ptr_null() -> ::cxx::core::mem::MaybeUninit<*mut ::cxx::core::ffi::c_void> { + extern "C" { + #[link_name = #link_unique_ptr_null] + fn __unique_ptr_null(this: *mut ::cxx::core::mem::MaybeUninit<*mut ::cxx::core::ffi::c_void>); + } + let mut repr = ::cxx::core::mem::MaybeUninit::uninit(); + unsafe { __unique_ptr_null(&mut repr) } + repr + } + unsafe fn __unique_ptr_raw(raw: *mut ::cxx::CxxFunction) -> ::cxx::core::mem::MaybeUninit<*mut ::cxx::core::ffi::c_void> { + extern "C" { + #[link_name = #link_unique_ptr_raw] + fn __unique_ptr_raw #impl_generics(this: *mut ::cxx::core::mem::MaybeUninit<*mut ::cxx::core::ffi::c_void>, raw: *mut ::cxx::CxxFunction<#ampersand #ref_lifetime #name #args_generics, #rettype>); + } + let mut repr = ::cxx::core::mem::MaybeUninit::uninit(); + __unique_ptr_raw(&mut repr, raw); + repr + } + unsafe fn __unique_ptr_get(repr: ::cxx::core::mem::MaybeUninit<*mut ::cxx::core::ffi::c_void>) -> *const ::cxx::CxxFunction { + extern "C" { + #[link_name = #link_unique_ptr_get] + fn __unique_ptr_get #impl_generics(this: *const ::cxx::core::mem::MaybeUninit<*mut ::cxx::core::ffi::c_void>) -> *const ::cxx::CxxFunction<#ampersand #ref_lifetime #name #args_generics, #rettype>; + } + __unique_ptr_get(&repr) + } + unsafe fn __unique_ptr_release(mut repr: ::cxx::core::mem::MaybeUninit<*mut ::cxx::core::ffi::c_void>) -> *mut ::cxx::CxxFunction { + extern "C" { + #[link_name = #link_unique_ptr_release] + fn __unique_ptr_release #impl_generics(this: *mut ::cxx::core::mem::MaybeUninit<*mut ::cxx::core::ffi::c_void>) -> *mut ::cxx::CxxFunction<#ampersand #ref_lifetime #name #args_generics, #rettype>; + } + __unique_ptr_release(&mut repr) + } + unsafe fn __unique_ptr_drop(mut repr: ::cxx::core::mem::MaybeUninit<*mut ::cxx::core::ffi::c_void>) { + extern "C" { + #[link_name = #link_unique_ptr_drop] + fn __unique_ptr_drop(this: *mut ::cxx::core::mem::MaybeUninit<*mut ::cxx::core::ffi::c_void>); + } + __unique_ptr_drop(&mut repr); + } + } + } +} + fn expand_return_type(ret: &Option) -> TokenStream { match ret { Some(ret) => quote!(-> #ret), diff --git a/src/cxx_function.rs b/src/cxx_function.rs new file mode 100644 index 000000000..cf2cd2821 --- /dev/null +++ b/src/cxx_function.rs @@ -0,0 +1,73 @@ +use core::marker::{PhantomData}; +use std::mem::MaybeUninit; +use core::ffi::c_void; +use crate::std::fmt; + +/// Binding to C++ `std::function` or `UniquePtr>`. +#[repr(C, packed)] +pub struct CxxFunction { + // A thing, because repr(C) structs are not allowed to consist exclusively + // of PhantomData fields. + _void: [c_void; 0], + _func: PhantomData A>, +} + +impl, R> CxxFunction { + + /// Sends the callback and the arguments to C++ land, and calls it there + pub fn call(&self, arguments: A) -> R { + unsafe { A::__call(self, arguments) } + } +} + +/// Trait bound for types which may be used as the `A` inside of a +/// `CxxFunction` in generic code. +/// +/// This trait has no publicly callable or implementable methods. Implementing +/// it outside of the CXX codebase is not supported. +/// +/// # Example +/// +/// A bound `T: CxxFunctionArguments` may be necessary when manipulating +/// [`CxxFunction`] in generic code. +/// +/// ``` +/// use cxx::CxxFunction; +/// use cxx::private::CxxFunctionArguments; +/// use std::fmt::Display; +/// +/// pub fn take_generic_function(ptr: &CxxFunction, arguments: A) +/// where +/// A: CxxFunctionArguments, +/// R: Display, +/// { +/// let result = ptr.call(arguments); +/// println!("the callback returned: {}", result); +/// } +/// ``` +/// +/// Writing the same generic function without a `CxxFunctionArguments` trait bound +/// would not compile. +pub unsafe trait CxxFunctionArguments: Sized { + #[doc(hidden)] + unsafe fn __call(f: &CxxFunction, arg: Self) -> R; + #[doc(hidden)] + fn __typename(f: &mut fmt::Formatter) -> fmt::Result; + #[doc(hidden)] + fn __unique_ptr_null() -> MaybeUninit<*mut c_void>; + #[doc(hidden)] + unsafe fn __unique_ptr_raw(raw: *mut CxxFunction) -> MaybeUninit<*mut c_void>; + #[doc(hidden)] + unsafe fn __unique_ptr_get(repr: MaybeUninit<*mut c_void>) -> *const CxxFunction; + #[doc(hidden)] + unsafe fn __unique_ptr_release(repr: MaybeUninit<*mut c_void>) -> *mut CxxFunction; + #[doc(hidden)] + unsafe fn __unique_ptr_drop(repr: MaybeUninit<*mut c_void>); +} diff --git a/src/lib.rs b/src/lib.rs index 8f425a145..0884ceb67 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -437,6 +437,7 @@ compile_error! { mod macros; mod c_char; +mod cxx_function; mod cxx_vector; mod exception; mod extern_type; @@ -464,6 +465,7 @@ pub mod vector; mod weak_ptr; pub use crate::cxx_vector::CxxVector; +pub use crate::cxx_function::CxxFunction; #[cfg(feature = "alloc")] pub use crate::exception::Exception; pub use crate::extern_type::{kind, ExternType}; @@ -491,6 +493,7 @@ pub type Vector = CxxVector; #[doc(hidden)] pub mod private { pub use crate::cxx_vector::VectorElement; + pub use crate::cxx_function::CxxFunctionArguments; pub use crate::extern_type::{verify_extern_kind, verify_extern_type}; pub use crate::function::FatFunction; pub use crate::hash::hash; diff --git a/src/unique_ptr.rs b/src/unique_ptr.rs index 33992059e..d20069b7f 100644 --- a/src/unique_ptr.rs +++ b/src/unique_ptr.rs @@ -2,13 +2,14 @@ use crate::cxx_vector::{CxxVector, VectorElement}; use crate::fmt::display; use crate::kind::Trivial; use crate::string::CxxString; -use crate::ExternType; +use crate::{CxxFunction, ExternType}; use core::ffi::c_void; use core::fmt::{self, Debug, Display}; use core::marker::PhantomData; use core::mem::{self, MaybeUninit}; use core::ops::{Deref, DerefMut}; use core::pin::Pin; +use cxx::private::CxxFunctionArguments; /// Binding to C++ `std::unique_ptr>`. #[repr(C)] @@ -294,3 +295,27 @@ where unsafe { T::__unique_ptr_drop(repr) } } } + +unsafe impl UniquePtrTarget for CxxFunction + where + T: CxxFunctionArguments, +{ + fn __typename(f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "CxxFunction<{}>", display(T::__typename)) + } + fn __null() -> MaybeUninit<*mut c_void> { + T::__unique_ptr_null() + } + unsafe fn __raw(raw: *mut Self) -> MaybeUninit<*mut c_void> { + unsafe { T::__unique_ptr_raw(raw) } + } + unsafe fn __get(repr: MaybeUninit<*mut c_void>) -> *const Self { + unsafe { T::__unique_ptr_get(repr) } + } + unsafe fn __release(repr: MaybeUninit<*mut c_void>) -> *mut Self { + unsafe { T::__unique_ptr_release(repr) } + } + unsafe fn __drop(repr: MaybeUninit<*mut c_void>) { + unsafe { T::__unique_ptr_drop(repr) } + } +} diff --git a/syntax/check.rs b/syntax/check.rs index 698782f96..109e1e8a3 100644 --- a/syntax/check.rs +++ b/syntax/check.rs @@ -1,10 +1,7 @@ use crate::syntax::atom::Atom::{self, *}; use crate::syntax::report::Errors; use crate::syntax::visit::{self, Visit}; -use crate::syntax::{ - error, ident, trivial, Api, Array, Enum, ExternFn, ExternType, Impl, Lang, Lifetimes, - NamedType, Ptr, Receiver, Ref, Signature, SliceRef, Struct, Trait, Ty1, Type, TypeAlias, Types, -}; +use crate::syntax::{error, ident, trivial, Api, Array, Enum, ExternFn, ExternType, Impl, Lang, Lifetimes, NamedType, Ptr, Receiver, Ref, Signature, SliceRef, Struct, Trait, Ty1, Ty2, Type, TypeAlias, Types, TupleStruct}; use proc_macro2::{Delimiter, Group, Ident, TokenStream}; use quote::{quote, ToTokens}; use std::fmt::Display; @@ -51,6 +48,7 @@ fn do_typecheck(cx: &mut Check) { Type::SharedPtr(ptr) => check_type_shared_ptr(cx, ptr), Type::WeakPtr(ptr) => check_type_weak_ptr(cx, ptr), Type::CxxVector(ptr) => check_type_cxx_vector(cx, ptr), + Type::CxxFunction(ptr) => check_type_cxx_func(cx, ptr), Type::Ref(ty) => check_type_ref(cx, ty), Type::Ptr(ty) => check_type_ptr(cx, ty), Type::Array(array) => check_type_array(cx, array), @@ -64,6 +62,7 @@ fn do_typecheck(cx: &mut Check) { match api { Api::Include(_) => {} Api::Struct(strct) => check_api_struct(cx, strct), + Api::TupleStruct(tstrct) => check_api_tuple_struct(cx, tstrct), Api::Enum(enm) => check_api_enum(cx, enm), Api::CxxType(ety) | Api::RustType(ety) => check_api_type(cx, ety), Api::CxxFunction(efn) | Api::RustFunction(efn) => check_api_fn(cx, efn), @@ -83,6 +82,7 @@ fn check_type_ident(cx: &mut Check, name: &NamedType) { let ident = &name.rust; if Atom::from(ident).is_none() && !cx.types.structs.contains_key(ident) + && !cx.types.tuple_structs.contains_key(ident) && !cx.types.enums.contains_key(ident) && !cx.types.cxx.contains(ident) && !cx.types.rust.contains(ident) @@ -150,6 +150,8 @@ fn check_type_unique_ptr(cx: &mut Check, ptr: &Ty1) { } } else if let Type::CxxVector(_) = &ptr.inner { return; + } else if let Type::CxxFunction(_) = &ptr.inner { + return; } cx.error(ptr, "unsupported unique_ptr target type"); @@ -219,6 +221,10 @@ fn check_type_cxx_vector(cx: &mut Check, ptr: &Ty1) { cx.error(ptr, "unsupported vector element type"); } +fn check_type_cxx_func(_: &mut Check, _: &Ty2) { + // TODO: write type checks +} + fn check_type_ref(cx: &mut Check, ty: &Ref) { if ty.mutable && !ty.pinned { if let Some(requires_pin) = match &ty.inner { @@ -347,6 +353,10 @@ fn check_api_struct(cx: &mut Check, strct: &Struct) { } } +fn check_api_tuple_struct(_: &mut Check, _: &TupleStruct) { + // TODO: implement chech logic for TupleStruct +} + fn check_api_enum(cx: &mut Check, enm: &Enum) { check_reserved_name(cx, &enm.name.rust); check_lifetimes(cx, &enm.generics); @@ -527,6 +537,15 @@ fn check_api_impl(cx: &mut Check, imp: &Impl) { } } } + Type::CxxFunction(ty) => { + if let Type::Ident(first) = &ty.first { + if let Type::Ident(second) = &ty.second { + if Atom::from(&first.rust).is_none() && Atom::from(&second.rust).is_none() { + return; + } + } + } + } _ => {} } @@ -603,6 +622,7 @@ fn check_reserved_name(cx: &mut Check, ident: &Ident) { || ident == "WeakPtr" || ident == "Vec" || ident == "CxxVector" + || ident == "CxxFunction" || ident == "str" || Atom::from(ident).is_some() { @@ -648,6 +668,7 @@ fn is_unsized(cx: &mut Check, ty: &Type) -> bool { | Type::UniquePtr(_) | Type::SharedPtr(_) | Type::WeakPtr(_) + | Type::CxxFunction(_) | Type::Ref(_) | Type::Ptr(_) | Type::Str(_) @@ -726,6 +747,7 @@ fn describe(cx: &mut Check, ty: &Type) -> String { Type::Ptr(_) => "raw pointer".to_owned(), Type::Str(_) => "&str".to_owned(), Type::CxxVector(_) => "C++ vector".to_owned(), + Type::CxxFunction(_) => "C++ function".to_owned(), Type::SliceRef(_) => "slice".to_owned(), Type::Fn(_) => "function pointer".to_owned(), Type::Void(_) => "()".to_owned(), diff --git a/syntax/ident.rs b/syntax/ident.rs index bb2281e72..e45d17081 100644 --- a/syntax/ident.rs +++ b/syntax/ident.rs @@ -33,7 +33,10 @@ pub(crate) fn check_all(cx: &mut Check, apis: &[Api]) { for field in &strct.fields { check(cx, &field.name); } - } + }, + Api::TupleStruct(tstrct) => { + check(cx, &tstrct.name); + }, Api::Enum(enm) => { check(cx, &enm.name); for variant in &enm.variants { diff --git a/syntax/impls.rs b/syntax/impls.rs index 36e1f322a..53b1699a1 100644 --- a/syntax/impls.rs +++ b/syntax/impls.rs @@ -1,5 +1,5 @@ use crate::syntax::{ - Array, ExternFn, Include, Lifetimes, Ptr, Receiver, Ref, Signature, SliceRef, Ty1, Type, Var, + Array, ExternFn, Include, Lifetimes, Ptr, Receiver, Ref, Signature, TupleStruct, SliceRef, Ty1, Ty2, Type, Var, }; use std::hash::{Hash, Hasher}; use std::mem; @@ -53,6 +53,7 @@ impl Hash for Type { Type::Str(t) => t.hash(state), Type::RustVec(t) => t.hash(state), Type::CxxVector(t) => t.hash(state), + Type::CxxFunction(t) => t.hash(state), Type::Fn(t) => t.hash(state), Type::SliceRef(t) => t.hash(state), Type::Array(t) => t.hash(state), @@ -75,6 +76,7 @@ impl PartialEq for Type { (Type::Str(lhs), Type::Str(rhs)) => lhs == rhs, (Type::RustVec(lhs), Type::RustVec(rhs)) => lhs == rhs, (Type::CxxVector(lhs), Type::CxxVector(rhs)) => lhs == rhs, + (Type::CxxFunction(lhs), Type::CxxFunction(rhs)) => lhs == rhs, (Type::Fn(lhs), Type::Fn(rhs)) => lhs == rhs, (Type::SliceRef(lhs), Type::SliceRef(rhs)) => lhs == rhs, (Type::Void(_), Type::Void(_)) => true, @@ -148,6 +150,46 @@ impl Hash for Ty1 { } } +impl Eq for Ty2 {} + +impl PartialEq for Ty2 { + fn eq(&self, other: &Self) -> bool { + let Ty2 { + name, + langle: _, + first, + comma: _, + second, + rangle: _, + } = self; + let Ty2 { + name: name2, + langle: _, + first: first2, + comma: _, + second: second2, + rangle: _, + } = other; + name == name2 && first == first2 && second == second2 + } +} + +impl Hash for Ty2 { + fn hash(&self, state: &mut H) { + let Ty2 { + name, + langle: _, + first, + comma: _, + second, + rangle: _, + } = self; + name.hash(state); + first.hash(state); + second.hash(state); + } +} + impl Eq for Ref {} impl PartialEq for Ref { @@ -396,6 +438,43 @@ impl Hash for Signature { } } +impl Eq for TupleStruct {} + +impl PartialEq for TupleStruct { + fn eq(&self, other: &Self) -> bool { + let TupleStruct { + types, + name: _, + paren_token: _, + generics: _, + } = self; + let TupleStruct { + types: types2, + name: _, + paren_token: _, + generics: _, + } = other; + types.len() == types2.len() + && types.iter().zip(types2).all(|(ty, ty2)| { + ty == ty2 + }) + } +} + +impl Hash for TupleStruct { + fn hash(&self, state: &mut H) { + let TupleStruct { + types, + name: _, + paren_token: _, + generics: _, + } = self; + for ty in types { + ty.hash(state); + } + } +} + impl Eq for Receiver {} impl PartialEq for Receiver { diff --git a/syntax/improper.rs b/syntax/improper.rs index f19eb86a7..a7ffea856 100644 --- a/syntax/improper.rs +++ b/syntax/improper.rs @@ -28,7 +28,7 @@ impl<'a> Types<'a> { | Type::Fn(_) | Type::Void(_) | Type::SliceRef(_) => Definite(true), - Type::UniquePtr(_) | Type::SharedPtr(_) | Type::WeakPtr(_) | Type::CxxVector(_) => { + Type::UniquePtr(_) | Type::SharedPtr(_) | Type::WeakPtr(_) | Type::CxxVector(_) | Type::CxxFunction(_) => { Definite(false) } Type::Ref(ty) => self.determine_improper_ctype(&ty.inner), diff --git a/syntax/instantiate.rs b/syntax/instantiate.rs index b6cbf24b5..a14eb63b5 100644 --- a/syntax/instantiate.rs +++ b/syntax/instantiate.rs @@ -1,4 +1,4 @@ -use crate::syntax::{NamedType, Ty1, Type}; +use crate::syntax::{NamedType, Ty1, Ty2, Type}; use proc_macro2::{Ident, Span}; use std::hash::{Hash, Hasher}; use syn::Token; @@ -11,6 +11,7 @@ pub enum ImplKey<'a> { SharedPtr(NamedImplKey<'a>), WeakPtr(NamedImplKey<'a>), CxxVector(NamedImplKey<'a>), + CxxFunction(DoubleNamedImplKey<'a>), } #[derive(Copy, Clone)] @@ -22,6 +23,19 @@ pub struct NamedImplKey<'a> { pub end_span: Span, } +#[derive(Copy, Clone)] +pub struct DoubleNamedImplKey<'a> { + pub begin_span: Span, + pub id1_ampersand: Option, + pub id1: &'a Ident, + pub id1_lt_token: Option, + pub id1_gt_token: Option]>, + pub id2: Option<&'a Ident>, + pub id2_lt_token: Option, + pub id2_gt_token: Option]>, + pub end_span: Span, +} + impl Type { pub(crate) fn impl_key(&self) -> Option { if let Type::RustBox(ty) = self { @@ -48,6 +62,22 @@ impl Type { if let Type::Ident(ident) = &ty.inner { return Some(ImplKey::CxxVector(NamedImplKey::new(ty, ident))); } + } else if let Type::CxxFunction(ty) = self { + let ret: Option<&NamedType> = if let Type::Ident(ret) = &ty.second { + Some(ret) + } else if let Type::Void(_) = &ty.second { + None + } else { + return None; + }; + + if let Type::Ident(args) = &ty.first { + return Some(ImplKey::CxxFunction(DoubleNamedImplKey::new(ty, None, args, ret))); + } else if let Type::Ref(rf) = &ty.first { + if let Type::Ident(args) = &rf.inner { + return Some(ImplKey::CxxFunction(DoubleNamedImplKey::new(ty, Some(rf.ampersand), args, ret))); + } + } } None } @@ -78,3 +108,37 @@ impl<'a> NamedImplKey<'a> { } } } + +impl<'a> PartialEq for DoubleNamedImplKey<'a> { + fn eq(&self, other: &Self) -> bool { + self.id1_ampersand.is_some() == other.id1_ampersand.is_some() + && PartialEq::eq(&self.id1, &other.id1) + && PartialEq::eq(&self.id2, &other.id2) + } +} + +impl<'a> Eq for DoubleNamedImplKey<'a> {} + +impl<'a> Hash for DoubleNamedImplKey<'a> { + fn hash(&self, hasher: &mut H) { + self.id1_ampersand.is_some().hash(hasher); + self.id1.hash(hasher); + self.id2.hash(hasher); + } +} + +impl<'a> DoubleNamedImplKey<'a> { + fn new(outer: &Ty2, ampersand: Option, first: &'a NamedType, second: Option<&'a NamedType>) -> Self { + DoubleNamedImplKey { + begin_span: outer.name.span(), + id1_ampersand: ampersand, + id1: &first.rust, + id1_lt_token: first.generics.lt_token, + id1_gt_token: first.generics.gt_token, + id2: second.map(|s| &s.rust), + id2_lt_token: second.map(|s| s.generics.lt_token).flatten(), + id2_gt_token: second.map(|s| s.generics.gt_token).flatten(), + end_span: outer.rangle.span, + } + } +} diff --git a/syntax/mangle.rs b/syntax/mangle.rs index 287b44341..c7d1fe6db 100644 --- a/syntax/mangle.rs +++ b/syntax/mangle.rs @@ -73,8 +73,9 @@ // - CXXBRIDGE1_STRUCT_org$rust$Struct // - CXXBRIDGE1_ENUM_Enabled +use proc_macro2::Ident; use crate::syntax::symbol::{self, Symbol}; -use crate::syntax::{ExternFn, Pair, Types}; +use crate::syntax::{Atom, ExternFn, Pair, Types}; const CXXBRIDGE: &str = "cxxbridge1"; @@ -118,3 +119,15 @@ pub fn c_trampoline(efn: &ExternFn, var: &Pair, types: &Types) -> Symbol { pub fn r_trampoline(efn: &ExternFn, var: &Pair, types: &Types) -> Symbol { join!(extern_fn(efn, types), var.rust, 1) } + +pub fn mangle_ident(ident: Option<&Ident>, types: &Types) -> Symbol { + if let Some(some) = ident { + if let Some(_) = Atom::from(&some) { + symbol::join(&[&&some.to_string()[..]]) + } else { + types.resolve(some).name.to_symbol() + } + } else { + symbol::join(&[&"void"]) + } +} diff --git a/syntax/mod.rs b/syntax/mod.rs index 4f19d9641..4596b01d6 100644 --- a/syntax/mod.rs +++ b/syntax/mod.rs @@ -51,6 +51,7 @@ pub use self::types::Types; pub enum Api { Include(Include), Struct(Struct), + TupleStruct(TupleStruct), Enum(Enum), CxxType(ExternType), CxxFunction(ExternFn), @@ -191,6 +192,13 @@ pub struct Signature { pub throws_tokens: Option<(kw::Result, Token![<], Token![>])>, } +pub struct TupleStruct { + pub name: Pair, + pub types: Punctuated, + pub paren_token: Paren, + pub generics: Lifetimes, +} + pub struct Var { pub cfg: CfgExpr, pub doc: Doc, @@ -234,6 +242,7 @@ pub enum Type { Ptr(Box), Str(Box), CxxVector(Box), + CxxFunction(Box), Fn(Box), Void(Span), SliceRef(Box), @@ -247,6 +256,15 @@ pub struct Ty1 { pub rangle: Token![>], } +pub struct Ty2 { + pub name: Ident, + pub langle: Token![<], + pub first: Type, + pub comma: Token![,], + pub second: Type, + pub rangle: Token![>], +} + pub struct Ref { pub pinned: bool, pub ampersand: Token![&], diff --git a/syntax/parse.rs b/syntax/parse.rs index 1754c6006..0e18818d0 100644 --- a/syntax/parse.rs +++ b/syntax/parse.rs @@ -7,7 +7,7 @@ use crate::syntax::Atom::*; use crate::syntax::{ attrs, error, Api, Array, Derive, Doc, Enum, EnumRepr, ExternFn, ExternType, ForeignName, Impl, Include, IncludeKind, Lang, Lifetimes, NamedType, Namespace, Pair, Ptr, Receiver, Ref, - Signature, SliceRef, Struct, Ty1, Type, TypeAlias, Var, Variant, + Signature, SliceRef, Struct, TupleStruct, Ty1, Ty2, Type, TypeAlias, Var, Variant, }; use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree}; use quote::{format_ident, quote, quote_spanned}; @@ -19,7 +19,7 @@ use syn::{ GenericArgument, GenericParam, Generics, Ident, ItemEnum, ItemImpl, ItemStruct, Lit, LitStr, Pat, PathArguments, Result, ReturnType, Signature as RustSignature, Token, TraitBound, TraitBoundModifier, Type as RustType, TypeArray, TypeBareFn, TypeParamBound, TypePath, TypePtr, - TypeReference, Variant as RustVariant, Visibility, + TypeReference, Variant as RustVariant, Visibility, FieldsUnnamed, }; pub mod kw { @@ -76,14 +76,6 @@ fn parse_struct(cx: &mut Errors, mut item: ItemStruct, namespace: &Namespace) -> }, ); - let named_fields = match item.fields { - Fields::Named(fields) => fields, - Fields::Unit => return Err(Error::new_spanned(item, "unit structs are not supported")), - Fields::Unnamed(_) => { - return Err(Error::new_spanned(item, "tuple structs are not supported")); - } - }; - let mut lifetimes = Punctuated::new(); let mut has_unsupported_generic_param = false; for pair in item.generics.params.into_pairs() { @@ -124,6 +116,24 @@ fn parse_struct(cx: &mut Errors, mut item: ItemStruct, namespace: &Namespace) -> ); } + let struct_token = item.struct_token; + let visibility = visibility_pub(&item.vis, struct_token.span); + let name = pair(namespace, &item.ident, cxx_name, rust_name); + + let generics = Lifetimes { + lt_token: item.generics.lt_token, + lifetimes, + gt_token: item.generics.gt_token, + }; + + let named_fields = match item.fields { + Fields::Named(fields) => fields, + Fields::Unit => return Err(Error::new_spanned(item.ident, "unit structs are not supported")), + Fields::Unnamed(u) => { + return Ok(Api::TupleStruct(parse_tuple_from_unnamed_struct(name, generics, &u)?)); + } + }; + let mut fields = Vec::new(); for field in named_fields.named { let ident = field.ident.unwrap(); @@ -163,14 +173,6 @@ fn parse_struct(cx: &mut Errors, mut item: ItemStruct, namespace: &Namespace) -> }); } - let struct_token = item.struct_token; - let visibility = visibility_pub(&item.vis, struct_token.span); - let name = pair(namespace, &item.ident, cxx_name, rust_name); - let generics = Lifetimes { - lt_token: item.generics.lt_token, - lifetimes, - gt_token: item.generics.gt_token, - }; let brace_token = named_fields.brace_token; Ok(Api::Struct(Struct { @@ -1073,6 +1075,10 @@ fn parse_impl(cx: &mut Errors, imp: ItemImpl) -> Result { Type::Ident(ident) => ident.generics.clone(), _ => Lifetimes::default(), }, + Type::CxxFunction(ty) => match &ty.first { + Type::Ident(ident) => ident.generics.clone(), + _ => Lifetimes::default(), + }, Type::Ident(_) | Type::Ref(_) | Type::Ptr(_) @@ -1264,6 +1270,21 @@ fn parse_type_path(ty: &TypePath) -> Result { rangle: generic.gt_token, }))); } + } else if ident == "CxxFunction" && generic.args.len() == 2 { + if let GenericArgument::Type(arg) = &generic.args[0] { + if let GenericArgument::Type(ret) = &generic.args[1] { + let first = parse_type(arg)?; + let second = parse_type(ret)?; + return Ok(Type::CxxFunction(Box::new(Ty2 { + name: ident, + langle: generic.lt_token, + first, + comma: **generic.args.pairs().next().unwrap().punct().unwrap(), + second, + rangle: generic.gt_token, + }))); + } + } } else if ident == "Box" && generic.args.len() == 1 { if let GenericArgument::Type(arg) = &generic.args[0] { let inner = parse_type(arg)?; @@ -1436,6 +1457,18 @@ fn parse_type_fn(ty: &TypeBareFn) -> Result { }))) } +fn parse_tuple_from_unnamed_struct(name: Pair, generics: Lifetimes, unnamed_fields: &FieldsUnnamed) -> Result { + let types: Punctuated = unnamed_fields.unnamed + .iter() + .map(|field| { + Ok(parse_type(&field.ty)?) + }) + .collect::>()?; + let paren_token = unnamed_fields.paren_token; + + Ok(TupleStruct { name, types, paren_token, generics }) +} + fn parse_return_type( ty: &ReturnType, throws_tokens: &mut Option<(kw::Result, Token![<], Token![>])>, diff --git a/syntax/pod.rs b/syntax/pod.rs index 0bf152eea..6b86bf781 100644 --- a/syntax/pod.rs +++ b/syntax/pod.rs @@ -28,6 +28,7 @@ impl<'a> Types<'a> { | Type::SharedPtr(_) | Type::WeakPtr(_) | Type::CxxVector(_) + | Type::CxxFunction(_) | Type::Void(_) => false, Type::Ref(_) | Type::Str(_) | Type::Fn(_) | Type::SliceRef(_) | Type::Ptr(_) => true, Type::Array(array) => self.is_guaranteed_pod(&array.inner), diff --git a/syntax/resolve.rs b/syntax/resolve.rs index 3a2635bd3..8d1a4426b 100644 --- a/syntax/resolve.rs +++ b/syntax/resolve.rs @@ -1,5 +1,5 @@ -use crate::syntax::instantiate::NamedImplKey; -use crate::syntax::{Lifetimes, NamedType, Pair, Types}; +use crate::syntax::instantiate::{DoubleNamedImplKey, NamedImplKey}; +use crate::syntax::{Lifetime, Lifetimes, NamedType, Pair, Type, Types, TupleStruct}; use proc_macro2::Ident; #[derive(Copy, Clone)] @@ -9,6 +9,27 @@ pub struct Resolution<'a> { } impl<'a> Types<'a> { + pub fn resolve_tuple_struct(&self, ident: &impl UnresolvedName) -> Option<&TupleStruct> { + self.tuple_structs.get(ident.ident()).map(|t| *t) + } + + pub fn resolve_cxx_arg_type(&self, key: &DoubleNamedImplKey) -> (Option<&Type>, Option<&Lifetime>) { + for t in self.all.iter() { + if let Type::Ident(t_ident) = t { + if &t_ident.rust == key.id1 && key.id1_ampersand.is_none() { + return (Some(t), None) + } + } else if let Type::Ref(t_ref) = t { + if let Type::Ident(t_ident) = &t_ref.inner { + if &t_ident.rust == key.id1 && key.id1_ampersand.is_some() { + return (Some(t), t_ref.lifetime.as_ref()) + } + } + } + } + (None, None) + } + pub fn resolve(&self, ident: &impl UnresolvedName) -> Resolution<'a> { let ident = ident.ident(); match self.try_resolve(ident) { diff --git a/syntax/tokens.rs b/syntax/tokens.rs index 62d37c23d..3c607aaae 100644 --- a/syntax/tokens.rs +++ b/syntax/tokens.rs @@ -1,7 +1,7 @@ use crate::syntax::atom::Atom::*; use crate::syntax::{ Array, Atom, Derive, Enum, EnumRepr, ExternFn, ExternType, Impl, Lifetimes, NamedType, Ptr, - Ref, Signature, SliceRef, Struct, Ty1, Type, TypeAlias, Var, + Ref, Signature, SliceRef, Struct, Ty1, Ty2, Type, TypeAlias, Var, }; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote_spanned, ToTokens}; @@ -26,6 +26,7 @@ impl ToTokens for Type { | Type::WeakPtr(ty) | Type::CxxVector(ty) | Type::RustVec(ty) => ty.to_tokens(tokens), + Type::CxxFunction(ty) => ty.to_tokens(tokens), Type::Ref(r) | Type::Str(r) => r.to_tokens(tokens), Type::Ptr(p) => p.to_tokens(tokens), Type::Array(a) => a.to_tokens(tokens), @@ -63,7 +64,7 @@ impl ToTokens for Ty1 { } = self; let span = name.span(); match name.to_string().as_str() { - "UniquePtr" | "SharedPtr" | "WeakPtr" | "CxxVector" => { + "UniquePtr" | "SharedPtr" | "WeakPtr" | "CxxVector" | "CxxFunction" => { tokens.extend(quote_spanned!(span=> ::cxx::)); } "Vec" => { @@ -78,6 +79,32 @@ impl ToTokens for Ty1 { } } +impl ToTokens for Ty2 { + fn to_tokens(&self, tokens: &mut TokenStream) { + let Ty2 { + name, + langle, + first, + comma, + second, + rangle, + } = self; + let span = name.span(); + match name.to_string().as_str() { + "CxxFunction" => { + tokens.extend(quote_spanned!(span=> ::cxx::)); + } + _ => {} + } + name.to_tokens(tokens); + langle.to_tokens(tokens); + first.to_tokens(tokens); + comma.to_tokens(tokens); + second.to_tokens(tokens); + rangle.to_tokens(tokens); + } +} + impl ToTokens for Ref { fn to_tokens(&self, tokens: &mut TokenStream) { let Ref { diff --git a/syntax/types.rs b/syntax/types.rs index 82b453008..fca345c33 100644 --- a/syntax/types.rs +++ b/syntax/types.rs @@ -7,7 +7,7 @@ use crate::syntax::set::{OrderedSet, UnorderedSet}; use crate::syntax::trivial::{self, TrivialReason}; use crate::syntax::visit::{self, Visit}; use crate::syntax::{ - toposort, Api, Atom, Enum, EnumRepr, ExternType, Impl, Lifetimes, Pair, Struct, Type, TypeAlias, + toposort, Api, Atom, Enum, EnumRepr, ExternType, Impl, Lifetimes, Pair, Struct, TupleStruct, Type, TypeAlias, }; use proc_macro2::Ident; use quote::ToTokens; @@ -15,6 +15,7 @@ use quote::ToTokens; pub struct Types<'a> { pub all: OrderedSet<&'a Type>, pub structs: UnorderedMap<&'a Ident, &'a Struct>, + pub tuple_structs: OrderedMap<&'a Ident, &'a TupleStruct>, pub enums: UnorderedMap<&'a Ident, &'a Enum>, pub cxx: UnorderedSet<&'a Ident>, pub rust: UnorderedSet<&'a Ident>, @@ -31,6 +32,7 @@ impl<'a> Types<'a> { pub fn collect(cx: &mut Errors, apis: &'a [Api]) -> Self { let mut all = OrderedSet::new(); let mut structs = UnorderedMap::new(); + let mut tuple_structs = OrderedMap::new(); let mut enums = UnorderedMap::new(); let mut cxx = UnorderedSet::new(); let mut rust = UnorderedSet::new(); @@ -87,6 +89,13 @@ impl<'a> Types<'a> { } add_resolution(&strct.name, &strct.generics); } + Api::TupleStruct(tstrct) => { + tuple_structs.insert(&tstrct.name.rust, tstrct); + for ty in &tstrct.types { + visit(&mut all, &ty); + } + add_resolution(&tstrct.name, &tstrct.generics); + } Api::Enum(enm) => { match &enm.repr { EnumRepr::Native { atom: _, repr_type } => { @@ -183,7 +192,8 @@ impl<'a> Types<'a> { | ImplKey::WeakPtr(ident) | ImplKey::CxxVector(ident) => { Atom::from(ident.rust).is_none() && !aliases.contains_key(ident.rust) - } + }, + ImplKey::CxxFunction(_) => true, }; if implicit_impl && !impls.contains_key(&impl_key) { impls.insert(impl_key, None); @@ -200,6 +210,7 @@ impl<'a> Types<'a> { let mut types = Types { all, structs, + tuple_structs, enums, cxx, rust, diff --git a/syntax/visit.rs b/syntax/visit.rs index 2f31378f2..5f4e87451 100644 --- a/syntax/visit.rs +++ b/syntax/visit.rs @@ -18,6 +18,10 @@ where | Type::WeakPtr(ty) | Type::CxxVector(ty) | Type::RustVec(ty) => visitor.visit_type(&ty.inner), + | Type::CxxFunction(ty) => { + visitor.visit_type(&ty.first); + visitor.visit_type(&ty.second); + } Type::Ref(r) => visitor.visit_type(&r.inner), Type::Ptr(p) => visitor.visit_type(&p.inner), Type::Array(a) => visitor.visit_type(&a.inner), diff --git a/tests/ffi/lib.rs b/tests/ffi/lib.rs index c3174bbca..fa66698a8 100644 --- a/tests/ffi/lib.rs +++ b/tests/ffi/lib.rs @@ -16,7 +16,7 @@ pub mod cast; pub mod module; -use cxx::{type_id, CxxString, CxxVector, ExternType, SharedPtr, UniquePtr}; +use cxx::{type_id, CxxFunction, CxxString, CxxVector, ExternType, SharedPtr, UniquePtr}; use std::fmt::{self, Display}; use std::mem::MaybeUninit; use std::os::raw::c_char; @@ -254,6 +254,19 @@ pub mod ffi { type Buffer = crate::Buffer; } + pub struct TestArgs<'a>( + u8, + &'a R, + Vec, + &'a Vec, + String, + &'a String, + SharedString, + &'a SharedString, + Box, + ); + pub struct NoArgs(); + extern "Rust" { type R; @@ -290,6 +303,10 @@ pub mod ffi { fn r_take_rust_string(s: String); fn r_take_unique_ptr_string(s: UniquePtr); fn r_take_ref_vector(v: &CxxVector); + fn r_take_ref_func_tuple_args(v: &CxxFunction); + unsafe fn r_take_ref_func_single_arg_opaque<'a>(f: &CxxFunction<&'a R, ()>); + fn r_take_unique_ptr_func(v: UniquePtr>); + fn r_take_ref_func_no_args(v: &'static CxxFunction); fn r_take_ref_empty_vector(v: &CxxVector); fn r_take_rust_vec(v: Vec); fn r_take_rust_vec_string(v: Vec); @@ -593,6 +610,35 @@ fn r_take_ref_vector(v: &CxxVector) { assert_eq!(slice, [20, 2, 0]); } +fn r_take_ref_func_tuple_args(f: &CxxFunction) { + let retval = f.call( + ffi::TestArgs( + 2, + &R(911), + vec![42], + &vec![64], + "malin".to_owned(), + &"iladalen".to_string(), + ffi::SharedString { msg: "sagene".to_owned() }, + &ffi::SharedString { msg: "torshov".to_owned() }, + Box::new(R(777)), + ) + ); + assert_eq!(200, retval); +} + +unsafe fn r_take_ref_func_single_arg_opaque<'a>(f: &CxxFunction<&'a R, ()>) { + f.call(&R(128)); +} + +fn r_take_unique_ptr_func(f: UniquePtr>) { + f.call(ffi::NoArgs()); +} + +fn r_take_ref_func_no_args(f: &'static CxxFunction) { + f.call(ffi::NoArgs()); +} + fn r_take_ref_empty_vector(v: &CxxVector) { assert!(v.as_slice().is_empty()); assert!(v.is_empty()); diff --git a/tests/ffi/tests.cc b/tests/ffi/tests.cc index 4a94f4a7f..392e1f13b 100644 --- a/tests/ffi/tests.cc +++ b/tests/ffi/tests.cc @@ -794,6 +794,63 @@ extern "C" const char *cxx_run_test() noexcept { r_take_unique_ptr_string( std::unique_ptr(new std::string("2020"))); r_take_ref_vector(std::vector{20, 2, 0}); + + uint8_t capture_val1; + size_t capture_val2; + uint8_t capture_vec_elem; + uint8_t capture_vec_ref_elem; + std::string capture_string; + std::string capture_string_ref; + SharedString capture_shared; + SharedString capture_shared_ref; + size_t capture_box_value; + auto fn = [&]( + uint8_t arg, + const tests::R& arg2, + rust::Vec vec, + const rust::Vec& vec_ref, + rust::String str, + const rust::String& str_ref, + SharedString shared_str, + const SharedString& shared_str_ref, + rust::Box box + ) { + capture_val1 = arg; + capture_val2 = arg2.get(); + capture_vec_elem = vec.at(0); + capture_vec_ref_elem = vec_ref.at(0); + capture_string = static_cast(str); + capture_string_ref = static_cast(str_ref); + capture_shared = shared_str; + capture_shared_ref = shared_str_ref; + capture_box_value = box->get(); + return (uint8_t)200; + }; + r_take_ref_func_tuple_args(fn); + ASSERT(capture_val1 == 2); + ASSERT(capture_val2 == 911); + ASSERT(capture_vec_elem == 42); + ASSERT(capture_vec_ref_elem == 64); + ASSERT(capture_string == "malin"); + ASSERT(capture_string_ref == "iladalen"); + ASSERT(capture_shared.msg == "sagene"); + ASSERT(capture_shared_ref.msg == "torshov"); + ASSERT(capture_box_value == 777); + + uint8_t capture_from_noarg_lambda; + auto noarg_fn = [&]() { + capture_from_noarg_lambda = 66; + }; + r_take_ref_func_no_args(noarg_fn); + ASSERT(capture_from_noarg_lambda == 66); + + + uint8_t capture_from_struct; + r_take_ref_func_single_arg_opaque([&](const tests::R& arg) { + capture_from_struct = arg.get(); + }); + ASSERT(capture_from_struct == 128); + std::vector empty_vector; r_take_ref_empty_vector(empty_vector); empty_vector.reserve(10);