Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions mlx/backend/metal/kernels/fp_quantized.h
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ METAL_FUNC void fp_qmm_t_impl(
const constant int& K,
const constant int& N,
const constant int& M,
const constant int& K_eff,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
Expand Down Expand Up @@ -695,7 +696,7 @@ METAL_FUNC void fp_qmm_t_impl(

if (num_els < BM) {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
for (int k = 0; k < K_eff; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_safe(short2(BK, num_outs));
Expand All @@ -705,7 +706,7 @@ METAL_FUNC void fp_qmm_t_impl(
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
for (int k = 0; k < K_eff; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
Expand All @@ -717,7 +718,7 @@ METAL_FUNC void fp_qmm_t_impl(
}
} else {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
for (int k = 0; k < K_eff; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_safe(short2(BK, num_outs));
Expand All @@ -727,7 +728,7 @@ METAL_FUNC void fp_qmm_t_impl(
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
for (int k = 0; k < K_eff; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
Expand Down Expand Up @@ -1219,7 +1220,7 @@ template <
tid);
}
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
w, scales, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid);
}

template <
Expand Down Expand Up @@ -1486,7 +1487,61 @@ template <
s_strides,
tid);
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
w, scales, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid);
}

template <
typename T,
const int group_size,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void fp_qmm_t_splitk(
const device uint32_t* w [[buffer(0)]],
const device uint8_t* scales [[buffer(1)]],
const device T* x [[buffer(2)]],
device T* y [[buffer(3)]],
const constant int& K [[buffer(4)]],
const constant int& N [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& k_partition_size [[buffer(7)]],
const constant int& split_k_partition_stride [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;

constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int pack_factor = get_pack_factor<8, bits>();
constexpr int bytes_per_pack = get_bytes_per_pack();
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
const int k_start = tid.z * k_partition_size;
x += k_start;

auto wl = (const device uint8_t*)w;
wl += k_start * bytes_per_pack / pack_factor;
scales += k_start / group_size;
y += tid.z * static_cast<int64_t>(split_k_partition_stride);

fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
(const device uint32_t*)wl,
scales,
x,
y,
Xs,
Ws,
K,
N,
M,
k_partition_size,
tid,
lid,
simd_gid,
simd_lid);
}

template <
Expand Down
4 changes: 3 additions & 1 deletion mlx/backend/metal/kernels/fp_quantized.metal
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@

#define instantiate_quantized_all_splitk(type, mode, group_size, bits) \
instantiate_quantized_split_k(mode, qvm_split_k, type, 8, group_size, bits) \
instantiate_quantized_split_k(mode, qvm_split_k, type, 32, group_size, bits)
instantiate_quantized_split_k(mode, qvm_split_k, type, 32, group_size, bits) \
instantiate_quantized_aligned(mode, qmm_t_splitk, type, true, group_size, bits) \
instantiate_quantized_aligned(mode, qmm_t_splitk, type, false, group_size, bits)

#define instantiate_quantized_all_rhs(type, mode, group_size, bits) \
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true, mode, group_size, bits) \
Expand Down
100 changes: 94 additions & 6 deletions mlx/backend/metal/kernels/quantized.h
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,7 @@ METAL_FUNC void qmm_t_impl(
const constant int& K,
const constant int& N,
const constant int& M,
const constant int& K_eff,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
Expand Down Expand Up @@ -1156,7 +1157,7 @@ METAL_FUNC void qmm_t_impl(

if (num_els < BM) {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
for (int k = 0; k < K_eff; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_safe(short2(BK, num_outs));
Expand All @@ -1166,7 +1167,7 @@ METAL_FUNC void qmm_t_impl(
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
for (int k = 0; k < K_eff; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
Expand All @@ -1178,7 +1179,7 @@ METAL_FUNC void qmm_t_impl(
}
} else {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
for (int k = 0; k < K_eff; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_safe(short2(BK, num_outs));
Expand All @@ -1188,7 +1189,7 @@ METAL_FUNC void qmm_t_impl(
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
for (int k = 0; k < K_eff; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
Expand Down Expand Up @@ -1759,7 +1760,80 @@ template <
tid);
}
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
w,
scales,
biases,
x,
y,
Xs,
Ws,
K,
N,
M,
K,
tid,
lid,
simd_gid,
simd_lid);
}

template <
typename T,
const int group_size,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void affine_qmm_t_splitk(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& k_partition_size [[buffer(8)]],
const constant int& split_k_partition_stride [[buffer(9)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;

constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();

threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];

const int k_start = tid.z * k_partition_size;
x += k_start;

auto wl = (const device uint8_t*)w;
wl += k_start * bytes_per_pack / pack_factor;
scales += k_start / group_size;
biases += k_start / group_size;
y += tid.z * static_cast<int64_t>(split_k_partition_stride);

qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
(const device uint32_t*)wl,
scales,
biases,
x,
y,
Xs,
Ws,
K,
N,
M,
k_partition_size,
tid,
lid,
simd_gid,
simd_lid);
}

template <
Expand Down Expand Up @@ -2073,7 +2147,21 @@ template <
b_strides,
tid);
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
w,
scales,
biases,
x,
y,
Xs,
Ws,
K,
N,
M,
K,
tid,
lid,
simd_gid,
simd_lid);
}

template <
Expand Down
16 changes: 15 additions & 1 deletion mlx/backend/metal/kernels/quantized.metal
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,20 @@

#define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32)
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32) \

#define instantiate_quantized_splitk_qmm(name, type, group_size, bits, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
name, \
type, \
group_size, \
bits, \
aligned)

#define instantiate_quantized_all_splitk_qmm(type, group_size, bits) \
instantiate_quantized_splitk_qmm(affine_qmm_t_splitk, type, group_size, bits, true) \
instantiate_quantized_splitk_qmm(affine_qmm_t_splitk, type, group_size, bits, false)

#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
Expand All @@ -121,6 +134,7 @@
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_splitk_qmm(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)

#define instantiate_quantized_types(group_size, bits) \
Expand Down
Loading