Skip to content

[RFC] Add TrainingModule and SGD JNI + PTE-only Training Workflow #12247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pip-out/
*.model
tokenizer.json
*.pte
*.ptd
!test_bpe_tokenizer.bin
!test_tiktoken_tokenizer.model

Expand Down
2 changes: 2 additions & 0 deletions extension/android/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ non_fbcode_target(_kind = fb_android_library,
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
],
autoglob = False,
Expand Down
3 changes: 2 additions & 1 deletion extension/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch)
find_package(executorch CONFIG REQUIRED)
target_link_options_shared_lib(executorch)

add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp)
add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp jni/jni_layer_training.cpp)

set(link_libraries)
list(
Expand All @@ -77,6 +77,7 @@ list(
extension_runner_util
extension_tensor
extension_threadpool
extension_training
fbjni
)

Expand Down
11 changes: 11 additions & 0 deletions extension/android/executorch_android/android_test_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@ which "${PYTHON_EXECUTABLE}"
BASEDIR=$(dirname "$(realpath $0)")

prepare_add() {
pushd "${BASEDIR}/../../../"
python3 -m test.models.export_program --modules "ModuleAdd" --outdir "${BASEDIR}/src/androidTest/resources/"
popd
}

prepare_xor() {
pushd "${BASEDIR}/../../../extension/training/"
python3 -m examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/"
mv "${BASEDIR}/src/androidTest/resources/xor.pte" "${BASEDIR}/src/androidTest/resources/xor_full.pte"
python3 -m examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/" --external
popd
}

prepare_tinyllama() {
Expand Down Expand Up @@ -43,5 +53,6 @@ prepare_vision() {
}

prepare_add
prepare_xor
prepare_tinyllama
prepare_vision
2 changes: 2 additions & 0 deletions extension/android/executorch_android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ dependencies {
implementation libs.core.ktx
testImplementation 'junit:junit:4.12'
testImplementation 'org.assertj:assertj-core:3.27.2'
testImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
androidTestImplementation 'androidx.test:rules:1.2.0'
androidTestImplementation 'commons-io:commons-io:2.4'
androidTestImplementation 'org.json:json:20250107'
androidTestImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
}

import com.vanniktech.maven.publish.SonatypeHost
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
package org.pytorch.executorch

import android.Manifest
import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.rule.GrantPermissionRule
import java.io.File
import java.io.IOException
import java.net.URISyntaxException
import org.apache.commons.io.FileUtils
import org.junit.Assert
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.pytorch.executorch.TestFileUtils.getTestFilePath
import kotlin.random.Random
import kotlin.test.assertContains

/** Unit tests for [TrainingModule]. */
@RunWith(AndroidJUnit4::class)
class TrainingModuleE2ETest {
@get:Rule
var runtimePermissionRule: GrantPermissionRule =
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)

@Test
@Throws(IOException::class, URISyntaxException::class)
fun testTrainXOR() {
val pteFilePath = "/xor.pte"
val ptdFilePath = "/xor.ptd"

val pteFile = File(getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

val ptdFile = File(getTestFilePath(ptdFilePath))
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
ptdInputStream.close()

val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
val params = module.namedParameters("forward")

Assert.assertEquals(4, params.size)
assertContains(params, LIN_WEIGHT)
assertContains(params, LIN_BIAS)
assertContains(params, LIN2_WEIGHT)
assertContains(params, LIN2_BIAS)

val sgd = SGD.create(params, 0.5);
val dataset = listOf<Tensor>(
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
)

val numEpochs = 5000;
var finalLoss = Float.MAX_VALUE

for (i in 0 until numEpochs) {
val inputDex = 2 * Random.nextInt(dataset.size / 2)
val targetDex = inputDex + 1
val input = dataset.get(inputDex)
val target = dataset.get(targetDex)
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
val gradients = module.namedGradients("forward")

if (i == 0) {
Assert.assertEquals(4, gradients.size)
assertContains(gradients, LIN_WEIGHT)
assertContains(gradients, LIN_BIAS)
assertContains(gradients, LIN2_WEIGHT)
assertContains(gradients, LIN2_BIAS)
}

if (i % 500 == 0 || i == numEpochs - 1) {
Log.i(
"testTrainXOR",
String.format(
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
i,
out[0].toTensor().getDataAsFloatArray()[0],
input.getDataAsFloatArray()[0],
input.getDataAsFloatArray()[1],
out[1].toTensor().getDataAsLongArray()[0],
target.getDataAsLongArray()[0]));
}

sgd.step(gradients)

if (i == numEpochs - 1) {
finalLoss = out[0].toTensor().dataAsFloatArray[0]
}
}
Assert.assertTrue(finalLoss < 0.1f)
}

@Test
@Throws(IOException::class, URISyntaxException::class)
fun testTrainXOR_PTEOnly() {
val pteFilePath = "/xor_full.pte"

val pteFile = File(getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

val module = TrainingModule.load(getTestFilePath(pteFilePath));
val params = module.namedParameters("forward")

Assert.assertEquals(4, params.size)
assertContains(params, LIN_WEIGHT)
assertContains(params, LIN_BIAS)
assertContains(params, LIN2_WEIGHT)
assertContains(params, LIN2_BIAS)

val sgd = SGD.create(params, 0.5);
val dataset = listOf<Tensor>(
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
)

val numEpochs = 5000;
var finalLoss = Float.MAX_VALUE

for (i in 0 until numEpochs) {
val inputDex = 2 * Random.nextInt(dataset.size / 2)
val targetDex = inputDex + 1
val input = dataset.get(inputDex)
val target = dataset.get(targetDex)
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
val gradients = module.namedGradients("forward")

if (i == 0) {
Assert.assertEquals(4, gradients.size)
assertContains(gradients, LIN_WEIGHT)
assertContains(gradients, LIN_BIAS)
assertContains(gradients, LIN2_WEIGHT)
assertContains(gradients, LIN2_BIAS)
}

if (i % 500 == 0 || i == numEpochs - 1) {
Log.i(
"testTrainXOR_PTEOnly",
String.format(
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
i,
out[0].toTensor().getDataAsFloatArray()[0],
input.getDataAsFloatArray()[0],
input.getDataAsFloatArray()[1],
out[1].toTensor().getDataAsLongArray()[0],
target.getDataAsLongArray()[0]));
}

sgd.step(gradients)

if (i == numEpochs - 1) {
finalLoss = out[0].toTensor().dataAsFloatArray[0]
}
}
Assert.assertTrue(finalLoss < 0.1f)
}

companion object {
private const val LIN_WEIGHT = "net.linear.weight"
private const val LIN_BIAS = "net.linear.bias"
private const val LIN2_WEIGHT = "net.linear2.weight"
private const val LIN2_BIAS = "net.linear2.bias"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.Map;
import org.pytorch.executorch.annotations.Experimental;

/**
* Java wrapper for ExecuTorch SGD Optimizer.
*
* <p>Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
public class SGD {

static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
// Loads libexecutorch.so from jniLibs
NativeLoader.loadLibrary("executorch");
}

private final HybridData mHybridData;

@DoNotStrip
private static native HybridData initHybrid(
Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov);

private SGD(
Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov) {
mHybridData =
initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
}

/**
* Creates a new SGD optimizer with the specified parameters and options.
*
* @param namedParameters Map of parameter names to tensors to be optimized
* @param learningRate The learning rate for the optimizer
* @param momentum The momentum value
* @param dampening The dampening value
* @param weightDecay The weight decay value
* @param nesterov Whether to use Nesterov momentum
* @return new {@link org.pytorch.executorch.SGD} object
*/
public static SGD create(
Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov) {
return new SGD(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
}

/**
* Creates a new SGD optimizer with default options.
*
* @param namedParameters Map of parameter names to tensors to be optimized
* @param learningRate The learning rate for the optimizer
* @return new {@link org.pytorch.executorch.SGD} object
*/
public static SGD create(Map<String, Tensor> namedParameters, double learningRate) {
return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false);
}

/**
* Performs a single optimization step using the provided gradients.
*
* @param namedGradients Map of parameter names to gradient tensors
*/
public void step(Map<String, Tensor> namedGradients) {
if (!mHybridData.isValid()) {
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
}
stepNative(namedGradients);
}

@DoNotStrip
private native void stepNative(Map<String, Tensor> namedGradients);
}
Loading
Loading