Skip to content

Commit 1bc2c7a

Browse files
committed
add trans-style
1 parent e23fafa commit 1bc2c7a

File tree

3 files changed

+124
-42
lines changed

3 files changed

+124
-42
lines changed

subs/ai/src/main/assets/mosaic.pt

6.43 MB
Binary file not shown.

subs/ai/src/main/java/com/engineer/ai/AIHomeActivity.kt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ import androidx.appcompat.app.AppCompatActivity
99

1010
class AIHomeActivity : AppCompatActivity() {
1111

12-
private val pages = arrayOf(GanActivity::class.java,
13-
DigitalClassificationActivity::class.java)
12+
private val pages = arrayOf(
13+
GanActivity::class.java,
14+
DigitalClassificationActivity::class.java,
15+
FastStyleTransActivity::class.java
16+
)
1417

1518
override fun onCreate(savedInstanceState: Bundle?) {
1619
super.onCreate(savedInstanceState)

subs/ai/src/main/java/com/engineer/ai/FastStyleTransActivity.kt

Lines changed: 119 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,23 @@ import androidx.appcompat.app.AlertDialog
1515
import androidx.appcompat.app.AppCompatActivity
1616
import androidx.core.app.ActivityCompat
1717
import androidx.core.content.ContextCompat
18-
import com.engineer.ai.databinding.ActivityGanBinding
1918
import com.engineer.ai.databinding.ActivityTansStyleBinding
2019
import com.engineer.ai.util.AndroidAssetsFileUtil
21-
import com.engineer.ai.util.Utils
20+
import kotlinx.coroutines.Dispatchers
21+
import kotlinx.coroutines.GlobalScope
22+
import kotlinx.coroutines.launch
2223
import org.pytorch.IValue
2324
import org.pytorch.LiteModuleLoader
2425
import org.pytorch.Module
2526
import org.pytorch.Tensor
26-
import java.util.Random
27+
import org.pytorch.torchvision.TensorImageUtils
2728

2829

2930
class FastStyleTransActivity : AppCompatActivity() {
3031
private val TAG = "FastStyleTransActivity"
3132
private lateinit var module: Module
32-
private val modelName = "dcgan.pt"
33+
private val modelName = "mosaic.pt"
34+
private var currentBitmap: Bitmap? = null
3335

3436
private lateinit var viewBinding: ActivityTansStyleBinding
3537
override fun onCreate(savedInstanceState: Bundle?) {
@@ -41,44 +43,119 @@ class FastStyleTransActivity : AppCompatActivity() {
4143
pickImage()
4244
}
4345
viewBinding.gen.setOnClickListener {
44-
genImage()
46+
GlobalScope.launch(Dispatchers.IO) {
47+
genImage()
48+
}
4549
}
4650
}
4751

4852
private fun showBitmap(bitmap: Bitmap) {
4953
viewBinding.pickResult.setImageBitmap(bitmap)
54+
currentBitmap = bitmap
5055
}
5156

52-
private fun genImage() {
53-
val zDim = intArrayOf(1, 100)
54-
val outDims = intArrayOf(64, 64, 3)
55-
Log.d(TAG, zDim.contentToString())
56-
val z = FloatArray(zDim[0] * zDim[1])
57-
Log.d(TAG, "z = ${z.contentToString()}")
58-
val rand = Random()
59-
// 生成高斯随机数
60-
for (c in 0 until zDim[0] * zDim[1]) {
61-
z[c] = rand.nextGaussian().toFloat()
57+
fun multiplyTensorBy255(inputTensor: Tensor): Tensor {
58+
// 获取 Tensor 的浮点数组
59+
val inputArray = inputTensor.dataAsFloatArray
60+
61+
// 创建新数组并乘以255
62+
val outputArray = FloatArray(inputArray.size) { i ->
63+
inputArray[i] * 255.0f
64+
}
65+
66+
// 创建新的 Tensor(保持原始形状)
67+
return Tensor.fromBlob(outputArray, inputTensor.shape())
68+
}
69+
70+
fun divTensorBy255(inputTensor: Tensor): Tensor {
71+
// 获取 Tensor 的浮点数组
72+
val inputArray = inputTensor.dataAsFloatArray
73+
74+
// 创建新数组并乘以255
75+
val outputArray = FloatArray(inputArray.size) { i ->
76+
inputArray[i] * 255.0f
6277
}
63-
Log.d(TAG, "z = ${z.contentToString()}")
64-
val shape = longArrayOf(1, 100)
65-
val tensor = Tensor.fromBlob(z, shape)
66-
Log.d(TAG, tensor.dataAsFloatArray.contentToString())
67-
val resultT = module.forward(IValue.from(tensor)).toTensor()
68-
val resultArray = resultT.dataAsFloatArray
69-
val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }
70-
var index = 0
71-
// 根据输出的一维数组,解析生成的卡通图像
72-
for (j in 0 until outDims[2]) {
73-
for (k in 0 until outDims[0]) {
74-
for (m in 0 until outDims[1]) {
75-
resultImg[k][m][j] = resultArray[index] * 127.5f + 127.5f
76-
index++
78+
79+
// 创建新的 Tensor(保持原始形状)
80+
return Tensor.fromBlob(outputArray, inputTensor.shape())
81+
}
82+
83+
private fun genImage() {
84+
85+
val inDims: IntArray = intArrayOf(224, 224, 3)
86+
val outDims: IntArray = intArrayOf(224, 224, 3)
87+
val bmp: Bitmap? = null
88+
var scaledBmp: Bitmap? = null
89+
val filePath = ""
90+
currentBitmap?.let {
91+
scaledBmp = Bitmap.createScaledBitmap(it, inDims[0], inDims[1], true);
92+
93+
94+
// Android更简洁的实现
95+
// 转换为张量并归一化到[0,1]
96+
val inputTensor: Tensor = TensorImageUtils.bitmapToFloat32Tensor(
97+
currentBitmap, floatArrayOf(0f, 0f, 0f), // 不减去均值
98+
floatArrayOf(1f, 1f, 1f) // 不除以标准差
99+
)
100+
101+
val tensor = multiplyTensorBy255(inputTensor);
102+
103+
Log.i(TAG, "1")
104+
105+
val resultTensor = module.forward(IValue.from(tensor)).toTensor()
106+
val out = divTensorBy255(resultTensor)
107+
108+
Log.i(TAG, "2")
109+
110+
val outputArray = out.dataAsFloatArray
111+
val width = outDims[0]
112+
val height = outDims[1]
113+
val outputBitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
114+
115+
// 将浮点数组转换为Bitmap (简化实现,实际可能需要更复杂的转换)
116+
for (y in 0 until height) {
117+
for (x in 0 until width) {
118+
val r = (outputArray[y * width * 3 + x * 3 + 0] * 255).toInt().coerceIn(0, 255)
119+
val g = (outputArray[y * width * 3 + x * 3 + 1] * 255).toInt().coerceIn(0, 255)
120+
val b = (outputArray[y * width * 3 + x * 3 + 2] * 255).toInt().coerceIn(0, 255)
121+
outputBitmap.setPixel(x, y, android.graphics.Color.rgb(r, g, b))
77122
}
78123
}
124+
Log.i(TAG, "3")
125+
GlobalScope.launch(Dispatchers.Main) {
126+
viewBinding.transResult.setImageBitmap(outputBitmap)
127+
}
79128
}
80-
val bitmap = Utils.getBitmap(resultImg, outDims)
81-
viewBinding.transResult.setImageBitmap(bitmap)
129+
130+
// val zDim = intArrayOf(1, 100)
131+
// val outDims = intArrayOf(64, 64, 3)
132+
// Log.d(TAG, zDim.contentToString())
133+
// val z = FloatArray(zDim[0] * zDim[1])
134+
// Log.d(TAG, "z = ${z.contentToString()}")
135+
// val rand = Random()
136+
// // 生成高斯随机数
137+
// for (c in 0 until zDim[0] * zDim[1]) {
138+
// z[c] = rand.nextGaussian().toFloat()
139+
// }
140+
// Log.d(TAG, "z = ${z.contentToString()}")
141+
// val shape = longArrayOf(1, 100)
142+
// val tensor = Tensor.fromBlob(z, shape)
143+
// Log.d(TAG, tensor.dataAsFloatArray.contentToString())
144+
// val resultT = module.forward(IValue.from(tensor)).toTensor()
145+
// val resultArray = resultT.dataAsFloatArray
146+
// val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }
147+
// var index = 0
148+
// // 根据输出的一维数组,解析生成的卡通图像
149+
// for (j in 0 until outDims[2]) {
150+
// for (k in 0 until outDims[0]) {
151+
// for (m in 0 until outDims[1]) {
152+
// resultImg[k][m][j] = resultArray[index] * 127.5f + 127.5f
153+
// index++
154+
// }
155+
// }
156+
// }
157+
// val bitmap = Utils.getBitmap(resultImg, outDims)
158+
// viewBinding.transResult.setImageBitmap(bitmap)
82159
}
83160

84161
private fun initModel() {
@@ -96,15 +173,16 @@ class FastStyleTransActivity : AppCompatActivity() {
96173
}
97174
}
98175

99-
private val pickMedia = registerForActivityResult(ActivityResultContracts.PickVisualMedia()) { uri ->
100-
// 处理选择的图片
101-
uri?.let {
102-
val inputStream = contentResolver.openInputStream(uri)
103-
val bitmap = BitmapFactory.decodeStream(inputStream)
104-
inputStream?.close()
105-
showBitmap(bitmap)
176+
private val pickMedia =
177+
registerForActivityResult(ActivityResultContracts.PickVisualMedia()) { uri ->
178+
// 处理选择的图片
179+
uri?.let {
180+
val inputStream = contentResolver.openInputStream(uri)
181+
val bitmap = BitmapFactory.decodeStream(inputStream)
182+
inputStream?.close()
183+
showBitmap(bitmap)
184+
}
106185
}
107-
}
108186

109187
// 定义权限请求和图片选择启动器
110188
private val requestPermissionLauncher = registerForActivityResult(
@@ -158,7 +236,8 @@ class FastStyleTransActivity : AppCompatActivity() {
158236
this, permission
159237
) -> {
160238
// 解释为什么需要权限
161-
AlertDialog.Builder(this).setTitle("需要权限").setMessage("需要存储权限才能从相册选择图片")
239+
AlertDialog.Builder(this).setTitle("需要权限")
240+
.setMessage("需要存储权限才能从相册选择图片")
162241
.setPositiveButton("确定") { _, _ ->
163242
requestPermissionLauncher.launch(permission)
164243
}.setNegativeButton("取消", null).show()

0 commit comments

Comments
 (0)