Skip to content

Commit f2dab35

Browse files
committed
update trans
1 parent df14a47 commit f2dab35

File tree

5 files changed

+297
-96
lines changed

5 files changed

+297
-96
lines changed

subs/ai/src/main/java/com/engineer/ai/FastStyleTransActivity.kt

Lines changed: 51 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -8,154 +8,109 @@ import android.os.Build
88
import android.os.Bundle
99
import android.util.Log
1010
import android.widget.Toast
11+
import androidx.activity.enableEdgeToEdge
1112
import androidx.activity.result.PickVisualMediaRequest
1213
import androidx.activity.result.contract.ActivityResultContracts
1314
import androidx.activity.result.contract.ActivityResultContracts.PickVisualMedia.Companion.isPhotoPickerAvailable
1415
import androidx.appcompat.app.AlertDialog
1516
import androidx.appcompat.app.AppCompatActivity
1617
import androidx.core.app.ActivityCompat
1718
import androidx.core.content.ContextCompat
19+
import androidx.lifecycle.lifecycleScope
1820
import com.engineer.ai.databinding.ActivityTansStyleBinding
1921
import com.engineer.ai.util.AndroidAssetsFileUtil
20-
import kotlinx.coroutines.Dispatchers
22+
import com.engineer.ai.util.AsyncExecutor
23+
import com.engineer.ai.util.StyleTransferProcessor
24+
import com.engineer.ai.util.gone
25+
import com.engineer.ai.util.show
2126
import kotlinx.coroutines.GlobalScope
27+
import kotlinx.coroutines.cancel
2228
import kotlinx.coroutines.launch
23-
import org.pytorch.IValue
2429
import org.pytorch.LiteModuleLoader
2530
import org.pytorch.Module
26-
import org.pytorch.Tensor
27-
import org.pytorch.torchvision.TensorImageUtils
2831

2932

3033
class FastStyleTransActivity : AppCompatActivity() {
31-
private val TAG = "FastStyleTransActivity"
34+
private val TAG = "FastStyleTransActivity_TAG"
3235
private lateinit var module: Module
3336
private val modelName = "mosaic.pt"
3437
private var currentBitmap: Bitmap? = null
3538

3639
private lateinit var viewBinding: ActivityTansStyleBinding
3740
override fun onCreate(savedInstanceState: Bundle?) {
3841
super.onCreate(savedInstanceState)
42+
enableEdgeToEdge()
3943
viewBinding = ActivityTansStyleBinding.inflate(layoutInflater)
4044
setContentView(viewBinding.root)
4145
initModel()
4246
viewBinding.pickImg.setOnClickListener {
4347
pickImage()
4448
}
4549
viewBinding.gen.setOnClickListener {
46-
GlobalScope.launch(Dispatchers.IO) {
50+
refreshLoading(true)
51+
// lifecycleScope.launch {
52+
// genImage()
53+
// }
54+
GlobalScope.launch {
4755
genImage()
4856
}
4957
}
5058
}
5159

60+
override fun onDestroy() {
61+
super.onDestroy()
62+
lifecycleScope.cancel()
63+
}
64+
5265
private fun showBitmap(bitmap: Bitmap) {
5366
viewBinding.pickResult.setImageBitmap(bitmap)
67+
68+
Log.i(TAG, "ori = ${bitmap.width},${bitmap.height}")
5469
currentBitmap = bitmap
5570
}
5671

57-
fun multiplyTensorBy255(inputTensor: Tensor): Tensor {
58-
// 获取 Tensor 的浮点数组
59-
val inputArray = inputTensor.dataAsFloatArray
60-
61-
// 创建新数组并乘以255
62-
val outputArray = FloatArray(inputArray.size) { i ->
63-
inputArray[i] * 255.0f
64-
}
65-
66-
// 创建新的 Tensor(保持原始形状)
67-
return Tensor.fromBlob(outputArray, inputTensor.shape())
72+
private fun refreshLoading(show: Boolean) {
73+
if (show) viewBinding.loading.show() else viewBinding.loading.gone()
6874
}
6975

70-
fun divTensorBy255(inputTensor: Tensor): Tensor {
71-
// 获取 Tensor 的浮点数组
72-
val inputArray = inputTensor.dataAsFloatArray
73-
74-
// 创建新数组并乘以255
75-
val outputArray = FloatArray(inputArray.size) { i ->
76-
inputArray[i] * 255.0f
77-
}
78-
79-
// 创建新的 Tensor(保持原始形状)
80-
return Tensor.fromBlob(outputArray, inputTensor.shape())
81-
}
8276

8377
private fun genImage() {
8478

85-
val inDims: IntArray = intArrayOf(224, 224, 3)
86-
val outDims: IntArray = intArrayOf(224, 224, 3)
87-
val bmp: Bitmap? = null
88-
var scaledBmp: Bitmap? = null
89-
val filePath = ""
9079
currentBitmap?.let {
91-
scaledBmp = Bitmap.createScaledBitmap(it, inDims[0], inDims[1], true);
92-
93-
94-
// Android更简洁的实现
95-
// 转换为张量并归一化到[0,1]
96-
val inputTensor: Tensor = TensorImageUtils.bitmapToFloat32Tensor(
97-
currentBitmap, floatArrayOf(0f, 0f, 0f), // 不减去均值
98-
floatArrayOf(1f, 1f, 1f) // 不除以标准差
99-
)
100-
101-
val tensor = multiplyTensorBy255(inputTensor);
102-
103-
Log.i(TAG, "1")
104-
105-
val resultTensor = module.forward(IValue.from(tensor)).toTensor()
106-
val out = divTensorBy255(resultTensor)
107-
108-
Log.i(TAG, "2")
80+
// AsyncExecutor.fromIO().execute {
81+
// StyleTransferProcessor.initModule(module)
82+
// StyleTransferProcessor.transferStyle(it, 1.0f)
83+
// }.awaitResult<Bitmap>(onSuccess = {
84+
// Log.i(TAG, "onSuccess")
85+
// refreshLoading(false)
86+
// Log.i(TAG, "output ${it.width},${it.height}")
87+
// viewBinding.transResult.setImageBitmap(it)
88+
// }, onError = {
89+
// refreshLoading(false)
90+
// Log.i(TAG, it.stackTraceToString())
91+
// })
92+
93+
AsyncExecutor.fromIO().execute {
94+
StyleTransferProcessor.initModule(module)
95+
// val it = StyleTransferProcessor.transferStyle(it, 1.0f)
96+
//
97+
// withContext(Dispatchers.Main) {
98+
// refreshLoading(false)
99+
// Log.i(TAG, "output ${it.width},${it.height}")
100+
// viewBinding.transResult.setImageBitmap(it)
101+
// }
109102

110-
val outputArray = out.dataAsFloatArray
111-
val width = outDims[0]
112-
val height = outDims[1]
113-
val outputBitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
103+
StyleTransferProcessor.transferStyleAsync(it, 0.5f) {
104+
runOnUiThread {
105+
refreshLoading(false)
114106

115-
// 将浮点数组转换为Bitmap (简化实现,实际可能需要更复杂的转换)
116-
for (y in 0 until height) {
117-
for (x in 0 until width) {
118-
val r = (outputArray[y * width * 3 + x * 3 + 0] * 255).toInt().coerceIn(0, 255)
119-
val g = (outputArray[y * width * 3 + x * 3 + 1] * 255).toInt().coerceIn(0, 255)
120-
val b = (outputArray[y * width * 3 + x * 3 + 2] * 255).toInt().coerceIn(0, 255)
121-
outputBitmap.setPixel(x, y, android.graphics.Color.rgb(r, g, b))
107+
viewBinding.transResult.setImageBitmap(it)
108+
}
122109
}
123110
}
124-
Log.i(TAG, "3")
125-
GlobalScope.launch(Dispatchers.Main) {
126-
viewBinding.transResult.setImageBitmap(outputBitmap)
127-
}
128111
}
129112

130-
// val zDim = intArrayOf(1, 100)
131-
// val outDims = intArrayOf(64, 64, 3)
132-
// Log.d(TAG, zDim.contentToString())
133-
// val z = FloatArray(zDim[0] * zDim[1])
134-
// Log.d(TAG, "z = ${z.contentToString()}")
135-
// val rand = Random()
136-
// // 生成高斯随机数
137-
// for (c in 0 until zDim[0] * zDim[1]) {
138-
// z[c] = rand.nextGaussian().toFloat()
139-
// }
140-
// Log.d(TAG, "z = ${z.contentToString()}")
141-
// val shape = longArrayOf(1, 100)
142-
// val tensor = Tensor.fromBlob(z, shape)
143-
// Log.d(TAG, tensor.dataAsFloatArray.contentToString())
144-
// val resultT = module.forward(IValue.from(tensor)).toTensor()
145-
// val resultArray = resultT.dataAsFloatArray
146-
// val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }
147-
// var index = 0
148-
// // 根据输出的一维数组,解析生成的卡通图像
149-
// for (j in 0 until outDims[2]) {
150-
// for (k in 0 until outDims[0]) {
151-
// for (m in 0 until outDims[1]) {
152-
// resultImg[k][m][j] = resultArray[index] * 127.5f + 127.5f
153-
// index++
154-
// }
155-
// }
156-
// }
157-
// val bitmap = Utils.getBitmap(resultImg, outDims)
158-
// viewBinding.transResult.setImageBitmap(bitmap)
113+
159114
}
160115

161116
private fun initModel() {

subs/ai/src/main/java/com/engineer/ai/util/AndroidAssetsFileUtil.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.engineer.ai.util
22

33
import android.content.Context
4+
import android.view.View
45
import java.io.File
56
import java.io.FileOutputStream
67

@@ -25,4 +26,16 @@ object AndroidAssetsFileUtil {
2526
file.absolutePath
2627
}
2728
}
29+
}
30+
31+
fun View?.show() {
32+
this?.let {
33+
visibility = View.VISIBLE
34+
}
35+
}
36+
37+
fun View?.gone() {
38+
this?.let {
39+
visibility = View.GONE
40+
}
2841
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package com.engineer.ai.util
2+
3+
import android.util.Log
4+
import kotlinx.coroutines.*
5+
import kotlin.coroutines.CoroutineContext
6+
7+
/**
8+
* 协程线程切换框架
9+
* 示例用法:
10+
* AsyncExecutor(Dispatchers.IO) // 初始线程
11+
* .execute { fetchDataFromApi() } // 执行耗时操作
12+
* .switchTo(Dispatchers.Main) // 切换线程
13+
* .execute { updateUI() } // 更新UI
14+
* .onError { handleError(it) } // 错误处理
15+
*/
16+
class AsyncExecutor(
17+
private val context: CoroutineContext = Dispatchers.Default, private val parentJob: Job? = null
18+
) {
19+
private val scope: CoroutineScope by lazy {
20+
CoroutineScope(context + Job(parentJob ?: SupervisorJob()))
21+
}
22+
internal var deferred: Deferred<Any?>? = null
23+
private var errorHandler: (Throwable) -> Unit = {}
24+
25+
// 核心执行方法
26+
fun execute(block: suspend () -> Unit): AsyncExecutor {
27+
deferred = scope.async {
28+
try {
29+
block()
30+
} catch (e: Throwable) {
31+
errorHandler(e)
32+
throw e
33+
}
34+
}
35+
return this
36+
}
37+
38+
// 线程切换方法
39+
fun switchTo(newContext: CoroutineContext): AsyncExecutor {
40+
return AsyncExecutor(newContext, deferred?.let { Job(it) })
41+
}
42+
43+
// 错误处理
44+
fun onError(handler: (Throwable) -> Unit): AsyncExecutor {
45+
errorHandler = handler
46+
return this
47+
}
48+
49+
// 取消任务
50+
fun cancel() {
51+
scope.cancel()
52+
}
53+
54+
companion object {
55+
// 快捷入口
56+
fun fromIO() = AsyncExecutor(Dispatchers.IO)
57+
fun fromMain() = AsyncExecutor(Dispatchers.Main)
58+
fun fromDefault() = AsyncExecutor(Dispatchers.Default)
59+
}
60+
}
61+
62+
// 扩展函数:自动处理结果到主线程
63+
suspend fun <T> AsyncExecutor.awaitResult(
64+
onSuccess: (T) -> Unit, onError: (Throwable) -> Unit = {}
65+
): AsyncExecutor {
66+
return this.apply {
67+
try {
68+
val result = deferred?.await() as T
69+
Log.d("StyleTransferProcessor", "result is is ok")
70+
switchTo(Dispatchers.Main).execute { onSuccess(result) }
71+
} catch (e: Throwable) {
72+
Log.d("StyleTransferProcessor","onError ${e.stackTraceToString()}")
73+
switchTo(Dispatchers.Main).execute { onError(e) }
74+
}
75+
}
76+
}

0 commit comments

Comments
 (0)