Skip to content

[JExtract] Static 'call' method in binding descriptor classes #246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 59 additions & 42 deletions Sources/JExtractSwift/Swift2JavaTranslator+JavaBindingsPrinting.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,27 @@ extension Swift2JavaTranslator {
printJavaBindingDescriptorClass(&printer, decl)

// Render the "make the downcall" functions.
printFuncDowncallMethod(&printer, decl)
printJavaBindingWrapperMethod(&printer, decl)
}

/// Print FFM Java binding descriptors for the imported Swift API.
func printJavaBindingDescriptorClass(
package func printJavaBindingDescriptorClass(
_ printer: inout CodePrinter,
_ decl: ImportedFunc
) {
let thunkName = thunkNameRegistry.functionThunkName(decl: decl)
let cFunc = decl.cFunctionDecl(cName: thunkName)

printer.printBraceBlock("private static class \(cFunc.name)") { printer in
printer.printBraceBlock(
"""
/**
* {@snippet lang=c :
* \(cFunc.description)
* }
*/
private static class \(cFunc.name)
"""
) { printer in
printFunctionDescriptorValue(&printer, cFunc)
printer.print(
"""
Expand All @@ -44,11 +53,12 @@ extension Swift2JavaTranslator {
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
"""
)
printJavaBindingDowncallMethod(&printer, cFunc)
}
}

/// Print the 'FunctionDescriptor' of the lowered cdecl thunk.
public func printFunctionDescriptorValue(
func printFunctionDescriptorValue(
_ printer: inout CodePrinter,
_ cFunc: CFunction
) {
Expand All @@ -74,9 +84,42 @@ extension Swift2JavaTranslator {
printer.print(");")
}

func printJavaBindingDowncallMethod(
_ printer: inout CodePrinter,
_ cFunc: CFunction
) {
let returnTy = cFunc.resultType.javaType
let maybeReturn = cFunc.resultType.isVoid ? "" : "return (\(returnTy)) "

var params: [String] = []
var args: [String] = []
for param in cFunc.parameters {
// ! unwrapping because cdecl lowering guarantees the parameter named.
params.append("\(param.type.javaType) \(param.name!)")
args.append(param.name!)
}
let paramsStr = params.joined(separator: ", ")
let argsStr = args.joined(separator: ", ")

printer.print(
"""
public static \(returnTy) call(\(paramsStr)) {
try {
if (SwiftKit.TRACE_DOWNCALLS) {
SwiftKit.traceDowncall(\(argsStr));
}
\(maybeReturn)HANDLE.invokeExact(\(argsStr));
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
"""
)
}

/// Print the calling body that forwards all the parameters to the `methodName`,
/// with adding `SwiftArena.ofAuto()` at the end.
public func printFuncDowncallMethod(
public func printJavaBindingWrapperMethod(
_ printer: inout CodePrinter,
_ decl: ImportedFunc) {
let methodName: String = switch decl.kind {
Expand Down Expand Up @@ -130,19 +173,11 @@ extension Swift2JavaTranslator {
_ printer: inout CodePrinter,
_ decl: ImportedFunc
) {
//=== Part 1: MethodHandle
let descriptorClassIdentifier = thunkNameRegistry.functionThunkName(decl: decl)
printer.print(
"var mh$ = \(descriptorClassIdentifier).HANDLE;"
)

let tryHead = if decl.translatedSignature.requiresTemporaryArena {
"try(var arena$ = Arena.ofConfined()) {"
} else {
"try {"
//=== Part 1: prepare temporary arena if needed.
if decl.translatedSignature.requiresTemporaryArena {
printer.print("try(var arena$ = Arena.ofConfined()) {")
printer.indent();
}
printer.print(tryHead);
printer.indent();

//=== Part 2: prepare all arguments.
var downCallArguments: [String] = []
Expand All @@ -151,15 +186,7 @@ extension Swift2JavaTranslator {
for (i, parameter) in decl.translatedSignature.parameters.enumerated() {
let original = decl.swiftSignature.parameters[i]
let parameterName = original.parameterName ?? "_\(i)"
let converted = parameter.conversion.render(&printer, parameterName)
let lowered: String
if parameter.conversion.isTrivial {
lowered = converted
} else {
// Store the conversion to a temporary variable.
lowered = "\(parameterName)$"
printer.print("var \(lowered) = \(converted);")
}
let lowered = parameter.conversion.render(&printer, parameterName)
downCallArguments.append(lowered)
}

Expand Down Expand Up @@ -191,14 +218,8 @@ extension Swift2JavaTranslator {
}

//=== Part 3: Downcall.
printer.print(
"""
if (SwiftKit.TRACE_DOWNCALLS) {
SwiftKit.traceDowncall(\(downCallArguments.joined(separator: ", ")));
}
"""
)
let downCall = "mh$.invokeExact(\(downCallArguments.joined(separator: ", ")))"
let thunkName = thunkNameRegistry.functionThunkName(decl: decl)
let downCall = "\(thunkName).call(\(downCallArguments.joined(separator: ", ")))"

//=== Part 4: Convert the return value.
if decl.translatedSignature.result.javaResultType == .void {
Expand All @@ -221,14 +242,10 @@ extension Swift2JavaTranslator {
}
}

printer.outdent()
printer.print(
"""
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
"""
)
if decl.translatedSignature.requiresTemporaryArena {
printer.outdent()
printer.print("}")
}
}

func renderMemoryLayoutValue(for javaType: JavaType) -> String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ struct JavaTranslation {
return TranslatedResult(
javaResultType: javaType,
outParameters: [],
conversion: .cast(javaType)
conversion: .pass
)
}

Expand Down
11 changes: 2 additions & 9 deletions Tests/JExtractSwiftTests/FuncCallbackImportTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ final class FuncCallbackImportTests {
let funcDecl = st.importedGlobalFuncs.first { $0.name == "callMe" }!

let output = CodePrinter.toString { printer in
st.printFuncDowncallMethod(&printer, funcDecl)
st.printJavaBindingWrapperMethod(&printer, funcDecl)
}

assertOutput(
Expand All @@ -59,15 +59,8 @@ final class FuncCallbackImportTests {
* }
*/
public static void callMe(java.lang.Runnable callback) {
var mh$ = swiftjava___FakeModule_callMe_callback.HANDLE;
try(var arena$ = Arena.ofConfined()) {
var callback$ = SwiftKit.toUpcallStub(callback, arena$);
if (SwiftKit.TRACE_DOWNCALLS) {
SwiftKit.traceDowncall(callback$);
}
mh$.invokeExact(callback$);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
swiftjava___FakeModule_callMe_callback.call(SwiftKit.toUpcallStub(callback, arena$))
}
}
"""
Expand Down
146 changes: 121 additions & 25 deletions Tests/JExtractSwiftTests/FunctionDescriptorImportTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,29 @@ final class FunctionDescriptorTests {
output,
expected:
"""
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
/* i: */SwiftValueLayout.SWIFT_INT
);
/**
* {@snippet lang=c :
* void swiftjava_SwiftModule_globalTakeInt_i(ptrdiff_t i)
* }
*/
private static class swiftjava_SwiftModule_globalTakeInt_i {
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
/* i: */SwiftValueLayout.SWIFT_INT
);
public static final MemorySegment ADDR =
SwiftModule.findOrThrow("swiftjava_SwiftModule_globalTakeInt_i");
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
public static void call(long i) {
try {
if (SwiftKit.TRACE_DOWNCALLS) {
SwiftKit.traceDowncall(i);
}
HANDLE.invokeExact(i);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
}
"""
)
}
Expand All @@ -66,10 +86,30 @@ final class FunctionDescriptorTests {
output,
expected:
"""
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
/* l: */SwiftValueLayout.SWIFT_INT64,
/* i32: */SwiftValueLayout.SWIFT_INT32
);
/**
* {@snippet lang=c :
* void swiftjava_SwiftModule_globalTakeLongInt_l_i32(int64_t l, int32_t i32)
* }
*/
private static class swiftjava_SwiftModule_globalTakeLongInt_l_i32 {
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
/* l: */SwiftValueLayout.SWIFT_INT64,
/* i32: */SwiftValueLayout.SWIFT_INT32
);
public static final MemorySegment ADDR =
SwiftModule.findOrThrow("swiftjava_SwiftModule_globalTakeLongInt_l_i32");
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
public static void call(long l, int i32) {
try {
if (SwiftKit.TRACE_DOWNCALLS) {
SwiftKit.traceDowncall(l, i32);
}
HANDLE.invokeExact(l, i32);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
}
"""
)
}
Expand All @@ -82,10 +122,30 @@ final class FunctionDescriptorTests {
output,
expected:
"""
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
/* -> */SwiftValueLayout.SWIFT_INT,
/* i: */SwiftValueLayout.SWIFT_INT
);
/**
* {@snippet lang=c :
* ptrdiff_t swiftjava_SwiftModule_echoInt_i(ptrdiff_t i)
* }
*/
private static class swiftjava_SwiftModule_echoInt_i {
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
/* -> */SwiftValueLayout.SWIFT_INT,
/* i: */SwiftValueLayout.SWIFT_INT
);
public static final MemorySegment ADDR =
SwiftModule.findOrThrow("swiftjava_SwiftModule_echoInt_i");
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
public static long call(long i) {
try {
if (SwiftKit.TRACE_DOWNCALLS) {
SwiftKit.traceDowncall(i);
}
return (long) HANDLE.invokeExact(i);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
}
"""
)
}
Expand All @@ -98,10 +158,30 @@ final class FunctionDescriptorTests {
output,
expected:
"""
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
/* -> */SwiftValueLayout.SWIFT_INT32,
/* self: */SwiftValueLayout.SWIFT_POINTER
);
/**
* {@snippet lang=c :
* int32_t swiftjava_SwiftModule_MySwiftClass_counter$get(const void *self)
* }
*/
private static class swiftjava_SwiftModule_MySwiftClass_counter$get {
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
/* -> */SwiftValueLayout.SWIFT_INT32,
/* self: */SwiftValueLayout.SWIFT_POINTER
);
public static final MemorySegment ADDR =
SwiftModule.findOrThrow("swiftjava_SwiftModule_MySwiftClass_counter$get");
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
public static int call(java.lang.foreign.MemorySegment self) {
try {
if (SwiftKit.TRACE_DOWNCALLS) {
SwiftKit.traceDowncall(self);
}
return (int) HANDLE.invokeExact(self);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
}
"""
)
}
Expand All @@ -113,10 +193,30 @@ final class FunctionDescriptorTests {
output,
expected:
"""
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
/* newValue: */SwiftValueLayout.SWIFT_INT32,
/* self: */SwiftValueLayout.SWIFT_POINTER
);
/**
* {@snippet lang=c :
* void swiftjava_SwiftModule_MySwiftClass_counter$set(int32_t newValue, const void *self)
* }
*/
private static class swiftjava_SwiftModule_MySwiftClass_counter$set {
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
/* newValue: */SwiftValueLayout.SWIFT_INT32,
/* self: */SwiftValueLayout.SWIFT_POINTER
);
public static final MemorySegment ADDR =
SwiftModule.findOrThrow("swiftjava_SwiftModule_MySwiftClass_counter$set");
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
public static void call(int newValue, java.lang.foreign.MemorySegment self) {
try {
if (SwiftKit.TRACE_DOWNCALLS) {
SwiftKit.traceDowncall(newValue, self);
}
HANDLE.invokeExact(newValue, self);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
}
"""
)
}
Expand Down Expand Up @@ -145,10 +245,8 @@ extension FunctionDescriptorTests {
$0.name == methodIdentifier
}!

let thunkName = st.thunkNameRegistry.functionThunkName(decl: funcDecl)
let cFunc = funcDecl.cFunctionDecl(cName: thunkName)
let output = CodePrinter.toString { printer in
st.printFunctionDescriptorValue(&printer, cFunc)
st.printJavaBindingDescriptorClass(&printer, funcDecl)
}

try body(output)
Expand Down Expand Up @@ -180,10 +278,8 @@ extension FunctionDescriptorTests {
fatalError("Cannot find descriptor of: \(identifier)")
}

let thunkName = st.thunkNameRegistry.functionThunkName(decl: accessorDecl)
let cFunc = accessorDecl.cFunctionDecl(cName: thunkName)
let getOutput = CodePrinter.toString { printer in
st.printFunctionDescriptorValue(&printer, cFunc)
st.printJavaBindingDescriptorClass(&printer, accessorDecl)
}

try body(getOutput)
Expand Down
Loading