From b4d7ada78da51125094160e828328e9baa09e3b1 Mon Sep 17 00:00:00 2001 From: George Hong Date: Mon, 7 Jul 2025 23:53:41 -0700 Subject: [PATCH] Add TrainingModule and SGD JNI As title, adds wrappers together with unit test based on XOR train.cpp example. --- .gitignore | 1 + extension/android/BUCK | 2 + extension/android/CMakeLists.txt | 3 +- .../executorch_android/android_test_setup.sh | 11 + .../android/executorch_android/build.gradle | 2 + .../executorch/TrainingModuleE2ETest.kt | 215 +++++++++++ .../main/java/org/pytorch/executorch/SGD.java | 102 +++++ .../pytorch/executorch/TrainingModule.java | 119 ++++++ extension/android/jni/BUCK | 18 +- extension/android/jni/jni_layer.cpp | 97 ++--- extension/android/jni/jni_layer_training.cpp | 350 ++++++++++++++++++ extension/android/jni/selective_jni.buck.bzl | 4 + 12 files changed, 876 insertions(+), 48 deletions(-) create mode 100644 extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt create mode 100644 extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java create mode 100644 extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java create mode 100644 extension/android/jni/jni_layer_training.cpp diff --git a/.gitignore b/.gitignore index 553729e9b68..08d14e13582 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ pip-out/ *.model tokenizer.json *.pte +*.ptd !test_bpe_tokenizer.bin !test_tiktoken_tokenizer.model diff --git a/extension/android/BUCK b/extension/android/BUCK index 0d8462692dd..76377aad08e 100644 --- a/extension/android/BUCK +++ b/extension/android/BUCK @@ -12,6 +12,8 @@ non_fbcode_target(_kind = fb_android_library, "executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java", "executorch_android/src/main/java/org/pytorch/executorch/Module.java", "executorch_android/src/main/java/org/pytorch/executorch/Tensor.java", + "executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java", + "executorch_android/src/main/java/org/pytorch/executorch/SGD.java", "executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java", ], autoglob = False, diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 8f7e19cb172..1a7852072d0 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -64,7 +64,7 @@ set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch) find_package(executorch CONFIG REQUIRED) target_link_options_shared_lib(executorch) -add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp) +add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp jni/jni_layer_training.cpp) set(link_libraries) list( @@ -77,6 +77,7 @@ list( extension_runner_util extension_tensor extension_threadpool + extension_training fbjni ) diff --git a/extension/android/executorch_android/android_test_setup.sh b/extension/android/executorch_android/android_test_setup.sh index f521dac30c5..1ab8f8ba469 100644 --- a/extension/android/executorch_android/android_test_setup.sh +++ b/extension/android/executorch_android/android_test_setup.sh @@ -15,7 +15,17 @@ which "${PYTHON_EXECUTABLE}" BASEDIR=$(dirname "$(realpath $0)") prepare_add() { + pushd "${BASEDIR}/../../../" python3 -m test.models.export_program --modules "ModuleAdd" --outdir "${BASEDIR}/src/androidTest/resources/" + popd +} + +prepare_xor() { + pushd "${BASEDIR}/../../../extension/training/" + python3 -m examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/" + mv "${BASEDIR}/src/androidTest/resources/xor.pte" "${BASEDIR}/src/androidTest/resources/xor_full.pte" + python3 -m examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/" --external + popd } prepare_tinyllama() { @@ -43,5 +53,6 @@ prepare_vision() { } prepare_add +prepare_xor prepare_tinyllama prepare_vision diff --git a/extension/android/executorch_android/build.gradle b/extension/android/executorch_android/build.gradle index 2fa0b9fd57c..83f59d6e5d5 100644 --- a/extension/android/executorch_android/build.gradle +++ b/extension/android/executorch_android/build.gradle @@ -50,10 +50,12 @@ dependencies { implementation libs.core.ktx testImplementation 'junit:junit:4.12' testImplementation 'org.assertj:assertj-core:3.27.2' + testImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23' androidTestImplementation 'androidx.test.ext:junit:1.1.5' androidTestImplementation 'androidx.test:rules:1.2.0' androidTestImplementation 'commons-io:commons-io:2.4' androidTestImplementation 'org.json:json:20250107' + androidTestImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23' } import com.vanniktech.maven.publish.SonatypeHost diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt new file mode 100644 index 00000000000..fe519659f5f --- /dev/null +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt @@ -0,0 +1,215 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +package org.pytorch.executorch + +import android.Manifest +import android.util.Log +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.rule.GrantPermissionRule +import java.io.File +import java.io.IOException +import java.net.URISyntaxException +import org.apache.commons.io.FileUtils +import org.junit.Assert +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.pytorch.executorch.TestFileUtils.getTestFilePath +import kotlin.random.Random +import kotlin.test.assertContains + +/** Unit tests for [TrainingModule]. */ +@RunWith(AndroidJUnit4::class) +class TrainingModuleE2ETest { + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testTrainXOR() { + val pteFilePath = "/xor.pte" + val ptdFilePath = "/xor.ptd" + + val pteFile = File(getTestFilePath(pteFilePath)) + val pteInputStream = javaClass.getResourceAsStream(pteFilePath) + FileUtils.copyInputStreamToFile(pteInputStream, pteFile) + pteInputStream.close() + + val ptdFile = File(getTestFilePath(ptdFilePath)) + val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath) + FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile) + ptdInputStream.close() + + val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath)) + val params = module.namedParameters("forward") + + Assert.assertEquals(4, params.size) + assertContains(params, LIN_WEIGHT) + assertContains(params, LIN_BIAS) + assertContains(params, LIN2_WEIGHT) + assertContains(params, LIN2_BIAS) + + val sgd = SGD.create(params, 0.5); + val dataset = listOf( + Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)), + Tensor.fromBlob(longArrayOf(0), longArrayOf(1)), + Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)), + Tensor.fromBlob(longArrayOf(0), longArrayOf(1)), + Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)), + Tensor.fromBlob(longArrayOf(1), longArrayOf(1)), + Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)), + Tensor.fromBlob(longArrayOf(1), longArrayOf(1)), + ) + + val numEpochs = 5000; + var finalLoss = Float.MAX_VALUE + + for (i in 0 until numEpochs) { + val inputDex = 2 * Random.nextInt(dataset.size / 2) + val targetDex = inputDex + 1 + val input = dataset.get(inputDex) + val target = dataset.get(targetDex) + val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target)) + val gradients = module.namedGradients("forward") + + if (i == 0) { + Assert.assertEquals(4, gradients.size) + assertContains(gradients, LIN_WEIGHT) + assertContains(gradients, LIN_BIAS) + assertContains(gradients, LIN2_WEIGHT) + assertContains(gradients, LIN2_BIAS) + } + + if (i % 500 == 0 || i == numEpochs - 1) { + Log.i( + "testTrainXOR", + String.format( + "Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d", + i, + out[0].toTensor().getDataAsFloatArray()[0], + input.getDataAsFloatArray()[0], + input.getDataAsFloatArray()[1], + out[1].toTensor().getDataAsLongArray()[0], + target.getDataAsLongArray()[0])); + } + + sgd.step(gradients) + + if (i == numEpochs - 1) { + finalLoss = out[0].toTensor().dataAsFloatArray[0] + } + } + Assert.assertTrue(finalLoss < 0.1f) + } + + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testTrainXOR_PTEOnly() { + val pteFilePath = "/xor_full.pte" + + val pteFile = File(getTestFilePath(pteFilePath)) + val pteInputStream = javaClass.getResourceAsStream(pteFilePath) + FileUtils.copyInputStreamToFile(pteInputStream, pteFile) + pteInputStream.close() + + val module = TrainingModule.load(getTestFilePath(pteFilePath)); + val params = module.namedParameters("forward") + + Assert.assertEquals(4, params.size) + assertContains(params, LIN_WEIGHT) + assertContains(params, LIN_BIAS) + assertContains(params, LIN2_WEIGHT) + assertContains(params, LIN2_BIAS) + + val sgd = SGD.create(params, 0.5); + val dataset = listOf( + Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)), + Tensor.fromBlob(longArrayOf(0), longArrayOf(1)), + Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)), + Tensor.fromBlob(longArrayOf(0), longArrayOf(1)), + Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)), + Tensor.fromBlob(longArrayOf(1), longArrayOf(1)), + Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)), + Tensor.fromBlob(longArrayOf(1), longArrayOf(1)), + ) + + val numEpochs = 5000; + var finalLoss = Float.MAX_VALUE + + for (i in 0 until numEpochs) { + val inputDex = 2 * Random.nextInt(dataset.size / 2) + val targetDex = inputDex + 1 + val input = dataset.get(inputDex) + val target = dataset.get(targetDex) + val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target)) + val gradients = module.namedGradients("forward") + + if (i == 0) { + Assert.assertEquals(4, gradients.size) + assertContains(gradients, LIN_WEIGHT) + assertContains(gradients, LIN_BIAS) + assertContains(gradients, LIN2_WEIGHT) + assertContains(gradients, LIN2_BIAS) + } + + if (i % 500 == 0 || i == numEpochs - 1) { + Log.i( + "testTrainXOR_PTEOnly", + String.format( + "Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d", + i, + out[0].toTensor().getDataAsFloatArray()[0], + input.getDataAsFloatArray()[0], + input.getDataAsFloatArray()[1], + out[1].toTensor().getDataAsLongArray()[0], + target.getDataAsLongArray()[0])); + } + + sgd.step(gradients) + + if (i == numEpochs - 1) { + finalLoss = out[0].toTensor().dataAsFloatArray[0] + } + } + Assert.assertTrue(finalLoss < 0.1f) + } + + @Test + @Throws(IOException::class) + fun testMissingPteFile() { + val exception = Assert.assertThrows(RuntimeException::class.java) { + TrainingModule.load(getTestFilePath(MISSING_PTE_NAME)) + } + Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME)) + } + + @Test + @Throws(IOException::class) + fun testMissingPtdFile() { + val exception = Assert.assertThrows(RuntimeException::class.java) { + val pteFilePath = "/xor.pte" + val pteFile = File(getTestFilePath(pteFilePath)) + val pteInputStream = javaClass.getResourceAsStream(pteFilePath) + FileUtils.copyInputStreamToFile(pteInputStream, pteFile) + pteInputStream.close() + + TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME)) + } + Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME)) + } + + companion object { + private const val LIN_WEIGHT = "net.linear.weight" + private const val LIN_BIAS = "net.linear.bias" + private const val LIN2_WEIGHT = "net.linear2.weight" + private const val LIN2_BIAS = "net.linear2.bias" + private const val MISSING_PTE_NAME = "/missing.pte" + private const val MISSING_PTD_NAME = "/missing.ptd" + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java new file mode 100644 index 00000000000..35dbf5cc54c --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java @@ -0,0 +1,102 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; +import java.util.Map; +import org.pytorch.executorch.annotations.Experimental; + +/** + * Java wrapper for ExecuTorch SGD Optimizer. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public class SGD { + + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + // Loads libexecutorch.so from jniLibs + NativeLoader.loadLibrary("executorch"); + } + + private final HybridData mHybridData; + + @DoNotStrip + private static native HybridData initHybrid( + Map namedParameters, + double learningRate, + double momentum, + double dampening, + double weightDecay, + boolean nesterov); + + private SGD( + Map namedParameters, + double learningRate, + double momentum, + double dampening, + double weightDecay, + boolean nesterov) { + mHybridData = + initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov); + } + + /** + * Creates a new SGD optimizer with the specified parameters and options. + * + * @param namedParameters Map of parameter names to tensors to be optimized + * @param learningRate The learning rate for the optimizer + * @param momentum The momentum value + * @param dampening The dampening value + * @param weightDecay The weight decay value + * @param nesterov Whether to use Nesterov momentum + * @return new {@link org.pytorch.executorch.SGD} object + */ + public static SGD create( + Map namedParameters, + double learningRate, + double momentum, + double dampening, + double weightDecay, + boolean nesterov) { + return new SGD(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov); + } + + /** + * Creates a new SGD optimizer with default options. + * + * @param namedParameters Map of parameter names to tensors to be optimized + * @param learningRate The learning rate for the optimizer + * @return new {@link org.pytorch.executorch.SGD} object + */ + public static SGD create(Map namedParameters, double learningRate) { + return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false); + } + + /** + * Performs a single optimization step using the provided gradients. + * + * @param namedGradients Map of parameter names to gradient tensors + */ + public void step(Map namedGradients) { + if (!mHybridData.isValid()) { + throw new RuntimeException("Attempt to use a destroyed SGD optimizer"); + } + stepNative(namedGradients); + } + + @DoNotStrip + private native void stepNative(Map namedGradients); +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java new file mode 100644 index 00000000000..f3c3cdc1219 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import android.util.Log; +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; +import java.io.File; +import java.util.HashMap; +import java.util.Map; +import org.pytorch.executorch.annotations.Experimental; + +/** + * Java wrapper for ExecuTorch TrainingModule. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public class TrainingModule { + + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + // Loads libexecutorch.so from jniLibs + NativeLoader.loadLibrary("executorch"); + } + + private final HybridData mHybridData; + + @DoNotStrip + private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath); + + private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) { + mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath); + } + + /** + * Loads a serialized ExecuTorch Training Module from the specified path on the disk. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @param dataPath path to file that contains the ExecuTorch module external weights. + * @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module. + */ + public static TrainingModule load(final String modelPath, final String dataPath) { + File modelFile = new File(modelPath); + if (!modelFile.canRead() || !modelFile.isFile()) { + throw new RuntimeException("Cannot load model path!! " + modelPath); + } + File dataFile = new File(dataPath); + if (!dataFile.canRead() || !dataFile.isFile()) { + throw new RuntimeException("Cannot load data path!! " + dataPath); + } + return new TrainingModule(modelPath, dataPath); + } + + /** + * Loads a serialized ExecuTorch training module from the specified path on the disk. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. This PTE does not + * rely on external weights. + * @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module. + */ + public static TrainingModule load(final String modelPath) { + File modelFile = new File(modelPath); + if (!modelFile.canRead() || !modelFile.isFile()) { + throw new RuntimeException("Cannot load model path!! " + modelPath); + } + return new TrainingModule(modelPath, ""); + } + + /** + * Runs the specified joint-graph method of this module with the specified arguments. + * + * @param methodName name of the ExecuTorch method to run. + * @param inputs arguments that will be passed to ExecuTorch method. + * @return return value(s) from the method. + */ + public EValue[] executeForwardBackward(String methodName, EValue... inputs) { + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new EValue[0]; + } + return executeForwardBackwardNative(methodName, inputs); + } + + @DoNotStrip + private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs); + + public Map namedParameters(String methodName) { + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new HashMap(); + } + return namedParametersNative(methodName); + } + + @DoNotStrip + private native Map namedParametersNative(String methodName); + + public Map namedGradients(String methodName) { + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new HashMap(); + } + return namedGradientsNative(methodName); + } + + @DoNotStrip + private native Map namedGradientsNative(String methodName); +} diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index 9ffe0525707..6390e07156d 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -28,7 +28,7 @@ non_fbcode_target(_kind = executorch_generated_lib, non_fbcode_target(_kind = fb_android_cxx_library, name = "executorch_jni", - srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"], + srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp", "jni_layer_training.cpp"], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS, soname = "libexecutorch.$(ext)", @@ -39,17 +39,20 @@ non_fbcode_target(_kind = fb_android_cxx_library, "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", "//third-party/glog:glog", + "//xplat/executorch/extension/data_loader:file_data_loader_static", "//xplat/executorch/extension/module:module_static", "//xplat/executorch/extension/runner_util:inputs_static", "//xplat/executorch/extension/tensor:tensor_static", "//xplat/executorch/extension/threadpool:threadpool_static", + "//xplat/executorch/extension/training/module:training_module_static", + "//xplat/executorch/extension/training/optimizer:sgd_static", third_party_dep("cpuinfo"), ], ) non_fbcode_target(_kind = fb_android_cxx_library, name = "executorch_jni_full", - srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"], + srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp", "jni_layer_training.cpp"], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS, soname = "libexecutorch.$(ext)", @@ -62,9 +65,12 @@ non_fbcode_target(_kind = fb_android_cxx_library, "//fbandroid/native/fb:fb", "//third-party/glog:glog", "//xplat/executorch/backends/xnnpack:xnnpack_backend_static", + "//xplat/executorch/extension/data_loader:file_data_loader_static", "//xplat/executorch/extension/module:module_static", "//xplat/executorch/extension/runner_util:inputs_static", "//xplat/executorch/extension/tensor:tensor_static", + "//xplat/executorch/extension/training/module:training_module_static", + "//xplat/executorch/extension/training/optimizer:sgd_static", "//xplat/executorch/kernels/quantized:generated_lib_static", ], ) @@ -75,6 +81,7 @@ non_fbcode_target(_kind = fb_android_cxx_library, "jni_layer.cpp", "jni_layer_llama.cpp", "jni_layer_runtime.cpp", + "jni_layer_training.cpp", ], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS + [ @@ -89,6 +96,7 @@ non_fbcode_target(_kind = fb_android_cxx_library, "//fbandroid/native/fb:fb", "//third-party/glog:glog", "//xplat/executorch/backends/xnnpack:xnnpack_backend_static", + "//xplat/executorch/extension/data_loader:file_data_loader_static", "//xplat/executorch/examples/models/llama/runner:runner_static", "//xplat/executorch/examples/models/llava/runner:runner_static", "//xplat/executorch/extension/module:module_static", @@ -96,6 +104,8 @@ non_fbcode_target(_kind = fb_android_cxx_library, "//xplat/executorch/extension/tensor:tensor_static", "//xplat/executorch/extension/threadpool:cpuinfo_utils_static", "//xplat/executorch/extension/threadpool:threadpool_static", + "//xplat/executorch/extension/training/module:training_module_static", + "//xplat/executorch/extension/training/optimizer:sgd_static", ], ) @@ -118,6 +128,10 @@ runtime.export_file( name = "jni_layer_runtime.cpp", ) +runtime.export_file( + name = "jni_layer_training.cpp", +) + runtime.cxx_library( name = "jni_headers", exported_headers = [ diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index c3ffe77a0cb..9799083e73f 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -94,6 +94,54 @@ class TensorHybrid : public facebook::jni::HybridClass { cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor)); } + static TensorPtr newTensorFromJTensor( + facebook::jni::alias_ref jtensor) { + static auto cls = TensorHybrid::javaClassStatic(); + static const auto dtypeMethod = cls->getMethod("dtypeJniCode"); + jint jdtype = dtypeMethod(jtensor); + + static const auto shapeField = cls->getField("shape"); + auto jshape = jtensor->getFieldValue(shapeField); + + static auto dataBufferMethod = cls->getMethod< + facebook::jni::local_ref()>( + "getRawDataBuffer"); + facebook::jni::local_ref jbuffer = + dataBufferMethod(jtensor); + + const auto rank = jshape->size(); + + const auto shapeArr = jshape->getRegion(0, rank); + std::vector shape_vec; + shape_vec.reserve(rank); + + auto numel = 1; + for (int i = 0; i < rank; i++) { + shape_vec.push_back(shapeArr[i]); + } + for (int i = rank - 1; i >= 0; --i) { + numel *= shapeArr[i]; + } + JNIEnv* jni = facebook::jni::Environment::current(); + if (java_dtype_to_scalar_type.count(jdtype) == 0) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Unknown Tensor jdtype %d", + jdtype); + } + ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); + const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); + if (dataCapacity != numel) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Tensor dimensions(elements number:%d inconsistent with buffer capacity(%d)", + numel, + dataCapacity); + } + return from_blob( + jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type); + } + private: friend HybridBase; }; @@ -163,51 +211,7 @@ class JEValue : public facebook::jni::JavaClass { ->getMethod()>( "toTensor"); auto jtensor = jMethodGetTensor(JEValue); - - static auto cls = TensorHybrid::javaClassStatic(); - static const auto dtypeMethod = cls->getMethod("dtypeJniCode"); - jint jdtype = dtypeMethod(jtensor); - - static const auto shapeField = cls->getField("shape"); - auto jshape = jtensor->getFieldValue(shapeField); - - static auto dataBufferMethod = cls->getMethod< - facebook::jni::local_ref()>( - "getRawDataBuffer"); - facebook::jni::local_ref jbuffer = - dataBufferMethod(jtensor); - - const auto rank = jshape->size(); - - const auto shapeArr = jshape->getRegion(0, rank); - std::vector shape_vec; - shape_vec.reserve(rank); - - auto numel = 1; - for (int i = 0; i < rank; i++) { - shape_vec.push_back(shapeArr[i]); - } - for (int i = rank - 1; i >= 0; --i) { - numel *= shapeArr[i]; - } - JNIEnv* jni = facebook::jni::Environment::current(); - if (java_dtype_to_scalar_type.count(jdtype) == 0) { - facebook::jni::throwNewJavaException( - facebook::jni::gJavaLangIllegalArgumentException, - "Unknown Tensor jdtype %d", - jdtype); - } - ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); - const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); - if (dataCapacity != numel) { - facebook::jni::throwNewJavaException( - facebook::jni::gJavaLangIllegalArgumentException, - "Tensor dimensions(elements number:%d inconsistent with buffer capacity(%d)", - numel, - dataCapacity); - } - return from_blob( - jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type); + return TensorHybrid::newTensorFromJTensor(jtensor); } facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, @@ -492,10 +496,13 @@ extern void register_natives_for_llm(); void register_natives_for_llm() {} #endif extern void register_natives_for_runtime(); +extern void register_natives_for_training(); + JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize(vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); register_natives_for_llm(); register_natives_for_runtime(); + register_natives_for_training(); }); } diff --git a/extension/android/jni/jni_layer_training.cpp b/extension/android/jni/jni_layer_training.cpp new file mode 100644 index 00000000000..7c66884dcff --- /dev/null +++ b/extension/android/jni/jni_layer_training.cpp @@ -0,0 +1,350 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace executorch::extension; +using namespace executorch::extension::training; +using namespace torch::executor; + +namespace executorch::extension { + +// Forward declarations from jni_layer.cpp +class TensorHybrid : public facebook::jni::HybridClass { + public: + constexpr static const char* kJavaDescriptor = + "Lorg/pytorch/executorch/Tensor;"; + + static facebook::jni::local_ref + newJTensorFromTensor(const executorch::aten::Tensor& tensor); + + static TensorPtr newTensorFromJTensor( + facebook::jni::alias_ref jtensor); +}; + +class JEValue : public facebook::jni::JavaClass { + public: + constexpr static const char* kJavaDescriptor = + "Lorg/pytorch/executorch/EValue;"; + + constexpr static int kTypeCodeTensor = 1; + constexpr static int kTypeCodeString = 2; + constexpr static int kTypeCodeDouble = 3; + constexpr static int kTypeCodeInt = 4; + constexpr static int kTypeCodeBool = 5; + + static facebook::jni::local_ref newJEValueFromEValue( + runtime::EValue evalue); + + static TensorPtr JEValueToTensorImpl( + facebook::jni::alias_ref JEValue); +}; + +class ExecuTorchTrainingJni + : public facebook::jni::HybridClass { + private: + friend HybridBase; + std::unique_ptr module_; + + public: + constexpr static auto kJavaDescriptor = + "Lorg/pytorch/executorch/TrainingModule;"; + + ExecuTorchTrainingJni( + facebook::jni::alias_ref modelPath, + facebook::jni::alias_ref dataPath) { + auto modelPathString = modelPath->toStdString(); + auto modelLoaderRes = FileDataLoader::from(modelPathString.c_str()); + if (modelLoaderRes.error() != Error::Ok) { + facebook::jni::throwNewJavaException( + "java/lang/Exception", + "Failed to open model file: %s", + modelPathString.c_str()); + } + auto modelLoader = + std::make_unique(std::move(modelLoaderRes.get())); + + std::unique_ptr dataLoader = nullptr; + auto dataPathString = dataPath->toStdString(); + if (!dataPathString.empty()) { + auto dataLoaderRes = FileDataLoader::from(dataPathString.c_str()); + if (dataLoaderRes.error() != Error::Ok) { + facebook::jni::throwNewJavaException( + "java/lang/Exception", + "Failed to open ptd file: %s", + dataPathString.c_str()); + } + dataLoader = + std::make_unique(std::move(dataLoaderRes.get())); + } + + module_ = std::make_unique( + std::move(modelLoader), + nullptr, + nullptr, + nullptr, + std::move(dataLoader)); + } + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref, + facebook::jni::alias_ref modelPath, + facebook::jni::alias_ref dataPath) { + return makeCxxInstance(modelPath, dataPath); + } + + facebook::jni::local_ref> + executeForwardBackward( + facebook::jni::alias_ref methodName, + facebook::jni::alias_ref< + facebook::jni::JArrayClass::javaobject> + jinputs) { + std::vector evalues; + std::vector tensors; + + static const auto typeCodeField = + JEValue::javaClassStatic()->getField("mTypeCode"); + + for (int i = 0; i < jinputs->size(); i++) { + auto jevalue = jinputs->getElement(i); + const auto typeCode = jevalue->getFieldValue(typeCodeField); + if (typeCode == JEValue::kTypeCodeTensor) { + tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue)); + evalues.emplace_back(tensors.back()); + } else if (typeCode == JEValue::kTypeCodeInt) { + int64_t value = jevalue->getFieldValue(typeCodeField); + evalues.emplace_back(value); + } else if (typeCode == JEValue::kTypeCodeDouble) { + double value = jevalue->getFieldValue(typeCodeField); + evalues.emplace_back(value); + } else if (typeCode == JEValue::kTypeCodeBool) { + bool value = jevalue->getFieldValue(typeCodeField); + evalues.emplace_back(value); + } + } + + auto result = + module_->execute_forward_backward(methodName->toStdString(), evalues); + if (!result.ok()) { + facebook::jni::throwNewJavaException( + "java/lang/Exception", + "Execution of forward_backward for method %s failed with status 0x%" PRIx32, + methodName->toStdString().c_str(), + static_cast(result.error())); + } + + facebook::jni::local_ref> jresult = + facebook::jni::JArrayClass::newArray(result.get().size()); + + for (int i = 0; i < result.get().size(); i++) { + auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]); + jresult->setElement(i, *jevalue); + } + return jresult; + } + + facebook::jni::local_ref< + facebook::jni::JMap> + namedParameters(facebook::jni::alias_ref methodName) { + auto method = methodName->toStdString(); + auto result = module_->named_parameters(method); + if (!result.ok()) { + facebook::jni::throwNewJavaException( + "java/lang/Exception", + "Getting named parameters for method %s failed with status 0x%" PRIx32, + method.c_str(), + static_cast(result.error())); + } + facebook::jni::local_ref< + facebook::jni::JHashMap> + parameters = facebook::jni:: + JHashMap::create(); + for (auto& [layer, tensor] : result.get()) { + parameters->put( + facebook::jni::make_jstring(layer.data()), + TensorHybrid::newJTensorFromTensor(tensor)); + } + return parameters; + } + + facebook::jni::local_ref< + facebook::jni::JMap> + namedGradients(facebook::jni::alias_ref methodName) { + auto method = methodName->toStdString(); + auto result = module_->named_gradients(method); + if (!result.ok()) { + facebook::jni::throwNewJavaException( + "java/lang/Exception", + "Getting named gradients for method %s failed with status 0x%" PRIx32, + method.c_str(), + static_cast(result.error())); + } + facebook::jni::local_ref< + facebook::jni::JHashMap> + gradients = facebook::jni::JHashMap:: + create(); + for (auto& [layer, tensor] : result.get()) { + gradients->put( + facebook::jni::make_jstring(layer.data()), + TensorHybrid::newJTensorFromTensor(tensor)); + } + return gradients; + } + + static void registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", ExecuTorchTrainingJni::initHybrid), + makeNativeMethod( + "executeForwardBackwardNative", + ExecuTorchTrainingJni::executeForwardBackward), + makeNativeMethod( + "namedParametersNative", ExecuTorchTrainingJni::namedParameters), + makeNativeMethod( + "namedGradientsNative", ExecuTorchTrainingJni::namedGradients), + }); + } +}; + +class SGDHybrid : public facebook::jni::HybridClass { + public: + constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/SGD;"; + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref, + facebook::jni::alias_ref< + facebook::jni::JMap> + namedParameters, + jdouble learningRate, + jdouble momentum, + jdouble dampening, + jdouble weightDecay, + jboolean nesterov) { + return makeCxxInstance( + namedParameters, + learningRate, + momentum, + dampening, + weightDecay, + nesterov); + } + + SGDHybrid( + facebook::jni::alias_ref< + facebook::jni::JMap> + namedParameters, + jdouble learningRate, + jdouble momentum, + jdouble dampening, + jdouble weightDecay, + jboolean nesterov) { + std::map cppNamedParameters; + + // Avoid vector reallocation to keep string_views valid. + parameterNames_.reserve(namedParameters->size()); + paramTensorPtrs_.reserve(namedParameters->size()); + + auto iterator = namedParameters->begin(); + auto end = namedParameters->end(); + + while (iterator != end) { + auto key = iterator->first; + auto value = iterator->second; + + std::string paramName = key->toStdString(); + TensorPtr tensor = TensorHybrid::newTensorFromJTensor(value); + + // Store the parameter name and tensor + parameterNames_.push_back(paramName); + paramTensorPtrs_.push_back(tensor); + cppNamedParameters.emplace( + std::string_view(parameterNames_.back()), *tensor); + + ++iterator; + } + + optimizer::SGDOptions options( + learningRate, momentum, dampening, weightDecay, nesterov); + sgdOptimizer_ = + std::make_unique(cppNamedParameters, options); + } + + void + step(facebook::jni::alias_ref< + facebook::jni::JMap> namedGradients) { + std::map cppNamedGradients; + std::vector gradientNames; + std::vector tensorKeepalives; + + gradientNames.reserve(namedGradients->size()); + tensorKeepalives.reserve(namedGradients->size()); + + auto iterator = namedGradients->begin(); + auto end = namedGradients->end(); + + while (iterator != end) { + auto key = iterator->first; + auto value = iterator->second; + + std::string gradName = key->toStdString(); + TensorPtr tensor = TensorHybrid::newTensorFromJTensor(value); + + // Store the gradient name and tensor + gradientNames.push_back(gradName); + tensorKeepalives.push_back(tensor); + cppNamedGradients.emplace( + std::string_view(gradientNames.back()), *tensor); + + ++iterator; + } + + auto result = sgdOptimizer_->step(cppNamedGradients); + if (result != ::executorch::runtime::Error::Ok) { + facebook::jni::throwNewJavaException( + "java/lang/Exception", + "SGD optimization step failed with status 0x%" PRIx32, + static_cast(result)); + } + } + + static void registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", SGDHybrid::initHybrid), + makeNativeMethod("stepNative", SGDHybrid::step), + }); + } + + private: + friend HybridBase; + std::unique_ptr sgdOptimizer_; + std::vector + parameterNames_; // Store parameter names to keep string_view valid + std::vector + paramTensorPtrs_; // Store parameter tensors to keep TensorPtrs valid. +}; + +} // namespace executorch::extension + +// Function to register training module natives +void register_natives_for_training() { + executorch::extension::ExecuTorchTrainingJni::registerNatives(); + executorch::extension::SGDHybrid::registerNatives(); +}; diff --git a/extension/android/jni/selective_jni.buck.bzl b/extension/android/jni/selective_jni.buck.bzl index d557606b7d1..e114d7c0b76 100644 --- a/extension/android/jni/selective_jni.buck.bzl +++ b/extension/android/jni/selective_jni.buck.bzl @@ -9,6 +9,7 @@ def selective_jni_target(name, deps, srcs = [], soname = "libexecutorch.$(ext)") name = name, srcs = [ "//xplat/executorch/extension/android/jni:jni_layer.cpp", + "//xplat/executorch/extension/android/jni:jni_layer_training.cpp", "//xplat/executorch/extension/android/jni:jni_layer_runtime.cpp", ] + srcs, allow_jni_merging = False, @@ -21,10 +22,13 @@ def selective_jni_target(name, deps, srcs = [], soname = "libexecutorch.$(ext)") "//third-party/glog:glog", "//xplat/executorch/extension/android/jni:jni_headers", "//xplat/executorch/extension/android/jni:log_provider_static", + "//xplat/executorch/extension/data_loader:file_data_loader_static", "//xplat/executorch/extension/module:module_static", "//xplat/executorch/extension/runner_util:inputs_static", "//xplat/executorch/extension/tensor:tensor_static", "//xplat/executorch/extension/threadpool:threadpool_static", + "//xplat/executorch/extension/training/module:training_module_static", + "//xplat/executorch/extension/training/optimizer:sgd_static", third_party_dep("cpuinfo"), ] + deps, )