Skip to content

Commit cd2afc3

Browse files
authored
Android E2E with real input (#10230)
Use MV2, MV3, quantized ResNet50 for validation.
1 parent 6301b05 commit cd2afc3

File tree

4 files changed

+201
-5
lines changed

4 files changed

+201
-5
lines changed

extension/android/executorch_android/android_test_setup.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ prepare_tinyllama() {
3636
prepare_vision() {
3737
pushd "${BASEDIR}/../../../"
3838
python3 -m examples.xnnpack.aot_compiler --model_name "mv2" --delegate
39-
cp mv2*.pte "${BASEDIR}/src/androidTest/resources/"
39+
python3 -m examples.xnnpack.aot_compiler --model_name "mv3" --delegate
40+
python3 -m examples.xnnpack.aot_compiler --model_name "resnet50" --quantize --delegate
41+
cp mv2*.pte mv3*.pte resnet50*.pte "${BASEDIR}/src/androidTest/resources/"
4042
popd
4143
}
4244

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import static org.junit.Assert.assertNotEquals;
1515
import static org.junit.Assert.fail;
1616

17+
import android.graphics.Bitmap;
18+
import android.graphics.BitmapFactory;
1719
import android.os.Environment;
1820
import androidx.test.rule.GrantPermissionRule;
1921
import android.Manifest;
@@ -45,18 +47,60 @@ private static String getTestFilePath(String fileName) {
4547
@Rule
4648
public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE);
4749

48-
@Test
49-
public void testMv2Fp32() throws IOException, URISyntaxException{
50-
String filePath = "/mv2_xnnpack_fp32.pte";
50+
static int argmax(float[] array) {
51+
if (array.length == 0) {
52+
throw new IllegalArgumentException("Array cannot be empty");
53+
}
54+
int maxIndex = 0;
55+
float maxValue = array[0];
56+
for (int i = 1; i < array.length; i++) {
57+
if (array[i] > maxValue) {
58+
maxValue = array[i];
59+
maxIndex = i;
60+
}
61+
}
62+
return maxIndex;
63+
}
64+
65+
public void testClassification(String filePath) throws IOException, URISyntaxException {
5166
File pteFile = new File(getTestFilePath(filePath));
5267
InputStream inputStream = getClass().getResourceAsStream(filePath);
5368
FileUtils.copyInputStreamToFile(inputStream, pteFile);
5469
inputStream.close();
5570

71+
InputStream imgInputStream = getClass().getResourceAsStream("/banana.jpeg");
72+
Bitmap bitmap = BitmapFactory.decodeStream(imgInputStream);
73+
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);
74+
imgInputStream.close();
75+
76+
Tensor inputTensor =
77+
TensorImageUtils.bitmapToFloat32Tensor(
78+
bitmap,
79+
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
80+
TensorImageUtils.TORCHVISION_NORM_STD_RGB);
81+
5682
Module module = Module.load(getTestFilePath(filePath));
5783

58-
EValue[] results = module.forward();
84+
EValue[] results = module.forward(EValue.from(inputTensor));
5985
assertTrue(results[0].isTensor());
86+
float[] scores = results[0].toTensor().getDataAsFloatArray();
87+
88+
int bananaClass = 954; // From ImageNet 1K
89+
assertEquals(bananaClass, argmax(scores));
90+
}
91+
92+
@Test
93+
public void testMv2Fp32() throws IOException, URISyntaxException {
94+
testClassification("/mv2_xnnpack_fp32.pte");
6095
}
6196

97+
@Test
98+
public void testMv3Fp32() throws IOException, URISyntaxException {
99+
testClassification("/mv3_xnnpack_fp32.pte");
100+
}
101+
102+
@Test
103+
public void testResnet50() throws IOException, URISyntaxException {
104+
testClassification("/resnet50_xnnpack_q8.pte");
105+
}
62106
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import android.graphics.Bitmap;
12+
import android.util.Log;
13+
import java.nio.FloatBuffer;
14+
import org.pytorch.executorch.Tensor;
15+
16+
/**
17+
* Contains utility functions for {@link Tensor} creation from {@link android.graphics.Bitmap} or
18+
* {@link android.media.Image} source.
19+
*/
20+
public final class TensorImageUtils {
21+
22+
public static float[] TORCHVISION_NORM_MEAN_RGB = new float[] {0.485f, 0.456f, 0.406f};
23+
public static float[] TORCHVISION_NORM_STD_RGB = new float[] {0.229f, 0.224f, 0.225f};
24+
25+
/**
26+
* Creates new {@link Tensor} from full {@link android.graphics.Bitmap}, normalized with specified
27+
* in parameters mean and std.
28+
*
29+
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
30+
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
31+
* order
32+
*/
33+
public static Tensor bitmapToFloat32Tensor(
34+
final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) {
35+
checkNormMeanArg(normMeanRGB);
36+
checkNormStdArg(normStdRGB);
37+
38+
return bitmapToFloat32Tensor(
39+
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB);
40+
}
41+
42+
/**
43+
* Writes tensor content from specified {@link android.graphics.Bitmap}, normalized with specified
44+
* in parameters mean and std to specified {@link java.nio.FloatBuffer} with specified offset.
45+
*
46+
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
47+
* @param x - x coordinate of top left corner of bitmap's area
48+
* @param y - y coordinate of top left corner of bitmap's area
49+
* @param width - width of bitmap's area
50+
* @param height - height of bitmap's area
51+
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
52+
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
53+
* order
54+
*/
55+
public static void bitmapToFloatBuffer(
56+
final Bitmap bitmap,
57+
final int x,
58+
final int y,
59+
final int width,
60+
final int height,
61+
final float[] normMeanRGB,
62+
final float[] normStdRGB,
63+
final FloatBuffer outBuffer,
64+
final int outBufferOffset) {
65+
checkOutBufferCapacity(outBuffer, outBufferOffset, width, height);
66+
checkNormMeanArg(normMeanRGB);
67+
checkNormStdArg(normStdRGB);
68+
final int pixelsCount = height * width;
69+
final int[] pixels = new int[pixelsCount];
70+
bitmap.getPixels(pixels, 0, width, x, y, width, height);
71+
final int offset_g = pixelsCount;
72+
final int offset_b = 2 * pixelsCount;
73+
for (int i = 0; i < 100; i++) {
74+
final int c = pixels[i];
75+
Log.i("Image", ": " + i + " " + ((c >> 16) & 0xff));
76+
}
77+
for (int i = 0; i < pixelsCount; i++) {
78+
final int c = pixels[i];
79+
float r = ((c >> 16) & 0xff) / 255.0f;
80+
float g = ((c >> 8) & 0xff) / 255.0f;
81+
float b = ((c) & 0xff) / 255.0f;
82+
outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]);
83+
outBuffer.put(outBufferOffset + offset_g + i, (g - normMeanRGB[1]) / normStdRGB[1]);
84+
outBuffer.put(outBufferOffset + offset_b + i, (b - normMeanRGB[2]) / normStdRGB[2]);
85+
}
86+
}
87+
88+
/**
89+
* Creates new {@link Tensor} from specified area of {@link android.graphics.Bitmap}, normalized
90+
* with specified in parameters mean and std.
91+
*
92+
* @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data
93+
* @param x - x coordinate of top left corner of bitmap's area
94+
* @param y - y coordinate of top left corner of bitmap's area
95+
* @param width - width of bitmap's area
96+
* @param height - height of bitmap's area
97+
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
98+
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
99+
* order
100+
*/
101+
public static Tensor bitmapToFloat32Tensor(
102+
final Bitmap bitmap,
103+
int x,
104+
int y,
105+
int width,
106+
int height,
107+
float[] normMeanRGB,
108+
float[] normStdRGB) {
109+
checkNormMeanArg(normMeanRGB);
110+
checkNormStdArg(normStdRGB);
111+
112+
final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height);
113+
bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0);
114+
return Tensor.fromBlob(floatBuffer, new long[] {1, 3, height, width});
115+
}
116+
117+
private static void checkOutBufferCapacity(
118+
FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) {
119+
if (outBufferOffset + 3 * tensorWidth * tensorHeight > outBuffer.capacity()) {
120+
throw new IllegalStateException("Buffer underflow");
121+
}
122+
}
123+
124+
private static void checkTensorSize(int tensorWidth, int tensorHeight) {
125+
if (tensorHeight <= 0 || tensorWidth <= 0) {
126+
throw new IllegalArgumentException("tensorHeight and tensorWidth must be positive");
127+
}
128+
}
129+
130+
private static void checkRotateCWDegrees(int rotateCWDegrees) {
131+
if (rotateCWDegrees != 0
132+
&& rotateCWDegrees != 90
133+
&& rotateCWDegrees != 180
134+
&& rotateCWDegrees != 270) {
135+
throw new IllegalArgumentException("rotateCWDegrees must be one of 0, 90, 180, 270");
136+
}
137+
}
138+
139+
private static void checkNormStdArg(float[] normStdRGB) {
140+
if (normStdRGB.length != 3) {
141+
throw new IllegalArgumentException("normStdRGB length must be 3");
142+
}
143+
}
144+
145+
private static void checkNormMeanArg(float[] normMeanRGB) {
146+
if (normMeanRGB.length != 3) {
147+
throw new IllegalArgumentException("normMeanRGB length must be 3");
148+
}
149+
}
150+
}
Loading

0 commit comments

Comments
 (0)