From 5a209c93f036ceb2c2cdc9f5bb58c8053d9da309 Mon Sep 17 00:00:00 2001 From: Rintaro Ishizaki Date: Mon, 23 Jun 2025 16:42:45 -0700 Subject: [PATCH] [JExtract] Bridge closures with UnsafeRawBufferPointer parameter First step to bridging closures with conversions. --- .../MySwiftLibrary/MySwiftLibrary.swift | 4 + .../com/example/swift/HelloJava2Swift.java | 3 + ...Swift2JavaGenerator+FunctionLowering.swift | 82 +++++++++++++--- .../JExtractSwiftLib/FFM/ConversionStep.swift | 53 +++++++++- ...t2JavaGenerator+JavaBindingsPrinting.swift | 50 +++++++--- ...MSwift2JavaGenerator+JavaTranslation.swift | 88 +++++++++++++---- .../FuncCallbackImportTests.swift | 97 +++++++++++++++++++ .../FunctionLoweringTests.swift | 18 ++++ 8 files changed, 351 insertions(+), 44 deletions(-) diff --git a/Samples/SwiftKitSampleApp/Sources/MySwiftLibrary/MySwiftLibrary.swift b/Samples/SwiftKitSampleApp/Sources/MySwiftLibrary/MySwiftLibrary.swift index 8b6066cf..18b6546d 100644 --- a/Samples/SwiftKitSampleApp/Sources/MySwiftLibrary/MySwiftLibrary.swift +++ b/Samples/SwiftKitSampleApp/Sources/MySwiftLibrary/MySwiftLibrary.swift @@ -53,6 +53,10 @@ public func globalReceiveRawBuffer(buf: UnsafeRawBufferPointer) -> Int { public var globalBuffer: UnsafeRawBufferPointer = UnsafeRawBufferPointer(UnsafeMutableRawBufferPointer.allocate(byteCount: 124, alignment: 1)) +public func withBuffer(body: (UnsafeRawBufferPointer) -> Void) { + body(globalBuffer) +} + // ==== Internal helpers func p(_ msg: String, file: String = #fileID, line: UInt = #line, function: String = #function) { diff --git a/Samples/SwiftKitSampleApp/src/main/java/com/example/swift/HelloJava2Swift.java b/Samples/SwiftKitSampleApp/src/main/java/com/example/swift/HelloJava2Swift.java index f3073201..c12d82dd 100644 --- a/Samples/SwiftKitSampleApp/src/main/java/com/example/swift/HelloJava2Swift.java +++ b/Samples/SwiftKitSampleApp/src/main/java/com/example/swift/HelloJava2Swift.java @@ -47,6 +47,9 @@ static void examples() { SwiftKit.trace("getGlobalBuffer().byteSize()=" + MySwiftLibrary.getGlobalBuffer().byteSize()); + MySwiftLibrary.withBuffer((buf) -> { + SwiftKit.trace("withBuffer{$0.byteSize()}=" + buf.byteSize()); + }); // Example of using an arena; MyClass.deinit is run at end of scope try (var arena = SwiftArena.ofConfined()) { MySwiftClass obj = MySwiftClass.init(2222, 7777, arena); diff --git a/Sources/JExtractSwiftLib/FFM/CDeclLowering/FFMSwift2JavaGenerator+FunctionLowering.swift b/Sources/JExtractSwiftLib/FFM/CDeclLowering/FFMSwift2JavaGenerator+FunctionLowering.swift index 2c7520dd..db764e2e 100644 --- a/Sources/JExtractSwiftLib/FFM/CDeclLowering/FFMSwift2JavaGenerator+FunctionLowering.swift +++ b/Sources/JExtractSwiftLib/FFM/CDeclLowering/FFMSwift2JavaGenerator+FunctionLowering.swift @@ -332,14 +332,15 @@ struct CdeclLowering { var parameters: [SwiftParameter] = [] var parameterConversions: [ConversionStep] = [] - for parameter in fn.parameters { - if let _ = try? CType(cdeclType: parameter.type) { - parameters.append(SwiftParameter(convention: .byValue, type: parameter.type)) - parameterConversions.append(.placeholder) - } else { - // Non-trivial types are not yet supported. - throw LoweringError.unhandledType(.function(fn)) - } + for (i, parameter) in fn.parameters.enumerated() { + let parameterName = parameter.parameterName ?? "_\(i)" + let loweredParam = try lowerClosureParameter( + parameter.type, + convention: parameter.convention, + parameterName: parameterName + ) + parameters.append(contentsOf: loweredParam.cdeclParameters) + parameterConversions.append(loweredParam.conversion) } let resultType: SwiftType @@ -352,15 +353,74 @@ struct CdeclLowering { throw LoweringError.unhandledType(.function(fn)) } - // Ignore the conversions for now, since we don't support non-trivial types yet. - _ = (parameterConversions, resultConversion) + let isCompatibleWithC = parameterConversions.allSatisfy(\.isPlaceholder) && resultConversion.isPlaceholder return ( type: .function(SwiftFunctionType(convention: .c, parameters: parameters, resultType: resultType)), - conversion: .placeholder + conversion: isCompatibleWithC ? .placeholder : .closureLowering(parameters: parameterConversions, result: resultConversion) ) } + func lowerClosureParameter( + _ type: SwiftType, + convention: SwiftParameterConvention, + parameterName: String + ) throws -> LoweredParameter { + // If there is a 1:1 mapping between this Swift type and a C type, we just + // return it. + if let _ = try? CType(cdeclType: type) { + return LoweredParameter( + cdeclParameters: [ + SwiftParameter( + convention: .byValue, + parameterName: parameterName, + type: type + ), + ], + conversion: .placeholder + ) + } + + switch type { + case .nominal(let nominal): + if let knownType = nominal.nominalTypeDecl.knownStandardLibraryType { + switch knownType { + case .unsafeRawBufferPointer, .unsafeMutableRawBufferPointer: + // pointer buffers are lowered to (raw-pointer, count) pair. + let isMutable = knownType == .unsafeMutableRawBufferPointer + return LoweredParameter( + cdeclParameters: [ + SwiftParameter( + convention: .byValue, + parameterName: "\(parameterName)_pointer", + type: .optional(isMutable ? knownTypes.unsafeMutableRawPointer : knownTypes.unsafeRawPointer) + ), + SwiftParameter( + convention: .byValue, + parameterName: "\(parameterName)_count", + type: knownTypes.int + ), + ], + conversion: .tuplify([ + .member(.placeholder, member: "baseAddress"), + .member(.placeholder, member: "count") + ]) + ) + + default: + throw LoweringError.unhandledType(type) + } + } + + // Custom types are not supported yet. + throw LoweringError.unhandledType(type) + + case .function, .metatype, .optional, .tuple: + // TODO: Implement + throw LoweringError.unhandledType(type) + } + } + /// Lower a Swift result type to cdecl out parameters and return type. /// /// - Parameters: diff --git a/Sources/JExtractSwiftLib/FFM/ConversionStep.swift b/Sources/JExtractSwiftLib/FFM/ConversionStep.swift index 315acc60..c7fa53e3 100644 --- a/Sources/JExtractSwiftLib/FFM/ConversionStep.swift +++ b/Sources/JExtractSwiftLib/FFM/ConversionStep.swift @@ -58,6 +58,8 @@ enum ConversionStep: Equatable { /// Perform multiple conversions using the same input. case aggregate([ConversionStep], name: String?) + indirect case closureLowering(parameters: [ConversionStep], result: ConversionStep) + indirect case member(ConversionStep, member: String) /// Count the number of times that the placeholder occurs within this @@ -73,13 +75,20 @@ enum ConversionStep: Equatable { inner.placeholderCount case .initialize(_, arguments: let arguments): arguments.reduce(0) { $0 + $1.argument.placeholderCount } - case .placeholder, .tupleExplode: + case .placeholder, .tupleExplode, .closureLowering: 1 case .tuplify(let elements), .aggregate(let elements, _): elements.reduce(0) { $0 + $1.placeholderCount } } } + var isPlaceholder: Bool { + if case .placeholder = self { + return true + } + return false + } + /// Convert the conversion step into an expression with the given /// value as the placeholder value in the expression. func asExprSyntax(placeholder: String, bodyItems: inout [CodeBlockItemSyntax]) -> ExprSyntax? { @@ -165,6 +174,48 @@ enum ConversionStep: Equatable { } } return nil + + case .closureLowering(let parameterSteps, let resultStep): + var body: [CodeBlockItemSyntax] = [] + + // Lower parameters. + var params: [String] = [] + var args: [ExprSyntax] = [] + for (i, parameterStep) in parameterSteps.enumerated() { + let paramName = "_\(i)" + params.append(paramName) + if case .tuplify(let elemSteps) = parameterStep { + for elemStep in elemSteps { + if let elemExpr = elemStep.asExprSyntax(placeholder: paramName, bodyItems: &body) { + args.append(elemExpr) + } + } + } else if let paramExpr = parameterStep.asExprSyntax(placeholder: paramName, bodyItems: &body) { + args.append(paramExpr) + } + } + + // Call the lowered closure with lowered parameters. + let loweredResult = "\(placeholder)(\(args.map(\.description).joined(separator: ", ")))" + + // Raise the lowered result. + let result = resultStep.asExprSyntax(placeholder: loweredResult.description, bodyItems: &body) + body.append("return \(result)") + + // Construct the closure expression. + var closure = ExprSyntax( + """ + { (\(raw: params.joined(separator: ", "))) in + } + """ + ).cast(ClosureExprSyntax.self) + + closure.statements = CodeBlockItemListSyntax { + body.map { + $0.with(\.leadingTrivia, [.newlines(1), .spaces(4)]) + } + } + return ExprSyntax(closure) } } } diff --git a/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaBindingsPrinting.swift b/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaBindingsPrinting.swift index ba88869d..a6cc6b26 100644 --- a/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaBindingsPrinting.swift +++ b/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaBindingsPrinting.swift @@ -251,7 +251,6 @@ extension FFMSwift2JavaGenerator { ) } else { // Otherwise, the lambda must be wrapped with the lowered function instance. - assertionFailure("should be unreachable at this point") let apiParams = functionType.parameters.flatMap { $0.javaParameters.map { param in "\(param.type) \(param.name)" } } @@ -262,13 +261,38 @@ extension FFMSwift2JavaGenerator { public interface \(functionType.name) { \(functionType.result.javaResultType) apply(\(apiParams.joined(separator: ", "))); } - private static MemorySegment $toUpcallStub(\(functionType.name) fi, Arena arena) { - return \(cdeclDescriptor).toUpcallStub(() -> { - fi() - }, arena); - } """ ) + + let cdeclParams = functionType.cdeclType.parameters.map( { "\($0.parameterName!)" }) + + printer.printBraceBlock( + """ + private static MemorySegment $toUpcallStub(\(functionType.name) fi, Arena arena) + """ + ) { printer in + printer.print( + """ + return \(cdeclDescriptor).toUpcallStub((\(cdeclParams.joined(separator: ", "))) -> { + """ + ) + printer.indent() + var convertedArgs: [String] = [] + for param in functionType.parameters { + let arg = param.conversion.render(&printer, param.javaParameters[0].name) + convertedArgs.append(arg) + } + + let call = "fi.apply(\(convertedArgs.joined(separator: ", ")))" + let result = functionType.result.conversion.render(&printer, call) + if functionType.result.javaResultType == .void { + printer.print("\(result);") + } else { + printer.print("return \(result);") + } + printer.outdent() + printer.print("}, arena);") + } } } @@ -419,7 +443,7 @@ extension JavaConversionStep { /// Whether the conversion uses SwiftArena. var requiresSwiftArena: Bool { switch self { - case .placeholder, .constant, .readOutParameter: + case .placeholder, .explodedName, .constant, .readMemorySegment: return false case .constructSwiftValue: return true @@ -436,9 +460,9 @@ extension JavaConversionStep { /// Whether the conversion uses temporary Arena. var requiresTemporaryArena: Bool { switch self { - case .placeholder, .constant: + case .placeholder, .explodedName, .constant: return false - case .readOutParameter: + case .readMemorySegment: return true case .cast(let inner, _), .construct(let inner, _), .constructSwiftValue(let inner, _), .swiftValueSelfSegment(let inner): return inner.requiresSwiftArena @@ -459,6 +483,9 @@ extension JavaConversionStep { case .placeholder: return placeholder + case .explodedName(let component): + return "\(placeholder)_\(component)" + case .swiftValueSelfSegment: return "\(placeholder).$memorySegment()" @@ -491,8 +518,9 @@ extension JavaConversionStep { case .constant(let value): return value - case .readOutParameter(let javaType, let name): - return "\(placeholder)_\(name).get(\(ForeignValueLayout(javaType: javaType)!), 0)" + case .readMemorySegment(let inner, let javaType): + let inner = inner.render(&printer, placeholder) + return "\(inner).get(\(ForeignValueLayout(javaType: javaType)!), 0)" } } } diff --git a/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaTranslation.swift b/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaTranslation.swift index 6f17158b..af1ddd09 100644 --- a/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaTranslation.swift +++ b/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaTranslation.swift @@ -113,6 +113,8 @@ struct TranslatedFunctionType { var name: String var parameters: [TranslatedParameter] var result: TranslatedResult + var swiftType: SwiftFunctionType + var cdeclType: SwiftFunctionType /// Whether or not this functional interface with C ABI compatible. var isCompatibleWithC: Bool { @@ -159,13 +161,19 @@ struct JavaTranslation { case .function, .initializer: decl.name } + // Signature. + let translatedSignature = try translate(loweredFunctionSignature: loweredSignature, methodName: javaName) + // Closures. var funcTypes: [TranslatedFunctionType] = [] for (idx, param) in decl.functionSignature.parameters.enumerated() { switch param.type { case .function(let funcTy): let paramName = param.parameterName ?? "_\(idx)" - let translatedClosure = try translateFunctionType(name: paramName, swiftType: funcTy) + guard case .function( let cdeclTy) = loweredSignature.parameters[idx].cdeclParameters[0].type else { + preconditionFailure("closure parameter wasn't lowered to a function type; \(funcTy)") + } + let translatedClosure = try translateFunctionType(name: paramName, swiftType: funcTy, cdeclType: cdeclTy) funcTypes.append(translatedClosure) case .tuple: // TODO: Implement @@ -175,9 +183,6 @@ struct JavaTranslation { } } - // Signature. - let translatedSignature = try translate(loweredFunctionSignature: loweredSignature, methodName: javaName) - return TranslatedFunctionDecl( name: javaName, functionTypes: funcTypes, @@ -189,23 +194,16 @@ struct JavaTranslation { /// Translate Swift closure type to Java functional interface. func translateFunctionType( name: String, - swiftType: SwiftFunctionType + swiftType: SwiftFunctionType, + cdeclType: SwiftFunctionType ) throws -> TranslatedFunctionType { var translatedParams: [TranslatedParameter] = [] for (i, param) in swiftType.parameters.enumerated() { let paramName = param.parameterName ?? "_\(i)" - if let cType = try? CType(cdeclType: param.type) { - let translatedParam = TranslatedParameter( - javaParameters: [ - JavaParameter(type: cType.javaType, name: paramName) - ], - conversion: .placeholder - ) - translatedParams.append(translatedParam) - continue - } - throw JavaTranslationError.unhandledType(.function(swiftType)) + translatedParams.append( + try translateClosureParameter(param.type, convention: param.convention, parameterName: paramName) + ) } guard let resultCType = try? CType(cdeclType: swiftType.resultType) else { @@ -221,10 +219,55 @@ struct JavaTranslation { return TranslatedFunctionType( name: name, parameters: translatedParams, - result: transltedResult + result: transltedResult, + swiftType: swiftType, + cdeclType: cdeclType ) } + func translateClosureParameter( + _ type: SwiftType, + convention: SwiftParameterConvention, + parameterName: String + ) throws -> TranslatedParameter { + if let cType = try? CType(cdeclType: type) { + return TranslatedParameter( + javaParameters: [ + JavaParameter(type: cType.javaType, name: parameterName) + ], + conversion: .placeholder + ) + } + + switch type { + case .nominal(let nominal): + if let knownType = nominal.nominalTypeDecl.knownStandardLibraryType { + switch knownType { + case .unsafeRawBufferPointer, .unsafeMutableRawBufferPointer: + return TranslatedParameter( + javaParameters: [ + JavaParameter(type: .javaForeignMemorySegment, name: parameterName) + ], + conversion: .method( + .explodedName(component: "pointer"), + methodName: "reinterpret", + arguments: [ + .explodedName(component: "count") + ], + withArena: false + ) + ) + default: + break + } + } + default: + break + } + throw JavaTranslationError.unhandledType(type) + } + + /// Translate a Swift API signature to the user-facing Java API signature. /// /// Note that the result signature is for the high-level Java API, not the @@ -424,10 +467,10 @@ struct JavaTranslation { JavaParameter(type: .long, name: "count"), ], conversion: .method( - .readOutParameter(.javaForeignMemorySegment, component: "pointer"), + .readMemorySegment(.explodedName(component: "pointer"), as: .javaForeignMemorySegment), methodName: "reinterpret", arguments: [ - .readOutParameter(.long, component: "count") + .readMemorySegment(.explodedName(component: "count"), as: .long), ], withArena: false ) @@ -483,9 +526,12 @@ struct JavaTranslation { /// Describes how to convert values between Java types and FFM types. enum JavaConversionStep { - // Pass through. + // The input case placeholder + // The input exploded into components. + case explodedName(component: String) + // A fixed value case constant(String) @@ -513,7 +559,7 @@ enum JavaConversionStep { indirect case commaSeparated([JavaConversionStep]) // Refer an exploded argument suffixed with `_\(name)`. - indirect case readOutParameter(JavaType, component: String) + indirect case readMemorySegment(JavaConversionStep, as: JavaType) var isPlaceholder: Bool { return if case .placeholder = self { true } else { false } diff --git a/Tests/JExtractSwiftTests/FuncCallbackImportTests.swift b/Tests/JExtractSwiftTests/FuncCallbackImportTests.swift index 457e3a7a..c738914a 100644 --- a/Tests/JExtractSwiftTests/FuncCallbackImportTests.swift +++ b/Tests/JExtractSwiftTests/FuncCallbackImportTests.swift @@ -31,6 +31,7 @@ final class FuncCallbackImportTests { public func callMe(callback: () -> Void) public func callMeMore(callback: (UnsafeRawPointer, Float) -> Int, fn: () -> ()) + public func withBuffer(body: (UnsafeRawBufferPointer) -> Int) """ @Test("Import: public func callMe(callback: () -> Void)") @@ -237,4 +238,100 @@ final class FuncCallbackImportTests { ) } + @Test("Import: public func withBuffer(body: (UnsafeRawBufferPointer) -> Int)") + func func_withBuffer_body() throws { + let st = Swift2JavaTranslator( + swiftModuleName: "__FakeModule" + ) + st.log.logLevel = .error + + try st.analyze(file: "Fake.swift", text: Self.class_interfaceFile) + + let funcDecl = st.importedGlobalFuncs.first { $0.name == "withBuffer" }! + + let generator = FFMSwift2JavaGenerator( + translator: st, + javaPackage: "com.example.swift", + swiftOutputDirectory: "/fake", + javaOutputDirectory: "/fake" + ) + + let output = CodePrinter.toString { printer in + generator.printFunctionDowncallMethods(&printer, funcDecl) + } + + assertOutput( + output, + expected: + """ + // ==== -------------------------------------------------- + // withBuffer + /** + * {@snippet lang=c : + * void swiftjava___FakeModule_withBuffer_body(ptrdiff_t (*body)(const void *, ptrdiff_t)) + * } + */ + private static class swiftjava___FakeModule_withBuffer_body { + private static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( + /* body: */SwiftValueLayout.SWIFT_POINTER + ); + private static final MemorySegment ADDR = + __FakeModule.findOrThrow("swiftjava___FakeModule_withBuffer_body"); + private static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC); + public static void call(java.lang.foreign.MemorySegment body) { + try { + if (SwiftKit.TRACE_DOWNCALLS) { + SwiftKit.traceDowncall(body); + } + HANDLE.invokeExact(body); + } catch (Throwable ex$) { + throw new AssertionError("should not reach here", ex$); + } + } + /** + * {snippet lang=c : + * ptrdiff_t (*)(const void *, ptrdiff_t) + * } + */ + private static class $body { + @FunctionalInterface + public interface Function { + long apply(java.lang.foreign.MemorySegment _0, long _1); + } + private static final FunctionDescriptor DESC = FunctionDescriptor.of( + /* -> */SwiftValueLayout.SWIFT_INT, + /* _0: */SwiftValueLayout.SWIFT_POINTER, + /* _1: */SwiftValueLayout.SWIFT_INT + ); + private static final MethodHandle HANDLE = SwiftKit.upcallHandle(Function.class, "apply", DESC); + private static MemorySegment toUpcallStub(Function fi, Arena arena) { + return Linker.nativeLinker().upcallStub(HANDLE.bindTo(fi), DESC, arena); + } + } + } + public static class withBuffer { + @FunctionalInterface + public interface body { + long apply(java.lang.foreign.MemorySegment _0); + } + private static MemorySegment $toUpcallStub(body fi, Arena arena) { + return swiftjava___FakeModule_withBuffer_body.$body.toUpcallStub((_0_pointer, _0_count) -> { + return fi.apply(_0_pointer.reinterpret(_0_count)); + }, arena); + } + } + /** + * Downcall to Swift: + * {@snippet lang=swift : + * public func withBuffer(body: (UnsafeRawBufferPointer) -> Int) + * } + */ + public static void withBuffer(withBuffer.body body) { + try(var arena$ = Arena.ofConfined()) { + swiftjava___FakeModule_withBuffer_body.call(withBuffer.$toUpcallStub(body, arena$)); + } + } + """ + ) + } } diff --git a/Tests/JExtractSwiftTests/FunctionLoweringTests.swift b/Tests/JExtractSwiftTests/FunctionLoweringTests.swift index c707732e..9422bbdb 100644 --- a/Tests/JExtractSwiftTests/FunctionLoweringTests.swift +++ b/Tests/JExtractSwiftTests/FunctionLoweringTests.swift @@ -333,6 +333,24 @@ final class FunctionLoweringTests { ) } + @Test("Lowering non-C-compatible closures") + func lowerComplexClosureParameter() throws { + try assertLoweredFunction( + """ + func withBuffer(body: (UnsafeRawBufferPointer) -> Int) {} + """, + expectedCDecl: """ + @_cdecl("c_withBuffer") + public func c_withBuffer(_ body: @convention(c) (UnsafeRawPointer?, Int) -> Int) { + withBuffer(body: { (_0) in + return body(_0.baseAddress, _0.count) + }) + } + """, + expectedCFunction: "void c_withBuffer(ptrdiff_t (*body)(const void *, ptrdiff_t))" + ) + } + @Test("Lowering () -> Void type") func lowerSimpleClosureTypes() throws { try assertLoweredFunction("""