Skip to content

Commit a6e15a7

Browse files
committed
Add entrypoints for training with PTE-only
As title.
1 parent 7f4a62d commit a6e15a7

File tree

4 files changed

+96
-3
lines changed

4 files changed

+96
-3
lines changed

extension/android/executorch_android/android_test_setup.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ prepare_add() {
1919
}
2020

2121
prepare_xor() {
22+
python3 -m extension.training.examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/"
23+
mv "${BASEDIR}/src/androidTest/resources/xor.pte" "${BASEDIR}/src/androidTest/resources/xor_only.pte"
2224
python3 -m extension.training.examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/" --external
2325
}
2426

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,78 @@ class TrainingModuleE2ETest {
108108
Assert.assertTrue(finalLoss < 0.1f)
109109
}
110110

111+
@Test
112+
@Throws(IOException::class, URISyntaxException::class)
113+
fun testTrainXOR_PTEOnly() {
114+
val pteFilePath = "/xor_only.pte"
115+
116+
val pteFile = File(getTestFilePath(pteFilePath))
117+
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
118+
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
119+
pteInputStream.close()
120+
121+
val module = TrainingModule.load(getTestFilePath(pteFilePath));
122+
val params = module.namedParameters("forward")
123+
124+
Assert.assertEquals(4, params.size)
125+
assertContains(params, LIN_WEIGHT)
126+
assertContains(params, LIN_BIAS)
127+
assertContains(params, LIN2_WEIGHT)
128+
assertContains(params, LIN2_BIAS)
129+
130+
val sgd = SGD.create(params, 0.5);
131+
val dataset = listOf<Tensor>(
132+
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
133+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
134+
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
135+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
136+
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
137+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
138+
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
139+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
140+
)
141+
142+
val numEpochs = 5000;
143+
var finalLoss = Float.MAX_VALUE
144+
145+
for (i in 0 until numEpochs) {
146+
val inputDex = 2 * Random.nextInt(dataset.size / 2)
147+
val targetDex = inputDex + 1
148+
val input = dataset.get(inputDex)
149+
val target = dataset.get(targetDex)
150+
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
151+
val gradients = module.namedGradients("forward")
152+
153+
if (i == 0) {
154+
Assert.assertEquals(4, gradients.size)
155+
assertContains(gradients, LIN_WEIGHT)
156+
assertContains(gradients, LIN_BIAS)
157+
assertContains(gradients, LIN2_WEIGHT)
158+
assertContains(gradients, LIN2_BIAS)
159+
}
160+
161+
if (i % 500 == 0 || i == numEpochs - 1) {
162+
Log.i(
163+
"testTrainXOR_PTEOnly",
164+
String.format(
165+
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
166+
i,
167+
out[0].toTensor().getDataAsFloatArray()[0],
168+
input.getDataAsFloatArray()[0],
169+
input.getDataAsFloatArray()[1],
170+
out[1].toTensor().getDataAsLongArray()[0],
171+
target.getDataAsLongArray()[0]));
172+
}
173+
174+
sgd.step(gradients)
175+
176+
if (i == numEpochs - 1) {
177+
finalLoss = out[0].toTensor().dataAsFloatArray[0]
178+
}
179+
}
180+
Assert.assertTrue(finalLoss < 0.1f)
181+
}
182+
111183
companion object {
112184
private const val LIN_WEIGHT = "net.linear.weight"
113185
private const val LIN_BIAS = "net.linear.bias"

extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
4444
}
4545

4646
/**
47-
* Loads a serialized ExecuTorch module from the specified path on the disk.
47+
* Loads a serialized ExecuTorch training module from the specified path on the disk.
4848
*
4949
* @param modelPath path to file that contains the serialized ExecuTorch module.
50+
* @param dataPath path to file that contains the ExecuTorch module external weights.
5051
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
5152
*/
5253
public static TrainingModule load(final String modelPath, final String dataPath) {
@@ -61,6 +62,21 @@ public static TrainingModule load(final String modelPath, final String dataPath)
6162
return new TrainingModule(modelPath, dataPath);
6263
}
6364

65+
/**
66+
* Loads a serialized ExecuTorch training module from the specified path on the disk.
67+
*
68+
* @param modelPath path to file that contains the serialized ExecuTorch module. This PTE does not
69+
* rely on external weights.
70+
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
71+
*/
72+
public static TrainingModule load(final String modelPath) {
73+
File modelFile = new File(modelPath);
74+
if (!modelFile.canRead() || !modelFile.isFile()) {
75+
throw new RuntimeException("Cannot load model path!! " + modelPath);
76+
}
77+
return new TrainingModule(modelPath, "");
78+
}
79+
6480
/**
6581
* Runs the specified method of this module with the specified arguments.
6682
*

extension/android/jni/jni_layer_training.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,16 @@ class ExecuTorchTrainingJni
7373
facebook::jni::alias_ref<jstring> modelPath,
7474
facebook::jni::alias_ref<jstring> dataPath) {
7575
auto modelLoader = FileDataLoader::from(modelPath->toStdString().c_str());
76-
auto dataLoader = FileDataLoader::from(dataPath->toStdString().c_str());
76+
auto stdStringDataPath = dataPath->toStdString();
7777
module_ = std::make_unique<training::TrainingModule>(
7878
std::make_unique<FileDataLoader>(std::move(modelLoader.get())),
7979
nullptr,
8080
nullptr,
8181
nullptr,
82-
std::make_unique<FileDataLoader>(std::move(dataLoader.get())));
82+
stdStringDataPath.empty()
83+
? nullptr
84+
: std::make_unique<FileDataLoader>(std::move(
85+
FileDataLoader::from(stdStringDataPath.c_str()).get())));
8386
}
8487

8588
static facebook::jni::local_ref<jhybriddata> initHybrid(

0 commit comments

Comments
 (0)