Skip to content

Commit 7f4a62d

Browse files
committed
Add TrainingModule and SGD JNI
As title, adds wrappers together with unit test based on XOR train.cpp example.
1 parent 2f55193 commit 7f4a62d

File tree

11 files changed

+744
-48
lines changed

11 files changed

+744
-48
lines changed

extension/android/BUCK

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ non_fbcode_target(_kind = fb_android_library,
1212
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
1313
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
1414
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
15+
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
16+
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
1517
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
1618
],
1719
autoglob = False,

extension/android/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch)
6464
find_package(executorch CONFIG REQUIRED)
6565
target_link_options_shared_lib(executorch)
6666

67-
add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp)
67+
add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp jni/jni_layer_training.cpp)
6868

6969
set(link_libraries)
7070
list(
@@ -77,6 +77,7 @@ list(
7777
extension_runner_util
7878
extension_tensor
7979
extension_threadpool
80+
extension_training
8081
fbjni
8182
)
8283

extension/android/executorch_android/android_test_setup.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ prepare_add() {
1818
python3 -m test.models.export_program --modules "ModuleAdd" --outdir "${BASEDIR}/src/androidTest/resources/"
1919
}
2020

21+
prepare_xor() {
22+
python3 -m extension.training.examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/" --external
23+
}
24+
2125
prepare_tinyllama() {
2226
pushd "${BASEDIR}/../../../"
2327
curl -C - -Ls "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt" --output stories15M.pt
@@ -43,5 +47,6 @@ prepare_vision() {
4347
}
4448

4549
prepare_add
50+
prepare_xor
4651
prepare_tinyllama
4752
prepare_vision

extension/android/executorch_android/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ dependencies {
5050
implementation libs.core.ktx
5151
testImplementation 'junit:junit:4.12'
5252
testImplementation 'org.assertj:assertj-core:3.27.2'
53+
testImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
5354
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
5455
androidTestImplementation 'androidx.test:rules:1.2.0'
5556
androidTestImplementation 'commons-io:commons-io:2.4'
5657
androidTestImplementation 'org.json:json:20250107'
58+
androidTestImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
5759
}
5860

5961
import com.vanniktech.maven.publish.SonatypeHost
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
package org.pytorch.executorch
9+
10+
import android.Manifest
11+
import android.util.Log
12+
import androidx.test.ext.junit.runners.AndroidJUnit4
13+
import androidx.test.rule.GrantPermissionRule
14+
import java.io.File
15+
import java.io.IOException
16+
import java.net.URISyntaxException
17+
import org.apache.commons.io.FileUtils
18+
import org.junit.Assert
19+
import org.junit.Rule
20+
import org.junit.Test
21+
import org.junit.runner.RunWith
22+
import org.pytorch.executorch.TestFileUtils.getTestFilePath
23+
import kotlin.random.Random
24+
import kotlin.test.assertContains
25+
26+
/** Unit tests for [TrainingModule]. */
27+
@RunWith(AndroidJUnit4::class)
28+
class TrainingModuleE2ETest {
29+
@get:Rule
30+
var runtimePermissionRule: GrantPermissionRule =
31+
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)
32+
33+
@Test
34+
@Throws(IOException::class, URISyntaxException::class)
35+
fun testTrainXOR() {
36+
val pteFilePath = "/xor.pte"
37+
val ptdFilePath = "/xor.ptd"
38+
39+
val pteFile = File(getTestFilePath(pteFilePath))
40+
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
41+
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
42+
pteInputStream.close()
43+
44+
val ptdFile = File(getTestFilePath(ptdFilePath))
45+
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
46+
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
47+
ptdInputStream.close()
48+
49+
val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
50+
val params = module.namedParameters("forward")
51+
52+
Assert.assertEquals(4, params.size)
53+
assertContains(params, LIN_WEIGHT)
54+
assertContains(params, LIN_BIAS)
55+
assertContains(params, LIN2_WEIGHT)
56+
assertContains(params, LIN2_BIAS)
57+
58+
val sgd = SGD.create(params, 0.5);
59+
val dataset = listOf<Tensor>(
60+
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
61+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
62+
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
63+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
64+
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
65+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
66+
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
67+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
68+
)
69+
70+
val numEpochs = 5000;
71+
var finalLoss = Float.MAX_VALUE
72+
73+
for (i in 0 until numEpochs) {
74+
val inputDex = 2 * Random.nextInt(dataset.size / 2)
75+
val targetDex = inputDex + 1
76+
val input = dataset.get(inputDex)
77+
val target = dataset.get(targetDex)
78+
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
79+
val gradients = module.namedGradients("forward")
80+
81+
if (i == 0) {
82+
Assert.assertEquals(4, gradients.size)
83+
assertContains(gradients, LIN_WEIGHT)
84+
assertContains(gradients, LIN_BIAS)
85+
assertContains(gradients, LIN2_WEIGHT)
86+
assertContains(gradients, LIN2_BIAS)
87+
}
88+
89+
if (i % 500 == 0 || i == numEpochs - 1) {
90+
Log.i(
91+
"testTrainXOR",
92+
String.format(
93+
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
94+
i,
95+
out[0].toTensor().getDataAsFloatArray()[0],
96+
input.getDataAsFloatArray()[0],
97+
input.getDataAsFloatArray()[1],
98+
out[1].toTensor().getDataAsLongArray()[0],
99+
target.getDataAsLongArray()[0]));
100+
}
101+
102+
sgd.step(gradients)
103+
104+
if (i == numEpochs - 1) {
105+
finalLoss = out[0].toTensor().dataAsFloatArray[0]
106+
}
107+
}
108+
Assert.assertTrue(finalLoss < 0.1f)
109+
}
110+
111+
companion object {
112+
private const val LIN_WEIGHT = "net.linear.weight"
113+
private const val LIN_BIAS = "net.linear.bias"
114+
private const val LIN2_WEIGHT = "net.linear2.weight"
115+
private const val LIN2_BIAS = "net.linear2.bias"
116+
}
117+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import com.facebook.jni.HybridData;
12+
import com.facebook.jni.annotations.DoNotStrip;
13+
import com.facebook.soloader.nativeloader.NativeLoader;
14+
import com.facebook.soloader.nativeloader.SystemDelegate;
15+
import java.util.Map;
16+
import org.pytorch.executorch.annotations.Experimental;
17+
18+
/**
19+
* Java wrapper for ExecuTorch SGD Optimizer.
20+
*
21+
* <p>Warning: These APIs are experimental and subject to change without notice
22+
*/
23+
@Experimental
24+
public class SGD {
25+
26+
static {
27+
if (!NativeLoader.isInitialized()) {
28+
NativeLoader.init(new SystemDelegate());
29+
}
30+
// Loads libexecutorch.so from jniLibs
31+
NativeLoader.loadLibrary("executorch");
32+
}
33+
34+
private final HybridData mHybridData;
35+
36+
@DoNotStrip
37+
private static native HybridData initHybrid(
38+
Map<String, Tensor> namedParameters,
39+
double learningRate,
40+
double momentum,
41+
double dampening,
42+
double weightDecay,
43+
boolean nesterov);
44+
45+
private SGD(
46+
Map<String, Tensor> namedParameters,
47+
double learningRate,
48+
double momentum,
49+
double dampening,
50+
double weightDecay,
51+
boolean nesterov) {
52+
mHybridData =
53+
initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
54+
}
55+
56+
/**
57+
* Creates a new SGD optimizer with the specified parameters and options.
58+
*
59+
* @param namedParameters Map of parameter names to tensors to be optimized
60+
* @param learningRate The learning rate for the optimizer
61+
* @param momentum The momentum factor (default: 0)
62+
* @param dampening The dampening for momentum (default: 0)
63+
* @param weightDecay The weight decay (L2 penalty) (default: 0)
64+
* @param nesterov Whether to use Nesterov momentum (default: false)
65+
* @return new {@link org.pytorch.executorch.SGD} object
66+
*/
67+
public static SGD create(
68+
Map<String, Tensor> namedParameters,
69+
double learningRate,
70+
double momentum,
71+
double dampening,
72+
double weightDecay,
73+
boolean nesterov) {
74+
return new SGD(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
75+
}
76+
77+
/**
78+
* Creates a new SGD optimizer with default options.
79+
*
80+
* @param namedParameters Map of parameter names to tensors to be optimized
81+
* @param learningRate The learning rate for the optimizer
82+
* @return new {@link org.pytorch.executorch.SGD} object
83+
*/
84+
public static SGD create(Map<String, Tensor> namedParameters, double learningRate) {
85+
return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false);
86+
}
87+
88+
/**
89+
* Performs a single optimization step using the provided gradients.
90+
*
91+
* @param namedGradients Map of parameter names to gradient tensors
92+
*/
93+
public void step(Map<String, Tensor> namedGradients) {
94+
if (!mHybridData.isValid()) {
95+
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
96+
}
97+
stepNative(namedGradients);
98+
}
99+
100+
@DoNotStrip
101+
private native void stepNative(Map<String, Tensor> namedGradients);
102+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import android.util.Log;
12+
import com.facebook.jni.HybridData;
13+
import com.facebook.jni.annotations.DoNotStrip;
14+
import com.facebook.soloader.nativeloader.NativeLoader;
15+
import com.facebook.soloader.nativeloader.SystemDelegate;
16+
import java.io.File;
17+
import java.util.HashMap;
18+
import java.util.Map;
19+
import org.pytorch.executorch.annotations.Experimental;
20+
21+
/**
22+
* Java wrapper for ExecuTorch TrainingModule.
23+
*
24+
* <p>Warning: These APIs are experimental and subject to change without notice
25+
*/
26+
@Experimental
27+
public class TrainingModule {
28+
29+
static {
30+
if (!NativeLoader.isInitialized()) {
31+
NativeLoader.init(new SystemDelegate());
32+
}
33+
// Loads libexecutorch.so from jniLibs
34+
NativeLoader.loadLibrary("executorch");
35+
}
36+
37+
private final HybridData mHybridData;
38+
39+
@DoNotStrip
40+
private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath);
41+
42+
private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
43+
mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath);
44+
}
45+
46+
/**
47+
* Loads a serialized ExecuTorch module from the specified path on the disk.
48+
*
49+
* @param modelPath path to file that contains the serialized ExecuTorch module.
50+
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
51+
*/
52+
public static TrainingModule load(final String modelPath, final String dataPath) {
53+
File modelFile = new File(modelPath);
54+
if (!modelFile.canRead() || !modelFile.isFile()) {
55+
throw new RuntimeException("Cannot load model path!! " + modelPath);
56+
}
57+
File dataFile = new File(dataPath);
58+
if (!dataFile.canRead() || !dataFile.isFile()) {
59+
throw new RuntimeException("Cannot load data path!! " + dataPath);
60+
}
61+
return new TrainingModule(modelPath, dataPath);
62+
}
63+
64+
/**
65+
* Runs the specified method of this module with the specified arguments.
66+
*
67+
* @param methodName name of the ExecuTorch method to run.
68+
* @param inputs arguments that will be passed to ExecuTorch method.
69+
* @return return value from the method.
70+
*/
71+
public EValue[] executeForwardBackward(String methodName, EValue... inputs) {
72+
if (!mHybridData.isValid()) {
73+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
74+
return new EValue[0];
75+
}
76+
return executeForwardBackwardNative(methodName, inputs);
77+
}
78+
79+
@DoNotStrip
80+
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs);
81+
82+
public Map<String, Tensor> namedParameters(String methodName) {
83+
if (!mHybridData.isValid()) {
84+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
85+
return new HashMap<String, Tensor>();
86+
}
87+
return namedParametersNative(methodName);
88+
}
89+
90+
@DoNotStrip
91+
private native Map<String, Tensor> namedParametersNative(String methodName);
92+
93+
public Map<String, Tensor> namedGradients(String methodName) {
94+
if (!mHybridData.isValid()) {
95+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
96+
return new HashMap<String, Tensor>();
97+
}
98+
return namedGradientsNative(methodName);
99+
}
100+
101+
@DoNotStrip
102+
private native Map<String, Tensor> namedGradientsNative(String methodName);
103+
}

0 commit comments

Comments
 (0)