@@ -8,154 +8,109 @@ import android.os.Build
8
8
import android.os.Bundle
9
9
import android.util.Log
10
10
import android.widget.Toast
11
+ import androidx.activity.enableEdgeToEdge
11
12
import androidx.activity.result.PickVisualMediaRequest
12
13
import androidx.activity.result.contract.ActivityResultContracts
13
14
import androidx.activity.result.contract.ActivityResultContracts.PickVisualMedia.Companion.isPhotoPickerAvailable
14
15
import androidx.appcompat.app.AlertDialog
15
16
import androidx.appcompat.app.AppCompatActivity
16
17
import androidx.core.app.ActivityCompat
17
18
import androidx.core.content.ContextCompat
19
+ import androidx.lifecycle.lifecycleScope
18
20
import com.engineer.ai.databinding.ActivityTansStyleBinding
19
21
import com.engineer.ai.util.AndroidAssetsFileUtil
20
- import kotlinx.coroutines.Dispatchers
22
+ import com.engineer.ai.util.AsyncExecutor
23
+ import com.engineer.ai.util.StyleTransferProcessor
24
+ import com.engineer.ai.util.gone
25
+ import com.engineer.ai.util.show
21
26
import kotlinx.coroutines.GlobalScope
27
+ import kotlinx.coroutines.cancel
22
28
import kotlinx.coroutines.launch
23
- import org.pytorch.IValue
24
29
import org.pytorch.LiteModuleLoader
25
30
import org.pytorch.Module
26
- import org.pytorch.Tensor
27
- import org.pytorch.torchvision.TensorImageUtils
28
31
29
32
30
33
class FastStyleTransActivity : AppCompatActivity () {
31
- private val TAG = " FastStyleTransActivity "
34
+ private val TAG = " FastStyleTransActivity_TAG "
32
35
private lateinit var module: Module
33
36
private val modelName = " mosaic.pt"
34
37
private var currentBitmap: Bitmap ? = null
35
38
36
39
private lateinit var viewBinding: ActivityTansStyleBinding
37
40
override fun onCreate (savedInstanceState : Bundle ? ) {
38
41
super .onCreate(savedInstanceState)
42
+ enableEdgeToEdge()
39
43
viewBinding = ActivityTansStyleBinding .inflate(layoutInflater)
40
44
setContentView(viewBinding.root)
41
45
initModel()
42
46
viewBinding.pickImg.setOnClickListener {
43
47
pickImage()
44
48
}
45
49
viewBinding.gen.setOnClickListener {
46
- GlobalScope .launch(Dispatchers .IO ) {
50
+ refreshLoading(true )
51
+ // lifecycleScope.launch {
52
+ // genImage()
53
+ // }
54
+ GlobalScope .launch {
47
55
genImage()
48
56
}
49
57
}
50
58
}
51
59
60
+ override fun onDestroy () {
61
+ super .onDestroy()
62
+ lifecycleScope.cancel()
63
+ }
64
+
52
65
private fun showBitmap (bitmap : Bitmap ) {
53
66
viewBinding.pickResult.setImageBitmap(bitmap)
67
+
68
+ Log .i(TAG , " ori = ${bitmap.width} ,${bitmap.height} " )
54
69
currentBitmap = bitmap
55
70
}
56
71
57
- fun multiplyTensorBy255 (inputTensor : Tensor ): Tensor {
58
- // 获取 Tensor 的浮点数组
59
- val inputArray = inputTensor.dataAsFloatArray
60
-
61
- // 创建新数组并乘以255
62
- val outputArray = FloatArray (inputArray.size) { i ->
63
- inputArray[i] * 255.0f
64
- }
65
-
66
- // 创建新的 Tensor(保持原始形状)
67
- return Tensor .fromBlob(outputArray, inputTensor.shape())
72
+ private fun refreshLoading (show : Boolean ) {
73
+ if (show) viewBinding.loading.show() else viewBinding.loading.gone()
68
74
}
69
75
70
- fun divTensorBy255 (inputTensor : Tensor ): Tensor {
71
- // 获取 Tensor 的浮点数组
72
- val inputArray = inputTensor.dataAsFloatArray
73
-
74
- // 创建新数组并乘以255
75
- val outputArray = FloatArray (inputArray.size) { i ->
76
- inputArray[i] * 255.0f
77
- }
78
-
79
- // 创建新的 Tensor(保持原始形状)
80
- return Tensor .fromBlob(outputArray, inputTensor.shape())
81
- }
82
76
83
77
private fun genImage () {
84
78
85
- val inDims: IntArray = intArrayOf(224 , 224 , 3 )
86
- val outDims: IntArray = intArrayOf(224 , 224 , 3 )
87
- val bmp: Bitmap ? = null
88
- var scaledBmp: Bitmap ? = null
89
- val filePath = " "
90
79
currentBitmap?.let {
91
- scaledBmp = Bitmap .createScaledBitmap(it, inDims[0 ], inDims[1 ], true );
92
-
93
-
94
- // Android更简洁的实现
95
- // 转换为张量并归一化到[0,1]
96
- val inputTensor: Tensor = TensorImageUtils .bitmapToFloat32Tensor(
97
- currentBitmap, floatArrayOf(0f , 0f , 0f ), // 不减去均值
98
- floatArrayOf(1f , 1f , 1f ) // 不除以标准差
99
- )
100
-
101
- val tensor = multiplyTensorBy255(inputTensor);
102
-
103
- Log .i(TAG , " 1" )
104
-
105
- val resultTensor = module.forward(IValue .from(tensor)).toTensor()
106
- val out = divTensorBy255(resultTensor)
107
-
108
- Log .i(TAG , " 2" )
80
+ // AsyncExecutor.fromIO().execute {
81
+ // StyleTransferProcessor.initModule(module)
82
+ // StyleTransferProcessor.transferStyle(it, 1.0f)
83
+ // }.awaitResult<Bitmap>(onSuccess = {
84
+ // Log.i(TAG, "onSuccess")
85
+ // refreshLoading(false)
86
+ // Log.i(TAG, "output ${it.width},${it.height}")
87
+ // viewBinding.transResult.setImageBitmap(it)
88
+ // }, onError = {
89
+ // refreshLoading(false)
90
+ // Log.i(TAG, it.stackTraceToString())
91
+ // })
92
+
93
+ AsyncExecutor .fromIO().execute {
94
+ StyleTransferProcessor .initModule(module)
95
+ // val it = StyleTransferProcessor.transferStyle(it, 1.0f)
96
+ //
97
+ // withContext(Dispatchers.Main) {
98
+ // refreshLoading(false)
99
+ // Log.i(TAG, "output ${it.width},${it.height}")
100
+ // viewBinding.transResult.setImageBitmap(it)
101
+ // }
109
102
110
- val outputArray = out .dataAsFloatArray
111
- val width = outDims[0 ]
112
- val height = outDims[1 ]
113
- val outputBitmap = Bitmap .createBitmap(width, height, Bitmap .Config .ARGB_8888 )
103
+ StyleTransferProcessor .transferStyleAsync(it, 0.5f ) {
104
+ runOnUiThread {
105
+ refreshLoading(false )
114
106
115
- // 将浮点数组转换为Bitmap (简化实现,实际可能需要更复杂的转换)
116
- for (y in 0 until height) {
117
- for (x in 0 until width) {
118
- val r = (outputArray[y * width * 3 + x * 3 + 0 ] * 255 ).toInt().coerceIn(0 , 255 )
119
- val g = (outputArray[y * width * 3 + x * 3 + 1 ] * 255 ).toInt().coerceIn(0 , 255 )
120
- val b = (outputArray[y * width * 3 + x * 3 + 2 ] * 255 ).toInt().coerceIn(0 , 255 )
121
- outputBitmap.setPixel(x, y, android.graphics.Color .rgb(r, g, b))
107
+ viewBinding.transResult.setImageBitmap(it)
108
+ }
122
109
}
123
110
}
124
- Log .i(TAG , " 3" )
125
- GlobalScope .launch(Dispatchers .Main ) {
126
- viewBinding.transResult.setImageBitmap(outputBitmap)
127
- }
128
111
}
129
112
130
- // val zDim = intArrayOf(1, 100)
131
- // val outDims = intArrayOf(64, 64, 3)
132
- // Log.d(TAG, zDim.contentToString())
133
- // val z = FloatArray(zDim[0] * zDim[1])
134
- // Log.d(TAG, "z = ${z.contentToString()}")
135
- // val rand = Random()
136
- // // 生成高斯随机数
137
- // for (c in 0 until zDim[0] * zDim[1]) {
138
- // z[c] = rand.nextGaussian().toFloat()
139
- // }
140
- // Log.d(TAG, "z = ${z.contentToString()}")
141
- // val shape = longArrayOf(1, 100)
142
- // val tensor = Tensor.fromBlob(z, shape)
143
- // Log.d(TAG, tensor.dataAsFloatArray.contentToString())
144
- // val resultT = module.forward(IValue.from(tensor)).toTensor()
145
- // val resultArray = resultT.dataAsFloatArray
146
- // val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }
147
- // var index = 0
148
- // // 根据输出的一维数组,解析生成的卡通图像
149
- // for (j in 0 until outDims[2]) {
150
- // for (k in 0 until outDims[0]) {
151
- // for (m in 0 until outDims[1]) {
152
- // resultImg[k][m][j] = resultArray[index] * 127.5f + 127.5f
153
- // index++
154
- // }
155
- // }
156
- // }
157
- // val bitmap = Utils.getBitmap(resultImg, outDims)
158
- // viewBinding.transResult.setImageBitmap(bitmap)
113
+
159
114
}
160
115
161
116
private fun initModel () {
0 commit comments