Skip to content

Commit 8a6b737

Browse files
ZachNagengastbpkeeneargmaxchen-argmaxKeith Ha
authored
Release v0.3.2 (#68)
* Add multilingual model support with PerLayerKVDecoder, handling diverse encoder output names, int64 token bindings, and timestamp postprocessing; lifts KV-cache logic into decoder for improved modularity. * Various SDK and build improvements: conditional Bazel download, updated API with segments, improved model download/handling, version bumps (0.3.2), and README cleanup. --------- Co-authored-by: Brian Keene <b@argmaxinc.com> Co-authored-by: chen-argmax <chen@argmaxinc.com> Co-authored-by: Keith Ha <keith@argmaxinc.com>
1 parent bfce0d5 commit 8a6b737

28 files changed

+1242
-332
lines changed

README.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@
88
<img src="https://github.com/user-attachments/assets/1be5e31c-de42-40ab-9b85-790cb911ed47" alt="WhisperKit" width="20%" />
99
</a>
1010

11-
# WhisperKit Android (Beta)
11+
# WhisperKit Android
12+
13+
[![Tests](https://github.com/argmaxinc/whisperkitandroid/actions/workflows/pr-checks.yml/badge.svg)](https://github.com/argmaxinc/whisperkitandroid/actions/workflows/pr-checks.yml)
14+
[![License](https://img.shields.io/github/license/argmaxinc/whisperkitandroid?logo=github&logoColor=969da4&label=License&labelColor=353a41&color=32d058)](LICENSE.md)
15+
[![Maven Central](https://img.shields.io/maven-central/v/com.argmaxinc/whisperkit?logo=sonatype&logoColor=969da4&label=Maven%20Central&labelColor=353a41&color=32d058)](https://central.sonatype.com/artifact/com.argmaxinc/whisperkit)
16+
[![Discord](https://img.shields.io/discord/1171912382512115722?style=flat&logo=discord&logoColor=969da4&label=Discord&labelColor=353a41&color=32d058&link=https%3A%2F%2Fdiscord.gg%2FG5F5GZGecC)](https://discord.gg/G5F5GZGecC)
1217

13-
[![Maven Central](https://img.shields.io/maven-central/v/com.argmaxinc/whisperkit?color=32d058)](https://central.sonatype.com/artifact/com.argmaxinc/whisperkit)
1418
</div>
1519

1620
WhisperKit Android brings Foundation Models On Device for Automatic Speech Recognition. It extends the performance and feature set of [WhisperKit](https://github.com/argmaxinc/WhisperKit) from Apple platforms to Android and Linux. The current feature set is a subset of the iOS counterpart,
1721
but we are continuing to invest in Android and now welcome contributions from the community.
1822

19-
[Example App (Coming Soon)] [[Blog Post]](https://takeargmax.com/blog/android) [[Python Tools Repo]](https://github.com/argmaxinc/whisperkittools)
23+
[[Example App]](https://play.google.com/store/apps/details?id=com.argmaxinc.whisperax) [[Blog Post]](https://takeargmax.com/blog/android) [[Python Tools Repo]](https://github.com/argmaxinc/whisperkittools)
2024

2125
## Table of Contents
2226

@@ -37,7 +41,7 @@ To use WhisperKit in your Android app, you need to:
3741
```kotlin
3842
dependencies {
3943
// 1. WhisperKit SDK
40-
implementation("com.argmaxinc:whisperkit:0.3.0") // Check badge above for latest version
44+
implementation("com.argmaxinc:whisperkit:0.3.2") // Check badge above for latest version
4145

4246
// 2. QNN dependencies for hardware acceleration
4347
implementation("com.qualcomm.qnn:qnn-runtime:2.34.0")
@@ -73,18 +77,22 @@ class YourActivity : AppCompatActivity() {
7377
whisperKit = WhisperKit.Builder()
7478
.setModel(WhisperKit.OPENAI_TINY_EN)
7579
.setApplicationContext(applicationContext)
76-
.setCallback { what, timestamp, msg ->
80+
.setCallback { what, result ->
7781
// Handle transcription output
7882
when (what) {
7983
WhisperKit.TextOutputCallback.MSG_INIT -> {
8084
// Model initialized successfully
8185
}
8286
WhisperKit.TextOutputCallback.MSG_TEXT_OUT -> {
8387
// New transcription available
84-
val text = msg
85-
val time = timestamp
88+
val fullText = result.text
89+
val segments = result.segments
8690
// Process the transcribed text as it becomes available
8791
// This callback will be called multiple times as more audio is processed
92+
segments.forEach { segment ->
93+
// Process each segment
94+
val segmentText = segment.text
95+
}
8896
}
8997
WhisperKit.TextOutputCallback.MSG_CLOSE -> {
9098
// Cleanup complete

android/config/detekt.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ complexity:
1414
thresholdInObjects: 10
1515
LongParameterList:
1616
functionThreshold: 8
17-
constructorThreshold: 7
17+
constructorThreshold: 8
1818
CyclomaticComplexMethod:
1919
threshold: 20
2020
NestedBlockDepth:

android/examples/WhisperAX/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ android {
1212
applicationId = "com.argmaxinc.whisperax"
1313
minSdk = 26
1414
targetSdk = 35
15-
versionCode = 3
15+
versionCode = 6
1616
versionName = "0.1.0"
1717

1818
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"

android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ComputeUnitsView.kt

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import androidx.compose.material3.Surface
3131
import androidx.compose.material3.Text
3232
import androidx.compose.runtime.Composable
3333
import androidx.compose.runtime.collectAsState
34+
import androidx.compose.runtime.derivedStateOf
3435
import androidx.compose.runtime.getValue
3536
import androidx.compose.runtime.mutableStateOf
3637
import androidx.compose.runtime.remember
@@ -40,6 +41,7 @@ import androidx.compose.ui.Modifier
4041
import androidx.compose.ui.draw.alpha
4142
import androidx.compose.ui.draw.rotate
4243
import androidx.compose.ui.unit.dp
44+
import com.argmaxinc.whisperax.WhisperViewModel.Companion.MODELS_SUPPORTING_NPU
4345
import com.argmaxinc.whisperkit.ExperimentalWhisperKit
4446
import com.argmaxinc.whisperkit.WhisperKit
4547

@@ -50,11 +52,18 @@ enum class ComputeUnits(val displayName: String, val backendValue: Int) {
5052
CPU_AND_NPU("NPU", WhisperKit.Builder.CPU_AND_NPU),
5153
}
5254

55+
@OptIn(ExperimentalWhisperKit::class)
5356
@Composable
5457
fun ComputeUnitsView(viewModel: WhisperViewModel) {
5558
val modelState by viewModel.modelState.collectAsState()
5659
val encoderState by viewModel.encoderState.collectAsState()
5760
val decoderState by viewModel.decoderState.collectAsState()
61+
val selectedModel by viewModel.selectedModel.collectAsState()
62+
val shouldEnableNPUForEncoderDecoder by remember {
63+
derivedStateOf {
64+
selectedModel in MODELS_SUPPORTING_NPU
65+
}
66+
}
5867
val isEnabled = modelState == ModelState.LOADED || modelState == ModelState.UNLOADED
5968

6069
var whisperKitExpanded by remember { mutableStateOf(true) }
@@ -75,6 +84,7 @@ fun ComputeUnitsView(viewModel: WhisperViewModel) {
7584
currentState = encoderState,
7685
currentUnit = viewModel.encoderComputeUnits.collectAsState().value,
7786
onUnitSelected = { viewModel.setEncoderComputeUnits(it) },
87+
shouldEnableNPU = shouldEnableNPUForEncoderDecoder,
7888
enabled = isEnabled,
7989
)
8090

@@ -85,6 +95,7 @@ fun ComputeUnitsView(viewModel: WhisperViewModel) {
8595
currentState = decoderState,
8696
currentUnit = viewModel.decoderComputeUnits.collectAsState().value,
8797
onUnitSelected = { viewModel.setDecoderComputeUnits(it) },
98+
shouldEnableNPU = shouldEnableNPUForEncoderDecoder,
8899
enabled = isEnabled,
89100
)
90101
}
@@ -185,6 +196,7 @@ fun ComputeUnitRow(
185196
currentState: ModelState,
186197
currentUnit: ComputeUnits,
187198
onUnitSelected: (ComputeUnits) -> Unit,
199+
shouldEnableNPU: Boolean = true,
188200
enabled: Boolean = true,
189201
) {
190202
val infiniteTransition = rememberInfiniteTransition(label = "loading animation")
@@ -248,7 +260,11 @@ fun ComputeUnitRow(
248260
expanded = expanded,
249261
onDismissRequest = { expanded = false },
250262
) {
251-
ComputeUnits.values().forEach { unit ->
263+
if (shouldEnableNPU) {
264+
listOf(ComputeUnits.CPU_ONLY, ComputeUnits.CPU_AND_GPU, ComputeUnits.CPU_AND_NPU)
265+
} else {
266+
listOf(ComputeUnits.CPU_ONLY, ComputeUnits.CPU_AND_GPU)
267+
}.forEach { unit ->
252268
DropdownMenuItem(
253269
text = { Text(unit.displayName) },
254270
onClick = {

android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ModelSelectorView.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import androidx.compose.material3.Icon
3434
import androidx.compose.material3.IconButton
3535
import androidx.compose.material3.LinearProgressIndicator
3636
import androidx.compose.material3.MaterialTheme
37+
import androidx.compose.material3.MenuAnchorType
3738
import androidx.compose.material3.OutlinedTextField
3839
import androidx.compose.material3.Surface
3940
import androidx.compose.material3.Text
@@ -111,7 +112,8 @@ fun ModelSelectorView(viewModel: WhisperViewModel) {
111112
},
112113
modifier = Modifier
113114
.fillMaxWidth()
114-
.weight(1f),
115+
.weight(1f)
116+
.menuAnchor(MenuAnchorType.PrimaryNotEditable),
115117
)
116118

117119
ExposedDropdownMenu(

android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/WhisperViewModel.kt

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import androidx.compose.runtime.mutableStateListOf
1414
import androidx.lifecycle.ViewModel
1515
import androidx.lifecycle.viewModelScope
1616
import com.argmaxinc.whisperkit.ExperimentalWhisperKit
17+
import com.argmaxinc.whisperkit.TranscriptionResult
18+
import com.argmaxinc.whisperkit.TranscriptionSegment
1719
import com.argmaxinc.whisperkit.WhisperKit
1820
import com.argmaxinc.whisperkit.WhisperKit.TextOutputCallback
1921
import com.argmaxinc.whisperkit.WhisperKitException
@@ -33,22 +35,13 @@ import java.text.SimpleDateFormat
3335
import java.util.Date
3436
import java.util.Locale
3537

36-
data class TranscriptionSegment(
37-
val text: String,
38-
val start: Float,
39-
val end: Float,
40-
val tokens: List<Int> = emptyList(),
41-
)
42-
43-
data class TranscriptionResult(
44-
val text: String = "",
45-
val segments: List<TranscriptionSegment> = emptyList(),
46-
)
47-
4838
@OptIn(ExperimentalWhisperKit::class)
4939
class WhisperViewModel : ViewModel() {
5040
companion object {
5141
const val TAG = "WhisperViewModel"
42+
43+
// Models currently supporting NPU backend, don't enable NPU for other models
44+
val MODELS_SUPPORTING_NPU = listOf(WhisperKit.Builder.QUALCOMM_TINY_EN, WhisperKit.Builder.QUALCOMM_BASE_EN)
5245
}
5346

5447
private lateinit var appContext: Context
@@ -190,25 +183,25 @@ class WhisperViewModel : ViewModel() {
190183
cacheDir = context.cacheDir.absolutePath
191184
}
192185

193-
fun onTextOutput(what: Int, timestamp: Float, msg: String) {
186+
fun onTextOutput(what: Int, result: TranscriptionResult) {
187+
val segments = result.segments
194188
when (what) {
195189
TextOutputCallback.MSG_INIT -> {
196-
Log.i(MainActivity.TAG, "TFLite initialized: $msg")
190+
Log.i(MainActivity.TAG, "TFLite initialized: ${result.text}")
197191
startTime = System.currentTimeMillis()
198192
_pipelineStart.value = startTime.toDouble() / 1000.0
199193
_isInitializing.value = false
200194
}
201195

202196
TextOutputCallback.MSG_TEXT_OUT -> {
203197
Log.i(MainActivity.TAG, "TEXT OUT THREAD")
204-
if (msg.isNotEmpty()) {
198+
if (segments.isNotEmpty()) {
205199
if (!firstTokenReceived) {
206200
firstTokenReceived = true
207201
firstTokenTimestamp = System.currentTimeMillis()
208202
_firstTokenTime.value = (firstTokenTimestamp - startTime).toDouble() / 1000.0
209203
}
210-
211-
val newTokens = msg.length / 4
204+
val newTokens = segments.joinToString("") { it.text }.length / 4
212205
totalTokens += newTokens
213206

214207
val currentTime = System.currentTimeMillis()
@@ -220,14 +213,14 @@ class WhisperViewModel : ViewModel() {
220213
}
221214

222215
lastTokenTimestamp = currentTime
223-
updateTranscript(msg)
216+
updateTranscript(segments)
224217
}
225218
}
226219

227220
TextOutputCallback.MSG_CLOSE -> {
228221
Log.i(MainActivity.TAG, "Transcription completed.")
229-
if (msg.isNotEmpty()) {
230-
val newTokens = msg.length / 4
222+
if (segments.isNotEmpty()) {
223+
val newTokens = segments.joinToString("") { it.text }.length / 4
231224
totalTokens += newTokens
232225

233226
val totalTime = (System.currentTimeMillis() - startTime).toDouble() / 1000.0
@@ -236,8 +229,7 @@ class WhisperViewModel : ViewModel() {
236229

237230
updateRealtimeMetrics(totalTime)
238231
}
239-
240-
updateTranscript(msg)
232+
updateTranscript(segments)
241233
}
242234
}
243235

@@ -247,25 +239,8 @@ class WhisperViewModel : ViewModel() {
247239
}
248240
}
249241

250-
private fun updateTranscript(chunkText: String, withTimestamps: Boolean = false) {
251-
var processedText = chunkText
252-
253-
val timestamps = if (withTimestamps) {
254-
val timestampPattern = "<\\|(\\d+\\.\\d+)\\|>".toRegex()
255-
val timestampMatches = timestampPattern.findAll(chunkText).toList()
256-
timestampMatches.map { it.groupValues[1].toFloat() }
257-
} else {
258-
emptyList()
259-
}
260-
261-
if (!withTimestamps) {
262-
processedText = processedText
263-
.replace("<\\|[^>]*\\|>".toRegex(), "")
264-
.trim()
265-
} else {
266-
processedText = processedText.trim()
267-
}
268-
242+
private fun updateTranscript(segments: List<TranscriptionSegment>) {
243+
val processedText = segments.joinToString("") { it.text }
269244
if (processedText.isNotEmpty()) {
270245
if (allText.isNotEmpty()) {
271246
allText.append("\n")
@@ -284,13 +259,12 @@ class WhisperViewModel : ViewModel() {
284259
fun listModels() {
285260
viewModelScope.launch {
286261
val modelDirs = listOf(
287-
// TODO: enable when models are ready
288-
// WhisperKit.Builder.OPENAI_TINY_EN,
289-
// WhisperKit.Builder.OPENAI_BASE_EN,
290-
// WhisperKit.Builder.OPENAI_SMALL_EN,
291262
WhisperKit.Builder.QUALCOMM_TINY_EN,
292263
WhisperKit.Builder.QUALCOMM_BASE_EN,
293-
// WhisperKit.Builder.QUALCOMM_SMALL_EN
264+
WhisperKit.Builder.OPENAI_TINY_EN,
265+
WhisperKit.Builder.OPENAI_BASE_EN,
266+
WhisperKit.Builder.OPENAI_TINY,
267+
WhisperKit.Builder.OPENAI_BASE,
294268
)
295269
availableModels.clear()
296270
availableModels.addAll(modelDirs)
@@ -364,6 +338,21 @@ class WhisperViewModel : ViewModel() {
364338

365339
fun selectModel(model: String) {
366340
_selectedModel.value = model
341+
if (model in MODELS_SUPPORTING_NPU) {
342+
_encoderComputeUnits.update {
343+
ComputeUnits.CPU_AND_NPU
344+
}
345+
_decoderComputeUnits.update {
346+
ComputeUnits.CPU_AND_NPU
347+
}
348+
} else {
349+
_encoderComputeUnits.update {
350+
ComputeUnits.CPU_ONLY
351+
}
352+
_decoderComputeUnits.update {
353+
ComputeUnits.CPU_ONLY
354+
}
355+
}
367356
_modelState.value = ModelState.UNLOADED
368357
_encoderState.value = ModelState.UNLOADED
369358
_decoderState.value = ModelState.UNLOADED

android/whisperkit/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ dependencies {
6666

6767
mavenPublishing {
6868

69-
coordinates("com.argmaxinc", "whisperkit", "0.3.0")
69+
coordinates("com.argmaxinc", "whisperkit", "0.3.2")
7070
pom {
7171
name.set("WhisperKit")
7272
description.set("On-device Speech Recognition for Android")

android/whisperkit/detekt-baseline.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
<SmellBaseline>
33
<ManuallySuppressedIssues/>
44
<CurrentIssues>
5+
<ID>LargeClass:ArgmaxModelDownloaderImplTest.kt$ArgmaxModelDownloaderImplTest</ID>
56
<ID>ThrowsCount:WhisperKit.kt$WhisperKit.Builder$@Throws(WhisperKitException::class) fun build(): WhisperKit</ID>
67
<ID>TooGenericExceptionCaught:KtorHuggingFaceApiImpl.kt$KtorHuggingFaceApiImpl$e: Exception</ID>
78
<ID>TooGenericExceptionCaught:WhisperKitImpl.kt$WhisperKitImpl$e: Exception</ID>
9+
<ID>UnusedParameter:WhisperKitImpl.kt$WhisperKitImpl$timestamp: Float</ID>
810
</CurrentIssues>
911
</SmellBaseline>

0 commit comments

Comments
 (0)