[Metal][Performance]: Add split-K for quantized matmul (small M)#3120
[Metal][Performance]: Add split-K for quantized matmul (small M)#3120Ziqiao-git wants to merge 1 commit intoml-explore:mainfrom
Conversation
|
As a quick note on why these specific dimensions matter: The performance bottleneck for small |
|
Thanks that is great! I 'll take a look asap. |
angeloskath
left a comment
There was a problem hiding this comment.
This is great but unfortunately unfinished.
The fp quantizations are not implemented (it should be trivial to add based on the qmm_t_splitk_impl), the qmv split-k is not used by anything so it can be removed for starters or if you want you can finish the implementation.
Finally, you don't need a qmm_t_splitk_impl and a qmv_splitk_impl the point of qmv_impl, qmv_fast_impl and qmm_t_impl are that one can adjust the input matrix offsets and call the implementation. See the qvm_splitk for an example.
e7c8390 to
884b86f
Compare
|
Addressed all feedback:
Build passes, pre-commit clean, benchmark confirms split-K working correctly for both affine and fp paths. Benchmark Results (M1/M2/M3 / applegpu_g15s)Device: applegpu_g15s Memory: 52 GBqmv_batch_limit(D=4096, O=4096) = 12 ============================================================ 1 0.475ms 0.036ms 0.08x gemv qmv (Tested across affine, mxfp8, and mxfp4 - truncated for brevity but all show similar smooth transitions). Let me know if there is anything I missed. |
jagrit06
left a comment
There was a problem hiding this comment.
Requesting a few small changes, thanks for putting this together!
| // Choose split_k to target ~512 threadgroups | ||
| int bm = 32, bn = 32; | ||
| int n_tiles = (N + bn - 1) / bn; | ||
| int m_tiles = (M + bm - 1) / bm; | ||
| int current_tgs = n_tiles * m_tiles; | ||
| int split_k = std::max(1, 512 / current_tgs); | ||
|
|
||
| // Ensure K divides evenly by split_k * group_size | ||
| while (split_k > 1 && (K % (split_k * group_size) != 0)) { | ||
| split_k--; | ||
| } | ||
| if (split_k <= 1) { | ||
| return qmm( | ||
| x, w, scales, biases, out, true, group_size, bits, M, N, K, d, s, mode); | ||
| } |
There was a problem hiding this comment.
Could you help me understand why this approach was chosen ?
I can make some sense of 512 thread groups being the target (if the M or N are larger, we go to the regular matmul, if the K is larger still we just loop more inside the kernel), but would like to know how you ended up at that number
Also, we do employ checks such that K is always divisible by group_size - so wouldn't the looping step just then come down to be equivalent to split_k = std::max(split_k, K / group_size)
There was a problem hiding this comment.
For the 512 target, it was mostly an empirical choice trying to balance GPU occupancy with atomic reduction overhead. My mental math was that on the larger Apple Silicon chips (like the 76-core M2 Ultra), we generally need a handful of active threadgroups per core to effectively hide memory latency (roughly
There was a problem hiding this comment.
I originally avoided just using split_k = std::min(split_k, K / group_size) because we need to guarantee quantization group alignment, not just an upper bound.
If we strictly use std::min, it might result in a split_k that doesn't perfectly divide K into multiples of group_size. For example, if K=1024, group_size=64, and the calculated split_k is 12, the chunk size 1024 / 12 does not align with 64(the quant group). If a threadgroup's chunk starts unaligned with group_size, it will read misaligned scale and bias values, leading to incorrect numerical results.
The while loop ensures we step down until (K / split_k) % group_size == 0 so that every threadgroup boundary perfectly aligns with the quantization blocks. Let me know if you think there is a more elegant way to enforce this modulo alignment mathematically instead of the loop!
There was a problem hiding this comment.
Good point
I think the way around this would be support unaligned K_eff for the last partition (that logic doesn't exist in the QMMs but does in the regular MM code) - we certainly don't want it to be the case that you end up with a K that divides to 31 quantized groups, but doesn't end up dispatched to the split-K variant because 31 is a prime and we require the number of splits to perfectly divide K / group_size
That said, we can merge this for now and do that in a follow up PR
Add a split-K variant for quantized matrix multiplication that partitions the K dimension across threadgroups when GPU occupancy is low (small M). - Reuse qmm_t_impl with a K_eff parameter for the loop bound, pre-offset pointers in the splitk wrapper (following qvm_splitk pattern) - Remove unused qmv_split_k code - Add fp quantization support (fp_qmm_t_splitk) - Dynamic split_k selection targeting ~512 threadgroups - Fallback to regular qmm when split_k <= 1
884b86f to
c83ba7f
Compare
| int bm = 32, bn = 32; | ||
| int n_tiles = (N + bn - 1) / bn; | ||
| int m_tiles = (M + bm - 1) / bm; | ||
| int current_tgs = n_tiles * m_tiles; | ||
| int split_k = std::max(1, 512 / current_tgs); |
There was a problem hiding this comment.
For now, just so we are safe, could you add a check here to make sure that K is large enough (in relation to M and N) to warrant going through the loop of splits ?
We naturally short-circuit if n_tiles * m_tiles >= 512 so large M and N are covered - it would be good to similarly short circuit out if the K isn't too large compared to the M and N.
Till we have a fix for the loop, I would like to avoid to a 64x64x128 matmul as an example, going through a 100 iterations where it could have just short circuited earlier
After that, we should be ready to merge!
Proposed changes
In issue #3086, it was observed that the quantized
qmmkernel severely underutilizes the GPU for smallM(e.g.,M=12-32). For example, a configuration ofD=2560andM=12yields only 80 threadgroups (assumingBM=BN=32), which is insufficient to saturate the GPU grid.This PR introduces a split-K variant (
qmm_t_splitk) that partitions theKdimension across multiple threadgroups. This safely improves GPU occupancy and execution speed for small-batch inference scenarios, while falling back to the standard kernel for larger batches to prevent any performance regression.What changed
qmm_t_splitk) in the Metal backend, conceptually similar to the existingfp16steel_gemm_splitk.quantized.cppto dynamically calculate the split factor, targeting ~512 threadgroups for optimal occupancy.qmmkernel whensplit_k <= 1(e.g., for largeM).group_size=64):D=2560, M=12: 0.079ms -> 0.055ms (~30% faster)D=4096, M=16: 0.155ms -> 0.117ms (~25% faster)Mconfigurations.Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes