@@ -15,21 +15,23 @@ import androidx.appcompat.app.AlertDialog
15
15
import androidx.appcompat.app.AppCompatActivity
16
16
import androidx.core.app.ActivityCompat
17
17
import androidx.core.content.ContextCompat
18
- import com.engineer.ai.databinding.ActivityGanBinding
19
18
import com.engineer.ai.databinding.ActivityTansStyleBinding
20
19
import com.engineer.ai.util.AndroidAssetsFileUtil
21
- import com.engineer.ai.util.Utils
20
+ import kotlinx.coroutines.Dispatchers
21
+ import kotlinx.coroutines.GlobalScope
22
+ import kotlinx.coroutines.launch
22
23
import org.pytorch.IValue
23
24
import org.pytorch.LiteModuleLoader
24
25
import org.pytorch.Module
25
26
import org.pytorch.Tensor
26
- import java.util.Random
27
+ import org.pytorch.torchvision.TensorImageUtils
27
28
28
29
29
30
class FastStyleTransActivity : AppCompatActivity () {
30
31
private val TAG = " FastStyleTransActivity"
31
32
private lateinit var module: Module
32
- private val modelName = " dcgan.pt"
33
+ private val modelName = " mosaic.pt"
34
+ private var currentBitmap: Bitmap ? = null
33
35
34
36
private lateinit var viewBinding: ActivityTansStyleBinding
35
37
override fun onCreate (savedInstanceState : Bundle ? ) {
@@ -41,44 +43,119 @@ class FastStyleTransActivity : AppCompatActivity() {
41
43
pickImage()
42
44
}
43
45
viewBinding.gen.setOnClickListener {
44
- genImage()
46
+ GlobalScope .launch(Dispatchers .IO ) {
47
+ genImage()
48
+ }
45
49
}
46
50
}
47
51
48
52
private fun showBitmap (bitmap : Bitmap ) {
49
53
viewBinding.pickResult.setImageBitmap(bitmap)
54
+ currentBitmap = bitmap
50
55
}
51
56
52
- private fun genImage () {
53
- val zDim = intArrayOf(1 , 100 )
54
- val outDims = intArrayOf(64 , 64 , 3 )
55
- Log .d(TAG , zDim.contentToString())
56
- val z = FloatArray (zDim[0 ] * zDim[1 ])
57
- Log .d(TAG , " z = ${z.contentToString()} " )
58
- val rand = Random ()
59
- // 生成高斯随机数
60
- for (c in 0 until zDim[0 ] * zDim[1 ]) {
61
- z[c] = rand.nextGaussian().toFloat()
57
+ fun multiplyTensorBy255 (inputTensor : Tensor ): Tensor {
58
+ // 获取 Tensor 的浮点数组
59
+ val inputArray = inputTensor.dataAsFloatArray
60
+
61
+ // 创建新数组并乘以255
62
+ val outputArray = FloatArray (inputArray.size) { i ->
63
+ inputArray[i] * 255.0f
64
+ }
65
+
66
+ // 创建新的 Tensor(保持原始形状)
67
+ return Tensor .fromBlob(outputArray, inputTensor.shape())
68
+ }
69
+
70
+ fun divTensorBy255 (inputTensor : Tensor ): Tensor {
71
+ // 获取 Tensor 的浮点数组
72
+ val inputArray = inputTensor.dataAsFloatArray
73
+
74
+ // 创建新数组并乘以255
75
+ val outputArray = FloatArray (inputArray.size) { i ->
76
+ inputArray[i] * 255.0f
62
77
}
63
- Log .d(TAG , " z = ${z.contentToString()} " )
64
- val shape = longArrayOf(1 , 100 )
65
- val tensor = Tensor .fromBlob(z, shape)
66
- Log .d(TAG , tensor.dataAsFloatArray.contentToString())
67
- val resultT = module.forward(IValue .from(tensor)).toTensor()
68
- val resultArray = resultT.dataAsFloatArray
69
- val resultImg = Array (outDims[0 ]) { Array (outDims[1 ]) { FloatArray (outDims[2 ]) { 0.0f } } }
70
- var index = 0
71
- // 根据输出的一维数组,解析生成的卡通图像
72
- for (j in 0 until outDims[2 ]) {
73
- for (k in 0 until outDims[0 ]) {
74
- for (m in 0 until outDims[1 ]) {
75
- resultImg[k][m][j] = resultArray[index] * 127.5f + 127.5f
76
- index++
78
+
79
+ // 创建新的 Tensor(保持原始形状)
80
+ return Tensor .fromBlob(outputArray, inputTensor.shape())
81
+ }
82
+
83
+ private fun genImage () {
84
+
85
+ val inDims: IntArray = intArrayOf(224 , 224 , 3 )
86
+ val outDims: IntArray = intArrayOf(224 , 224 , 3 )
87
+ val bmp: Bitmap ? = null
88
+ var scaledBmp: Bitmap ? = null
89
+ val filePath = " "
90
+ currentBitmap?.let {
91
+ scaledBmp = Bitmap .createScaledBitmap(it, inDims[0 ], inDims[1 ], true );
92
+
93
+
94
+ // Android更简洁的实现
95
+ // 转换为张量并归一化到[0,1]
96
+ val inputTensor: Tensor = TensorImageUtils .bitmapToFloat32Tensor(
97
+ currentBitmap, floatArrayOf(0f , 0f , 0f ), // 不减去均值
98
+ floatArrayOf(1f , 1f , 1f ) // 不除以标准差
99
+ )
100
+
101
+ val tensor = multiplyTensorBy255(inputTensor);
102
+
103
+ Log .i(TAG , " 1" )
104
+
105
+ val resultTensor = module.forward(IValue .from(tensor)).toTensor()
106
+ val out = divTensorBy255(resultTensor)
107
+
108
+ Log .i(TAG , " 2" )
109
+
110
+ val outputArray = out .dataAsFloatArray
111
+ val width = outDims[0 ]
112
+ val height = outDims[1 ]
113
+ val outputBitmap = Bitmap .createBitmap(width, height, Bitmap .Config .ARGB_8888 )
114
+
115
+ // 将浮点数组转换为Bitmap (简化实现,实际可能需要更复杂的转换)
116
+ for (y in 0 until height) {
117
+ for (x in 0 until width) {
118
+ val r = (outputArray[y * width * 3 + x * 3 + 0 ] * 255 ).toInt().coerceIn(0 , 255 )
119
+ val g = (outputArray[y * width * 3 + x * 3 + 1 ] * 255 ).toInt().coerceIn(0 , 255 )
120
+ val b = (outputArray[y * width * 3 + x * 3 + 2 ] * 255 ).toInt().coerceIn(0 , 255 )
121
+ outputBitmap.setPixel(x, y, android.graphics.Color .rgb(r, g, b))
77
122
}
78
123
}
124
+ Log .i(TAG , " 3" )
125
+ GlobalScope .launch(Dispatchers .Main ) {
126
+ viewBinding.transResult.setImageBitmap(outputBitmap)
127
+ }
79
128
}
80
- val bitmap = Utils .getBitmap(resultImg, outDims)
81
- viewBinding.transResult.setImageBitmap(bitmap)
129
+
130
+ // val zDim = intArrayOf(1, 100)
131
+ // val outDims = intArrayOf(64, 64, 3)
132
+ // Log.d(TAG, zDim.contentToString())
133
+ // val z = FloatArray(zDim[0] * zDim[1])
134
+ // Log.d(TAG, "z = ${z.contentToString()}")
135
+ // val rand = Random()
136
+ // // 生成高斯随机数
137
+ // for (c in 0 until zDim[0] * zDim[1]) {
138
+ // z[c] = rand.nextGaussian().toFloat()
139
+ // }
140
+ // Log.d(TAG, "z = ${z.contentToString()}")
141
+ // val shape = longArrayOf(1, 100)
142
+ // val tensor = Tensor.fromBlob(z, shape)
143
+ // Log.d(TAG, tensor.dataAsFloatArray.contentToString())
144
+ // val resultT = module.forward(IValue.from(tensor)).toTensor()
145
+ // val resultArray = resultT.dataAsFloatArray
146
+ // val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }
147
+ // var index = 0
148
+ // // 根据输出的一维数组,解析生成的卡通图像
149
+ // for (j in 0 until outDims[2]) {
150
+ // for (k in 0 until outDims[0]) {
151
+ // for (m in 0 until outDims[1]) {
152
+ // resultImg[k][m][j] = resultArray[index] * 127.5f + 127.5f
153
+ // index++
154
+ // }
155
+ // }
156
+ // }
157
+ // val bitmap = Utils.getBitmap(resultImg, outDims)
158
+ // viewBinding.transResult.setImageBitmap(bitmap)
82
159
}
83
160
84
161
private fun initModel () {
@@ -96,15 +173,16 @@ class FastStyleTransActivity : AppCompatActivity() {
96
173
}
97
174
}
98
175
99
- private val pickMedia = registerForActivityResult(ActivityResultContracts .PickVisualMedia ()) { uri ->
100
- // 处理选择的图片
101
- uri?.let {
102
- val inputStream = contentResolver.openInputStream(uri)
103
- val bitmap = BitmapFactory .decodeStream(inputStream)
104
- inputStream?.close()
105
- showBitmap(bitmap)
176
+ private val pickMedia =
177
+ registerForActivityResult(ActivityResultContracts .PickVisualMedia ()) { uri ->
178
+ // 处理选择的图片
179
+ uri?.let {
180
+ val inputStream = contentResolver.openInputStream(uri)
181
+ val bitmap = BitmapFactory .decodeStream(inputStream)
182
+ inputStream?.close()
183
+ showBitmap(bitmap)
184
+ }
106
185
}
107
- }
108
186
109
187
// 定义权限请求和图片选择启动器
110
188
private val requestPermissionLauncher = registerForActivityResult(
@@ -158,7 +236,8 @@ class FastStyleTransActivity : AppCompatActivity() {
158
236
this , permission
159
237
) -> {
160
238
// 解释为什么需要权限
161
- AlertDialog .Builder (this ).setTitle(" 需要权限" ).setMessage(" 需要存储权限才能从相册选择图片" )
239
+ AlertDialog .Builder (this ).setTitle(" 需要权限" )
240
+ .setMessage(" 需要存储权限才能从相册选择图片" )
162
241
.setPositiveButton(" 确定" ) { _, _ ->
163
242
requestPermissionLauncher.launch(permission)
164
243
}.setNegativeButton(" 取消" , null ).show()
0 commit comments