-
I'm working on Local Diffusion, using stable-diffusion.cpp on Android. Vulkan performance on Mali GPUs is currently very poor Disabling Questions:
Looking for guidance to improve Vulkan matmul performance on Mali without breaking correctness |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 9 replies
-
What is the warp size and shared memory size for this GPU? These should be printed out on startup. The first value is the workgroup size. I'm surprised this broke things unless the workgroup size is smaller than the warp size. Which is currently faster, m_warptile or s_warptile? |
Beta Was this translation helpful? Give feedback.
-
Thanks for the response. Here's the warp size and shared memory size of the GPU:
I pretty much brute-forced all possible combinations while tuning In my case, |
Beta Was this translation helpful? Give feedback.
-
@0cc4m On my GTX 1060 I was able to reproduce the ops failure and the broken sd.cpp inference when switching to a workgroup size of 64,. So it's not a Mali-specific issue. I think you can reproduce it on your end as well. I'm not sure how to dig deeper into this, if you could help or take a look by yourself, that would be greatly appreciated. I believe that with a properly tuned Vulkan backend on Mali can already match or even surpass CPU performance on sd.cpp |
Beta Was this translation helpful? Give feedback.
You need to look into the meaning of the warptile parameters, they are not independent. I'll try to summarize what I remember:
The 11 parameters are: BLOCK_SIZE, BM, BN, BK, WM, WN, WMITER, TM, TN, TK and WARP.
They originate from this CUDA article, look at the kernel 10 information: https://siboehm.com/articles/22/CUDA-MMM
Especially the diagram is helpful.
For your problem: You need to make sure that the amount of warps in the subgroup (BLOCK_SIZE) is identical to the amount of warptiles. For example in the Nvidia case (warps of size 32) we have a subgroup of size BLOCK_SIZE=128, meaning 4 warps. BM=64, BN=64 and WM=32, WN=32 means we have 4 tiles. This is why it works.
In your WARP=16 …