Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import android.app.Activity;
import android.content.Intent;
import android.os.AsyncTask;
import android.os.Bundle;
import android.system.ErrnoException;
import android.system.Os;
Expand Down Expand Up @@ -47,43 +48,57 @@ protected void onCreate(Bundle savedInstanceState) {
// TODO: Format the string with a parsable format
Stats stats = new Stats();

// Record the time it takes to load the model and the forward method
stats.loadStart = System.nanoTime();
Module module = Module.load(model.getPath());
stats.errorCode = module.loadMethod("forward");
stats.loadEnd = System.nanoTime();
new AsyncTask<Void, Void, Void>() {
@Override
protected Void doInBackground(Void... voids) {

for (int i = 0; i < numIter; i++) {
long start = System.nanoTime();
module.forward();
double forwardMs = (System.nanoTime() - start) * 1e-6;
stats.latency.add(forwardMs);
}
// Record the time it takes to load the model and the forward method
stats.loadStart = System.nanoTime();
Module module = Module.load(model.getPath());
stats.errorCode = module.loadMethod("forward");
stats.loadEnd = System.nanoTime();

final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ms)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel, "model_load_time(ms)", (stats.loadEnd - stats.loadStart) * 1e-6, 0.0f));
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
for (int i = 0; i < numIter; i++) {
long start = System.nanoTime();
module.forward();
double forwardMs = (System.nanoTime() - start) * 1e-6;
stats.latency.add(forwardMs);
}
return null;
}

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
@Override
protected void onPostExecute(Void aVoid) {

final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ms)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel,
"model_load_time(ms)",
(stats.loadEnd - stats.loadStart) * 1e-6,
0.0f));
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}
}.execute();
}
}

Expand Down
Loading