Skip to content

Commit 8e96c94

Browse files
authored
Add TrainingModule and SGD JNI + PTE-only Training Workflow (#12247)
### Summary Adds JNI for SGD and TrainingModule, including a unit test that mirrors train.cpp for a simple XOR example. Also makes the following change: - Refactor jni_layer.cpp JTensor <--> Tensor conversion to be a general TensorHybrid utility. This is useful for TrainingModule classes that move maps of Tensors around. - Updates `android_test_setup.sh` to match the pushd-popd directory movement for consistency and flexibility. This is also used to fix errors with generating the XOR files. Training dependencies are already enabled for Java JNI library, so we skip adding additional guard flags. ### Test plan Updated XOR tests that check .pte only convergence workflow. ``` sh scripts/build_android_library.sh sh executorch_android/android_test_setup.sh // Creates xor.ptd, xor.pte, and xor_full.pte files. ./gradlew :executorch_android:connectedAndroidTest // Added unit test to check toy model convergence loss < 0.01 ``` For the XOR tests, the device logs will show convergence values: ``` I testTrainXOR: Step 0, Loss 0.683540, Input [1, 0], Prediction 1, Label 1 ... I testTrainXOR: Step 4500, Loss 0.000994, Input [0, 0], Prediction 0, Label 0 ```
1 parent d506312 commit 8e96c94

File tree

11 files changed

+893
-45
lines changed

11 files changed

+893
-45
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pip-out/
2626
*.model
2727
tokenizer.json
2828
*.pte
29+
*.ptd
2930
!test_bpe_tokenizer.bin
3031
!test_tiktoken_tokenizer.model
3132

extension/android/BUCK

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ non_fbcode_target(_kind = fb_android_library,
1313
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
1414
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
1515
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
16+
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
17+
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
1618
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
1719
],
1820
autoglob = False,

extension/android/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ if(EXECUTORCH_JNI_CUSTOM_LIBRARY)
145145
)
146146
endif()
147147

148+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
149+
target_sources(executorch_jni PRIVATE jni/jni_layer_training.cpp jni/log.cpp)
150+
list(APPEND link_libraries extension_training)
151+
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_TRAINING=1)
152+
endif()
153+
148154
if(EXECUTORCH_BUILD_LLAMA_JNI)
149155
target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp)
150156
list(APPEND link_libraries llama_runner llava_runner)

extension/android/executorch_android/android_test_setup.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,17 @@ which "${PYTHON_EXECUTABLE}"
1515
BASEDIR=$(dirname "$(realpath $0)")
1616

1717
prepare_add() {
18+
pushd "${BASEDIR}/../../../"
1819
python3 -m test.models.export_program --modules "ModuleAdd" --outdir "${BASEDIR}/src/androidTest/resources/"
20+
popd
21+
}
22+
23+
prepare_xor() {
24+
pushd "${BASEDIR}/../../training/"
25+
python3 -m examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/"
26+
mv "${BASEDIR}/src/androidTest/resources/xor.pte" "${BASEDIR}/src/androidTest/resources/xor_full.pte"
27+
python3 -m examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/" --external
28+
popd
1929
}
2030

2131
prepare_tinyllama() {
@@ -43,5 +53,6 @@ prepare_vision() {
4353
}
4454

4555
prepare_add
56+
prepare_xor
4657
prepare_tinyllama
4758
prepare_vision

extension/android/executorch_android/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ dependencies {
5050
implementation libs.core.ktx
5151
testImplementation 'junit:junit:4.12'
5252
testImplementation 'org.assertj:assertj-core:3.27.2'
53+
testImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
5354
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
5455
androidTestImplementation 'androidx.test:rules:1.2.0'
5556
androidTestImplementation 'commons-io:commons-io:2.4'
5657
androidTestImplementation 'org.json:json:20250107'
58+
androidTestImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
5759
}
5860

5961
import com.vanniktech.maven.publish.SonatypeHost
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
package org.pytorch.executorch
9+
10+
import android.Manifest
11+
import android.util.Log
12+
import androidx.test.ext.junit.runners.AndroidJUnit4
13+
import androidx.test.rule.GrantPermissionRule
14+
import java.io.File
15+
import java.io.IOException
16+
import java.net.URISyntaxException
17+
import org.apache.commons.io.FileUtils
18+
import org.junit.Assert
19+
import org.junit.Rule
20+
import org.junit.Test
21+
import org.junit.runner.RunWith
22+
import org.pytorch.executorch.TestFileUtils.getTestFilePath
23+
import kotlin.random.Random
24+
import kotlin.test.assertContains
25+
26+
/** Unit tests for [TrainingModule]. */
27+
@RunWith(AndroidJUnit4::class)
28+
class TrainingModuleE2ETest {
29+
@get:Rule
30+
var runtimePermissionRule: GrantPermissionRule =
31+
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)
32+
33+
@Test
34+
@Throws(IOException::class, URISyntaxException::class)
35+
fun testTrainXOR() {
36+
val pteFilePath = "/xor.pte"
37+
val ptdFilePath = "/xor.ptd"
38+
39+
val pteFile = File(getTestFilePath(pteFilePath))
40+
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
41+
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
42+
pteInputStream.close()
43+
44+
val ptdFile = File(getTestFilePath(ptdFilePath))
45+
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
46+
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
47+
ptdInputStream.close()
48+
49+
val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
50+
val params = module.namedParameters("forward")
51+
52+
Assert.assertEquals(4, params.size)
53+
assertContains(params, LIN_WEIGHT)
54+
assertContains(params, LIN_BIAS)
55+
assertContains(params, LIN2_WEIGHT)
56+
assertContains(params, LIN2_BIAS)
57+
58+
val sgd = SGD.create(params, 0.5);
59+
val dataset = listOf<Tensor>(
60+
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
61+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
62+
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
63+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
64+
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
65+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
66+
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
67+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
68+
)
69+
70+
val numEpochs = 5000;
71+
var finalLoss = Float.MAX_VALUE
72+
73+
for (i in 0 until numEpochs) {
74+
val inputDex = 2 * Random.nextInt(dataset.size / 2)
75+
val targetDex = inputDex + 1
76+
val input = dataset.get(inputDex)
77+
val target = dataset.get(targetDex)
78+
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
79+
val gradients = module.namedGradients("forward")
80+
81+
if (i == 0) {
82+
Assert.assertEquals(4, gradients.size)
83+
assertContains(gradients, LIN_WEIGHT)
84+
assertContains(gradients, LIN_BIAS)
85+
assertContains(gradients, LIN2_WEIGHT)
86+
assertContains(gradients, LIN2_BIAS)
87+
}
88+
89+
if (i % 500 == 0 || i == numEpochs - 1) {
90+
Log.i(
91+
"testTrainXOR",
92+
String.format(
93+
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
94+
i,
95+
out[0].toTensor().getDataAsFloatArray()[0],
96+
input.getDataAsFloatArray()[0],
97+
input.getDataAsFloatArray()[1],
98+
out[1].toTensor().getDataAsLongArray()[0],
99+
target.getDataAsLongArray()[0]));
100+
}
101+
102+
sgd.step(gradients)
103+
104+
if (i == numEpochs - 1) {
105+
finalLoss = out[0].toTensor().dataAsFloatArray[0]
106+
}
107+
}
108+
Assert.assertTrue(finalLoss < 0.1f)
109+
}
110+
111+
@Test
112+
@Throws(IOException::class, URISyntaxException::class)
113+
fun testTrainXOR_PTEOnly() {
114+
val pteFilePath = "/xor_full.pte"
115+
116+
val pteFile = File(getTestFilePath(pteFilePath))
117+
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
118+
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
119+
pteInputStream.close()
120+
121+
val module = TrainingModule.load(getTestFilePath(pteFilePath));
122+
val params = module.namedParameters("forward")
123+
124+
Assert.assertEquals(4, params.size)
125+
assertContains(params, LIN_WEIGHT)
126+
assertContains(params, LIN_BIAS)
127+
assertContains(params, LIN2_WEIGHT)
128+
assertContains(params, LIN2_BIAS)
129+
130+
val sgd = SGD.create(params, 0.5);
131+
val dataset = listOf<Tensor>(
132+
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
133+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
134+
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
135+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
136+
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
137+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
138+
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
139+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
140+
)
141+
142+
val numEpochs = 5000;
143+
var finalLoss = Float.MAX_VALUE
144+
145+
for (i in 0 until numEpochs) {
146+
val inputDex = 2 * Random.nextInt(dataset.size / 2)
147+
val targetDex = inputDex + 1
148+
val input = dataset.get(inputDex)
149+
val target = dataset.get(targetDex)
150+
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
151+
val gradients = module.namedGradients("forward")
152+
153+
if (i == 0) {
154+
Assert.assertEquals(4, gradients.size)
155+
assertContains(gradients, LIN_WEIGHT)
156+
assertContains(gradients, LIN_BIAS)
157+
assertContains(gradients, LIN2_WEIGHT)
158+
assertContains(gradients, LIN2_BIAS)
159+
}
160+
161+
if (i % 500 == 0 || i == numEpochs - 1) {
162+
Log.i(
163+
"testTrainXOR_PTEOnly",
164+
String.format(
165+
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
166+
i,
167+
out[0].toTensor().getDataAsFloatArray()[0],
168+
input.getDataAsFloatArray()[0],
169+
input.getDataAsFloatArray()[1],
170+
out[1].toTensor().getDataAsLongArray()[0],
171+
target.getDataAsLongArray()[0]));
172+
}
173+
174+
sgd.step(gradients)
175+
176+
if (i == numEpochs - 1) {
177+
finalLoss = out[0].toTensor().dataAsFloatArray[0]
178+
}
179+
}
180+
Assert.assertTrue(finalLoss < 0.1f)
181+
}
182+
183+
@Test
184+
@Throws(IOException::class)
185+
fun testMissingPteFile() {
186+
val exception = Assert.assertThrows(RuntimeException::class.java) {
187+
TrainingModule.load(getTestFilePath(MISSING_PTE_NAME))
188+
}
189+
Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME))
190+
}
191+
192+
@Test
193+
@Throws(IOException::class)
194+
fun testMissingPtdFile() {
195+
val exception = Assert.assertThrows(RuntimeException::class.java) {
196+
val pteFilePath = "/xor.pte"
197+
val pteFile = File(getTestFilePath(pteFilePath))
198+
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
199+
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
200+
pteInputStream.close()
201+
202+
TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME))
203+
}
204+
Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME))
205+
}
206+
207+
companion object {
208+
private const val LIN_WEIGHT = "net.linear.weight"
209+
private const val LIN_BIAS = "net.linear.bias"
210+
private const val LIN2_WEIGHT = "net.linear2.weight"
211+
private const val LIN2_BIAS = "net.linear2.bias"
212+
private const val MISSING_PTE_NAME = "/missing.pte"
213+
private const val MISSING_PTD_NAME = "/missing.ptd"
214+
}
215+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import com.facebook.jni.HybridData;
12+
import com.facebook.jni.annotations.DoNotStrip;
13+
import com.facebook.soloader.nativeloader.NativeLoader;
14+
import com.facebook.soloader.nativeloader.SystemDelegate;
15+
import java.util.Map;
16+
import org.pytorch.executorch.annotations.Experimental;
17+
18+
/**
19+
* Java wrapper for ExecuTorch SGD Optimizer.
20+
*
21+
* <p>Warning: These APIs are experimental and subject to change without notice
22+
*/
23+
@Experimental
24+
public class SGD {
25+
26+
static {
27+
if (!NativeLoader.isInitialized()) {
28+
NativeLoader.init(new SystemDelegate());
29+
}
30+
// Loads libexecutorch.so from jniLibs
31+
NativeLoader.loadLibrary("executorch");
32+
}
33+
34+
private final HybridData mHybridData;
35+
36+
@DoNotStrip
37+
private static native HybridData initHybrid(
38+
Map<String, Tensor> namedParameters,
39+
double learningRate,
40+
double momentum,
41+
double dampening,
42+
double weightDecay,
43+
boolean nesterov);
44+
45+
private SGD(
46+
Map<String, Tensor> namedParameters,
47+
double learningRate,
48+
double momentum,
49+
double dampening,
50+
double weightDecay,
51+
boolean nesterov) {
52+
mHybridData =
53+
initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
54+
}
55+
56+
/**
57+
* Creates a new SGD optimizer with the specified parameters and options.
58+
*
59+
* @param namedParameters Map of parameter names to tensors to be optimized
60+
* @param learningRate The learning rate for the optimizer
61+
* @param momentum The momentum value
62+
* @param dampening The dampening value
63+
* @param weightDecay The weight decay value
64+
* @param nesterov Whether to use Nesterov momentum
65+
* @return new {@link org.pytorch.executorch.SGD} object
66+
*/
67+
public static SGD create(
68+
Map<String, Tensor> namedParameters,
69+
double learningRate,
70+
double momentum,
71+
double dampening,
72+
double weightDecay,
73+
boolean nesterov) {
74+
return new SGD(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
75+
}
76+
77+
/**
78+
* Creates a new SGD optimizer with default options.
79+
*
80+
* @param namedParameters Map of parameter names to tensors to be optimized
81+
* @param learningRate The learning rate for the optimizer
82+
* @return new {@link org.pytorch.executorch.SGD} object
83+
*/
84+
public static SGD create(Map<String, Tensor> namedParameters, double learningRate) {
85+
return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false);
86+
}
87+
88+
/**
89+
* Performs a single optimization step using the provided gradients.
90+
*
91+
* @param namedGradients Map of parameter names to gradient tensors
92+
*/
93+
public void step(Map<String, Tensor> namedGradients) {
94+
if (!mHybridData.isValid()) {
95+
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
96+
}
97+
stepNative(namedGradients);
98+
}
99+
100+
@DoNotStrip
101+
private native void stepNative(Map<String, Tensor> namedGradients);
102+
}

0 commit comments

Comments
 (0)