File tree Expand file tree Collapse file tree 13 files changed +97
-27
lines changed
examples/demo-apps/android/LlamaDemo/app
androidTest/java/com/example/executorchllamademo
main/java/com/example/executorchllamademo
androidTest/java/org/pytorch/executorch
main/java/org/pytorch/executorch/extension/llm
benchmark/android/benchmark/app
src/main/java/org/pytorch/minibench Expand file tree Collapse file tree 13 files changed +97
-27
lines changed Original file line number Diff line number Diff line change @@ -60,6 +60,7 @@ dependencies {
60
60
implementation(files(" libs/executorch.aar" ))
61
61
implementation(" com.google.android.material:material:1.12.0" )
62
62
implementation(" androidx.activity:activity:1.9.0" )
63
+ implementation(" org.json:json:20250107" )
63
64
testImplementation(" junit:junit:4.13.2" )
64
65
androidTestImplementation(" androidx.test.ext:junit:1.1.5" )
65
66
androidTestImplementation(" androidx.test.espresso:espresso-core:3.5.1" )
Original file line number Diff line number Diff line change 18
18
import java .util .ArrayList ;
19
19
import java .util .Arrays ;
20
20
import java .util .List ;
21
+ import org .json .JSONException ;
22
+ import org .json .JSONObject ;
21
23
import org .junit .Test ;
22
24
import org .junit .runner .RunWith ;
23
25
import org .pytorch .executorch .extension .llm .LlmCallback ;
@@ -64,8 +66,16 @@ public void onResult(String result) {
64
66
}
65
67
66
68
@ Override
67
- public void onStats (float tps ) {
68
- tokensPerSecond .add (tps );
69
+ public void onStats (String result ) {
70
+ try {
71
+ JSONObject jsonObject = new JSONObject (result );
72
+ int numGeneratedTokens = jsonObject .getInt ("generated_tokens" );
73
+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
74
+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
75
+ float tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
76
+ tokensPerSecond .add (tps );
77
+ } catch (JSONException e ) {
78
+ }
69
79
}
70
80
71
81
private void report (final String metric , final Float value ) {
Original file line number Diff line number Diff line change 49
49
import java .util .List ;
50
50
import java .util .concurrent .Executor ;
51
51
import java .util .concurrent .Executors ;
52
+ import org .json .JSONException ;
53
+ import org .json .JSONObject ;
52
54
import org .pytorch .executorch .extension .llm .LlmCallback ;
53
55
import org .pytorch .executorch .extension .llm .LlmModule ;
54
56
@@ -97,10 +99,20 @@ public void onResult(String result) {
97
99
}
98
100
99
101
@ Override
100
- public void onStats (float tps ) {
102
+ public void onStats (String stats ) {
101
103
runOnUiThread (
102
104
() -> {
103
105
if (mResultMessage != null ) {
106
+ float tps = 0 ;
107
+ try {
108
+ JSONObject jsonObject = new JSONObject (stats );
109
+ int numGeneratedTokens = jsonObject .getInt ("generated_tokens" );
110
+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
111
+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
112
+ tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
113
+ } catch (JSONException e ) {
114
+ Log .e ("LLM" , "Error parsing JSON: " + e .getMessage ());
115
+ }
104
116
mResultMessage .setTokensPerSecond (tps );
105
117
mMessageAdapter .notifyDataSetChanged ();
106
118
}
Original file line number Diff line number Diff line change 13
13
import android .os .Looper ;
14
14
import android .os .Message ;
15
15
import androidx .annotation .NonNull ;
16
+ import org .json .JSONException ;
17
+ import org .json .JSONObject ;
16
18
import org .pytorch .executorch .extension .llm .LlmCallback ;
17
19
import org .pytorch .executorch .extension .llm .LlmModule ;
18
20
@@ -69,7 +71,16 @@ public void onResult(String result) {
69
71
}
70
72
71
73
@ Override
72
- public void onStats (float tps ) {
74
+ public void onStats (String stats ) {
75
+ float tps = 0 ;
76
+ try {
77
+ JSONObject jsonObject = new JSONObject (stats );
78
+ int numGeneratedTokens = jsonObject .getInt ("generated_tokens" );
79
+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
80
+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
81
+ tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
82
+ } catch (JSONException e ) {
83
+ }
73
84
mCallback .onStats ("tokens/second: " + tps );
74
85
}
75
86
}
Original file line number Diff line number Diff line change @@ -18,7 +18,7 @@ public interface ModelRunnerCallback {
18
18
19
19
void onTokenGenerated (String token );
20
20
21
- void onStats (String token );
21
+ void onStats (String stats );
22
22
23
23
void onGenerationStopped ();
24
24
}
Original file line number Diff line number Diff line change @@ -47,6 +47,7 @@ dependencies {
47
47
androidTestImplementation ' androidx.test.ext:junit:1.1.5'
48
48
androidTestImplementation ' androidx.test:rules:1.2.0'
49
49
androidTestImplementation ' commons-io:commons-io:2.4'
50
+ androidTestImplementation ' org.json:json:20250107'
50
51
}
51
52
52
53
import com.vanniktech.maven.publish.SonatypeHost
Original file line number Diff line number Diff line change 34
34
import org .apache .commons .io .FileUtils ;
35
35
import androidx .test .ext .junit .runners .AndroidJUnit4 ;
36
36
import androidx .test .InstrumentationRegistry ;
37
+ import org .json .JSONException ;
38
+ import org .json .JSONObject ;
37
39
import org .pytorch .executorch .extension .llm .LlmCallback ;
38
40
import org .pytorch .executorch .extension .llm .LlmModule ;
39
41
@@ -94,8 +96,17 @@ public void onResult(String result) {
94
96
}
95
97
96
98
@ Override
97
- public void onStats (float tps ) {
98
- LlmModuleInstrumentationTest .this .onStats (tps );
99
+ public void onStats (String stats ) {
100
+ float tps = 0 ;
101
+ try {
102
+ JSONObject jsonObject = new JSONObject (stats );
103
+ int numGeneratedTokens = jsonObject .getInt ("generated_tokens" );
104
+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
105
+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
106
+ tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
107
+ LlmModuleInstrumentationTest .this .onStats (tps );
108
+ } catch (JSONException e ) {
109
+ }
99
110
}
100
111
});
101
112
Original file line number Diff line number Diff line change @@ -31,8 +31,22 @@ public interface LlmCallback {
31
31
/**
32
32
* Called when the statistics for the generate() is available.
33
33
*
34
+ * Note: This is a deprecated API and will be removed in the future. Please use onStats(String stats)
35
+ *
34
36
* @param tps Tokens/second for generated tokens.
35
37
*/
38
+ @ Deprecated
39
+ @ DoNotStrip
40
+ default public void onStats (float tps ) {}
41
+
42
+ /**
43
+ * Called when the statistics for the generate() is available.
44
+ *
45
+ * The result will be a JSON string. See extension/llm/stats.h for the field
46
+ * definitions.
47
+ *
48
+ * @param stats JSON string containing the statistics for the generate()
49
+ */
36
50
@ DoNotStrip
37
- public void onStats (float tps );
51
+ default public void onStats (String stats ) {}
38
52
}
Original file line number Diff line number Diff line change @@ -100,14 +100,20 @@ class ExecuTorchLlmCallbackJni
100
100
101
101
void onStats (const llm::Stats& result) const {
102
102
static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic ();
103
- static const auto method = cls->getMethod <void (jfloat)>(" onStats" );
103
+ static const auto tps_method = cls->getMethod <void (jfloat)>(" onStats" );
104
104
double eval_time =
105
105
(double )(result.inference_end_ms - result.prompt_eval_end_ms );
106
106
107
107
float tps = result.num_generated_tokens / eval_time *
108
108
result.SCALING_FACTOR_UNITS_PER_SECOND ;
109
-
110
- method (self (), tps);
109
+ tps_method (self (), tps);
110
+
111
+ static const auto on_stats_method =
112
+ cls->getMethod <void (facebook::jni::local_ref<jstring>)>(" onStats" );
113
+ on_stats_method (
114
+ self (),
115
+ facebook::jni::make_jstring (
116
+ executorch::extension::llm::stats_to_json_string (result)));
111
117
}
112
118
};
113
119
Original file line number Diff line number Diff line change @@ -39,6 +39,7 @@ dependencies {
39
39
implementation(" com.facebook.soloader:soloader:0.10.5" )
40
40
implementation(" com.facebook.fbjni:fbjni:0.5.1" )
41
41
implementation(" com.google.code.gson:gson:2.8.6" )
42
+ implementation(" org.json:json:20250107" )
42
43
testImplementation(" junit:junit:4.13.2" )
43
44
androidTestImplementation(" androidx.test.ext:junit:1.2.1" )
44
45
androidTestImplementation(" androidx.test.espresso:espresso-core:3.6.1" )
You can’t perform that action at this time.
0 commit comments