Skip to content

Commit aba87ed

Browse files
committed
Add TrainingModule and SGD JNI
As title, adds wrappers together with unit test based on XOR train.cpp example.
1 parent ba19c75 commit aba87ed

File tree

12 files changed

+850
-48
lines changed

12 files changed

+850
-48
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pip-out/
2626
*.model
2727
tokenizer.json
2828
*.pte
29+
*.ptd
2930
!test_bpe_tokenizer.bin
3031
!test_tiktoken_tokenizer.model
3132

extension/android/BUCK

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ non_fbcode_target(_kind = fb_android_library,
1212
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
1313
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
1414
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
15+
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
16+
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
1517
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
1618
],
1719
autoglob = False,

extension/android/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch)
6464
find_package(executorch CONFIG REQUIRED)
6565
target_link_options_shared_lib(executorch)
6666

67-
add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp)
67+
add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp jni/jni_layer_training.cpp)
6868

6969
set(link_libraries)
7070
list(
@@ -77,6 +77,7 @@ list(
7777
extension_runner_util
7878
extension_tensor
7979
extension_threadpool
80+
extension_training
8081
fbjni
8182
)
8283

extension/android/executorch_android/android_test_setup.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,17 @@ which "${PYTHON_EXECUTABLE}"
1515
BASEDIR=$(dirname "$(realpath $0)")
1616

1717
prepare_add() {
18+
pushd "${BASEDIR}/../../../"
1819
python3 -m test.models.export_program --modules "ModuleAdd" --outdir "${BASEDIR}/src/androidTest/resources/"
20+
popd
21+
}
22+
23+
prepare_xor() {
24+
pushd "${BASEDIR}/../../../extension/training/"
25+
python3 -m examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/"
26+
mv "${BASEDIR}/src/androidTest/resources/xor.pte" "${BASEDIR}/src/androidTest/resources/xor_full.pte"
27+
python3 -m examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/" --external
28+
popd
1929
}
2030

2131
prepare_tinyllama() {
@@ -43,5 +53,6 @@ prepare_vision() {
4353
}
4454

4555
prepare_add
56+
prepare_xor
4657
prepare_tinyllama
4758
prepare_vision

extension/android/executorch_android/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ dependencies {
5050
implementation libs.core.ktx
5151
testImplementation 'junit:junit:4.12'
5252
testImplementation 'org.assertj:assertj-core:3.27.2'
53+
testImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
5354
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
5455
androidTestImplementation 'androidx.test:rules:1.2.0'
5556
androidTestImplementation 'commons-io:commons-io:2.4'
5657
androidTestImplementation 'org.json:json:20250107'
58+
androidTestImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
5759
}
5860

5961
import com.vanniktech.maven.publish.SonatypeHost
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
package org.pytorch.executorch
9+
10+
import android.Manifest
11+
import android.util.Log
12+
import androidx.test.ext.junit.runners.AndroidJUnit4
13+
import androidx.test.rule.GrantPermissionRule
14+
import java.io.File
15+
import java.io.IOException
16+
import java.net.URISyntaxException
17+
import org.apache.commons.io.FileUtils
18+
import org.junit.Assert
19+
import org.junit.Rule
20+
import org.junit.Test
21+
import org.junit.runner.RunWith
22+
import org.pytorch.executorch.TestFileUtils.getTestFilePath
23+
import kotlin.random.Random
24+
import kotlin.test.assertContains
25+
26+
/** Unit tests for [TrainingModule]. */
27+
@RunWith(AndroidJUnit4::class)
28+
class TrainingModuleE2ETest {
29+
@get:Rule
30+
var runtimePermissionRule: GrantPermissionRule =
31+
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)
32+
33+
@Test
34+
@Throws(IOException::class, URISyntaxException::class)
35+
fun testTrainXOR() {
36+
val pteFilePath = "/xor.pte"
37+
val ptdFilePath = "/xor.ptd"
38+
39+
val pteFile = File(getTestFilePath(pteFilePath))
40+
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
41+
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
42+
pteInputStream.close()
43+
44+
val ptdFile = File(getTestFilePath(ptdFilePath))
45+
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
46+
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
47+
ptdInputStream.close()
48+
49+
val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
50+
val params = module.namedParameters("forward")
51+
52+
Assert.assertEquals(4, params.size)
53+
assertContains(params, LIN_WEIGHT)
54+
assertContains(params, LIN_BIAS)
55+
assertContains(params, LIN2_WEIGHT)
56+
assertContains(params, LIN2_BIAS)
57+
58+
val sgd = SGD.create(params, 0.5);
59+
val dataset = listOf<Tensor>(
60+
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
61+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
62+
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
63+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
64+
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
65+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
66+
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
67+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
68+
)
69+
70+
val numEpochs = 5000;
71+
var finalLoss = Float.MAX_VALUE
72+
73+
for (i in 0 until numEpochs) {
74+
val inputDex = 2 * Random.nextInt(dataset.size / 2)
75+
val targetDex = inputDex + 1
76+
val input = dataset.get(inputDex)
77+
val target = dataset.get(targetDex)
78+
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
79+
val gradients = module.namedGradients("forward")
80+
81+
if (i == 0) {
82+
Assert.assertEquals(4, gradients.size)
83+
assertContains(gradients, LIN_WEIGHT)
84+
assertContains(gradients, LIN_BIAS)
85+
assertContains(gradients, LIN2_WEIGHT)
86+
assertContains(gradients, LIN2_BIAS)
87+
}
88+
89+
if (i % 500 == 0 || i == numEpochs - 1) {
90+
Log.i(
91+
"testTrainXOR",
92+
String.format(
93+
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
94+
i,
95+
out[0].toTensor().getDataAsFloatArray()[0],
96+
input.getDataAsFloatArray()[0],
97+
input.getDataAsFloatArray()[1],
98+
out[1].toTensor().getDataAsLongArray()[0],
99+
target.getDataAsLongArray()[0]));
100+
}
101+
102+
sgd.step(gradients)
103+
104+
if (i == numEpochs - 1) {
105+
finalLoss = out[0].toTensor().dataAsFloatArray[0]
106+
}
107+
}
108+
Assert.assertTrue(finalLoss < 0.1f)
109+
}
110+
111+
@Test
112+
@Throws(IOException::class, URISyntaxException::class)
113+
fun testTrainXOR_PTEOnly() {
114+
val pteFilePath = "/xor_full.pte"
115+
116+
val pteFile = File(getTestFilePath(pteFilePath))
117+
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
118+
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
119+
pteInputStream.close()
120+
121+
val module = TrainingModule.load(getTestFilePath(pteFilePath));
122+
val params = module.namedParameters("forward")
123+
124+
Assert.assertEquals(4, params.size)
125+
assertContains(params, LIN_WEIGHT)
126+
assertContains(params, LIN_BIAS)
127+
assertContains(params, LIN2_WEIGHT)
128+
assertContains(params, LIN2_BIAS)
129+
130+
val sgd = SGD.create(params, 0.5);
131+
val dataset = listOf<Tensor>(
132+
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
133+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
134+
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
135+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
136+
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
137+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
138+
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
139+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
140+
)
141+
142+
val numEpochs = 5000;
143+
var finalLoss = Float.MAX_VALUE
144+
145+
for (i in 0 until numEpochs) {
146+
val inputDex = 2 * Random.nextInt(dataset.size / 2)
147+
val targetDex = inputDex + 1
148+
val input = dataset.get(inputDex)
149+
val target = dataset.get(targetDex)
150+
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
151+
val gradients = module.namedGradients("forward")
152+
153+
if (i == 0) {
154+
Assert.assertEquals(4, gradients.size)
155+
assertContains(gradients, LIN_WEIGHT)
156+
assertContains(gradients, LIN_BIAS)
157+
assertContains(gradients, LIN2_WEIGHT)
158+
assertContains(gradients, LIN2_BIAS)
159+
}
160+
161+
if (i % 500 == 0 || i == numEpochs - 1) {
162+
Log.i(
163+
"testTrainXOR_PTEOnly",
164+
String.format(
165+
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
166+
i,
167+
out[0].toTensor().getDataAsFloatArray()[0],
168+
input.getDataAsFloatArray()[0],
169+
input.getDataAsFloatArray()[1],
170+
out[1].toTensor().getDataAsLongArray()[0],
171+
target.getDataAsLongArray()[0]));
172+
}
173+
174+
sgd.step(gradients)
175+
176+
if (i == numEpochs - 1) {
177+
finalLoss = out[0].toTensor().dataAsFloatArray[0]
178+
}
179+
}
180+
Assert.assertTrue(finalLoss < 0.1f)
181+
}
182+
183+
companion object {
184+
private const val LIN_WEIGHT = "net.linear.weight"
185+
private const val LIN_BIAS = "net.linear.bias"
186+
private const val LIN2_WEIGHT = "net.linear2.weight"
187+
private const val LIN2_BIAS = "net.linear2.bias"
188+
}
189+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import com.facebook.jni.HybridData;
12+
import com.facebook.jni.annotations.DoNotStrip;
13+
import com.facebook.soloader.nativeloader.NativeLoader;
14+
import com.facebook.soloader.nativeloader.SystemDelegate;
15+
import java.util.Map;
16+
import org.pytorch.executorch.annotations.Experimental;
17+
18+
/**
19+
* Java wrapper for ExecuTorch SGD Optimizer.
20+
*
21+
* <p>Warning: These APIs are experimental and subject to change without notice
22+
*/
23+
@Experimental
24+
public class SGD {
25+
26+
static {
27+
if (!NativeLoader.isInitialized()) {
28+
NativeLoader.init(new SystemDelegate());
29+
}
30+
// Loads libexecutorch.so from jniLibs
31+
NativeLoader.loadLibrary("executorch");
32+
}
33+
34+
private final HybridData mHybridData;
35+
36+
@DoNotStrip
37+
private static native HybridData initHybrid(
38+
Map<String, Tensor> namedParameters,
39+
double learningRate,
40+
double momentum,
41+
double dampening,
42+
double weightDecay,
43+
boolean nesterov);
44+
45+
private SGD(
46+
Map<String, Tensor> namedParameters,
47+
double learningRate,
48+
double momentum,
49+
double dampening,
50+
double weightDecay,
51+
boolean nesterov) {
52+
mHybridData =
53+
initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
54+
}
55+
56+
/**
57+
* Creates a new SGD optimizer with the specified parameters and options.
58+
*
59+
* @param namedParameters Map of parameter names to tensors to be optimized
60+
* @param learningRate The learning rate for the optimizer
61+
* @param momentum The momentum value
62+
* @param dampening The dampening value
63+
* @param weightDecay The weight decay value
64+
* @param nesterov Whether to use Nesterov momentum
65+
* @return new {@link org.pytorch.executorch.SGD} object
66+
*/
67+
public static SGD create(
68+
Map<String, Tensor> namedParameters,
69+
double learningRate,
70+
double momentum,
71+
double dampening,
72+
double weightDecay,
73+
boolean nesterov) {
74+
return new SGD(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
75+
}
76+
77+
/**
78+
* Creates a new SGD optimizer with default options.
79+
*
80+
* @param namedParameters Map of parameter names to tensors to be optimized
81+
* @param learningRate The learning rate for the optimizer
82+
* @return new {@link org.pytorch.executorch.SGD} object
83+
*/
84+
public static SGD create(Map<String, Tensor> namedParameters, double learningRate) {
85+
return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false);
86+
}
87+
88+
/**
89+
* Performs a single optimization step using the provided gradients.
90+
*
91+
* @param namedGradients Map of parameter names to gradient tensors
92+
*/
93+
public void step(Map<String, Tensor> namedGradients) {
94+
if (!mHybridData.isValid()) {
95+
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
96+
}
97+
stepNative(namedGradients);
98+
}
99+
100+
@DoNotStrip
101+
private native void stepNative(Map<String, Tensor> namedGradients);
102+
}

0 commit comments

Comments
 (0)