diff --git a/setup.py b/setup.py index 5560ab877e..88669e7b3b 100644 --- a/setup.py +++ b/setup.py @@ -385,20 +385,29 @@ def get_extensions(): extra_compile_args["cxx"].extend( ["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"] ) - if ( - use_cpu_kernels - and is_linux - and hasattr(torch._C._cpu, "_is_avx512_supported") - and torch._C._cpu._is_avx512_supported() - ): - extra_compile_args["cxx"].extend( - [ - "-DCPU_CAPABILITY_AVX512", - "-march=native", - "-mfma", - "-fopenmp", - ] - ) + + if use_cpu_kernels and is_linux: + if ( + hasattr(torch._C._cpu, "_is_avx512_supported") + and torch._C._cpu._is_avx512_supported() + ): + extra_compile_args["cxx"].extend( + [ + "-DCPU_CAPABILITY_AVX512", + "-march=native", + "-mfma", + "-fopenmp", + ] + ) + if ( + hasattr(torch._C._cpu, "_is_avx512_vnni_supported") + and torch._C._cpu._is_avx512_vnni_supported() + ): + extra_compile_args["cxx"].extend( + [ + "-DCPU_CAPABILITY_AVX512_VNNI", + ] + ) if debug_mode: extra_compile_args["cxx"].append("-g") diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 0435a6c59b..2bb20d5afd 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -29,6 +29,7 @@ AffineQuantizedTensor, Int4CPULayout, Int4XPULayout, + Int8DynamicActInt4WeightCPULayout, PlainLayout, QDQLayout, TensorCoreTiledLayout, @@ -70,6 +71,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_7, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_90, @@ -695,6 +697,72 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + @common_utils.parametrize("sym_quant_a", [True, False]) + def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): + if sym_quant_a and not TORCH_VERSION_AT_LEAST_2_8: + # not supported until PT 2.8 + return + device = "cpu" + m = ToyLinearModel(bias=bias).eval().to(dtype).to(device) + m2 = copy.deepcopy(m) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + + with torch.no_grad(): + # Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout + # is that the former packs two int4 weights into one int8, while the latter does not. + quantize_( + m, + Int8DynamicActivationInt4WeightConfig( + group_size=32, + layout=Int8DynamicActInt4WeightCPULayout(), + act_mapping_type=MappingType.SYMMETRIC + if sym_quant_a + else MappingType.ASYMMETRIC, + ), + ) + y, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0] + quantize_( + m2, + int8_dynamic_activation_int4_weight( + group_size=32, + layout=PlainLayout(), + act_mapping_type=MappingType.SYMMETRIC + if sym_quant_a + else MappingType.ASYMMETRIC, + ), + ) + torch._dynamo.reset() # may segfault without this + y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) + atol, rtol = 4e-7, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 1e-2, 3e-3 + elif dtype == torch.half: + atol, rtol = 6e-3, 2e-3 + assert torch.allclose(y, y2, atol=atol, rtol=rtol) + # Test get_plain by dequantize() + dqw1 = m.linear1.weight.original_weight_tensor.dequantize() + dqw2 = m.linear2.weight.original_weight_tensor.dequantize() + dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() + dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() + assert torch.allclose(dqw1, dqw1_ref) + assert torch.allclose(dqw2, dqw2_ref) + # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/da8w4_linear.cpp new file mode 100644 index 0000000000..537aa0fce9 --- /dev/null +++ b/torchao/csrc/cpu/da8w4_linear.cpp @@ -0,0 +1,745 @@ +#include +#include +#include +#include + +namespace torchao { + +namespace { + +#define BLOCK_N 32 + +static bool cpublas_checked = false; +static bool cpublas_can_pack = false; + +bool cpublas_could_pack() { + // the could_pack check requires AMX support implicitly + if (cpublas_checked) { + return cpublas_can_pack; + } + cpublas_can_pack = at::native::cpublas::could_pack(at::kByte); + cpublas_checked = true; + return cpublas_can_pack; +} + +/* +return: packed_weight, packed_scales, packed_qzeros, compensation +*/ +std::tuple +da8w4_linear_prepack_impl( + const at::Tensor& weight, + const at::Tensor& scales, + const at::Tensor& qzeros) { + // weight shape = [N, K] + // scales shape = [N, G] + // qzeros shape = [N, G] + TORCH_CHECK(weight.dim() == 2, + "DA8W4 CPU: Weight should be a 2D tensor for packing"); + TORCH_CHECK(weight.size(1) % 2 == 0, + "DA8W4 CPU: Weight should have even number of columns for packing"); + + auto new_scales = scales; + auto new_qzeros = qzeros; + if (new_scales.dim() == 1) { + new_scales.unsqueeze_(1); + } + new_scales = new_scales.to(at::kFloat); + if (new_qzeros.dim() == 1) { + new_qzeros.unsqueeze_(1); + } + new_qzeros = new_qzeros.to(at::kChar); + int N = weight.size(0); + int K = weight.size(1); + int G = scales.size(1); + int group_size = K / G; + int block_k = group_size > 128 ? 128 : group_size; + constexpr int block_n = BLOCK_N; + int Nc = N / block_n; + int Kc = K / block_k; + + // Reorder weight to [N/block_n, K/block_k, block_k, block_n] + // Reorder scales/qzeros to [N/block_n, G, block_n] + auto weight_view = weight.view({Nc, block_n, Kc, block_k}); + at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous(); + at::Tensor blocked_weight; + at::Tensor blocked_scales = new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + at::Tensor blocked_qzeros = new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + // Compensation = Σ(k)(W[k][n] - ZP[n]) for each block. + auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) - new_qzeros.view({Nc, block_n, G, -1}); + weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, block_k}); + at::Tensor compensation = weight_sub_qzero.sum(-1); + compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt); + + if (cpublas_could_pack()) { + blocked_weight = at::empty({Nc, Kc, block_k, block_n / 2}, weight.options()); + auto weight_ptr = weight_reordered.data_ptr(); + auto blocked_weight_ptr = blocked_weight.data_ptr(); + int64_t num_blocks = Nc * Kc; + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + auto in_ptr = weight_ptr + i * block_k * block_n; + auto out_ptr = blocked_weight_ptr + i * block_k * block_n / 2; + + // Reorder weight block to VNNI4 and pack two lanes along N + // N=16 viewed as two lanes: a0, ...a7, b0, ...b7 + // pack two lanes: [a0, b0], ..., [a7, b7] + // plain shape = [block_k, block_n] + // packed shape = [block_k / 4, block_n / 2, 4] viewed as [block_k, block_n / 2] + constexpr int n_group_size = 8; + constexpr int vnni_size = 4; + constexpr int n_group = block_n / n_group_size; // 4 + for (int nb = 0; nb < n_group; nb += 2) { + for (int k = 0; k < block_k; k += vnni_size) { + for (int ni = 0; ni < n_group_size; ++ni) { + for (int ki = 0; ki < vnni_size; ++ki) { + int src_idx_1 = nb * n_group_size + ni + (k + ki) * block_n; + int src_idx_2 = (nb + 1) * n_group_size + ni + (k + ki) * block_n; + int dst_idx = (nb / 2 * n_group_size + ni) * vnni_size + k * block_n / 2 + ki; + uint8_t src_1 = *(in_ptr + src_idx_1); + uint8_t src_2 = *(in_ptr + src_idx_2); + uint8_t dst = (src_1 & 0x0f) | ((src_2 & 0x0f) << 4); + *(out_ptr + dst_idx) = dst; + } + } + } + } + } + }); + } else { + // Pack weight: two int4 -> one int8 + using namespace at::indexing; + at::Tensor even_columns = + weight_reordered.index({Slice(), Slice(), Slice(), Slice(1, None, 2)}); + even_columns = even_columns.bitwise_left_shift(4); + at::Tensor odd_columns = + weight_reordered.index({Slice(), Slice(), Slice(), Slice(None, None, 2)}); + blocked_weight = even_columns.bitwise_or(odd_columns); + } + + return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales), std::move(blocked_qzeros), std::move(compensation)); +} + +template +struct ActDtype; +template<> +struct ActDtype { + using type = int8_t; +}; + +template<> +struct ActDtype { + using type = uint8_t; +}; + + +#if defined(CPU_CAPABILITY_AVX512) +inline std::array<__m256i, 2> load_zps_4vnni(const int8_t* __restrict__ zps) { + // broadcast 01234567 to + // 01234567012345670123456701234567 + __m256i vzps_low = _mm256_set1_epi64x(*reinterpret_cast(zps)); + __m256i vzps_high = _mm256_set1_epi64x(*reinterpret_cast(zps + 8)); + // shuffle from + // 01234567012345670123456701234567 + // to + // 00001111222233334444555566667777 + __m256i shuffle_mask = _mm256_set_epi8( + 7, + 7, + 7, + 7, + 6, + 6, + 6, + 6, + 5, + 5, + 5, + 5, + 4, + 4, + 4, + 4, + 3, + 3, + 3, + 3, + 2, + 2, + 2, + 2, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0); + vzps_low = _mm256_shuffle_epi8(vzps_low, shuffle_mask); + vzps_high = _mm256_shuffle_epi8(vzps_high, shuffle_mask); + return {vzps_low, vzps_high}; +} + +inline std::array<__m256i, 2> load_uint4_as_int8(const uint8_t* __restrict__ qB) { + __m256i packed = _mm256_loadu_si256(reinterpret_cast(qB)); + const __m256i low_mask = _mm256_set1_epi8(0x0f); + __m256i high = _mm256_srli_epi16(packed, 4); + high = _mm256_and_si256(high, low_mask); + __m256i low = _mm256_and_si256(packed, low_mask); + return {low, high}; +} + +template +void _dequant_weight_zp_only( + const uint8_t* __restrict__ B, + int8_t* dqB, + const int8_t* __restrict__ qzeros, + int64_t K) { + // unpack weight int8 -> two int4 + // subtract zero point + // B shape = [K, ldb] = [K, N / 2], actual shape = [K / 4, N / 2, 4] + // dqB shape = [K, N], actual shape = [K / 4, N, 4] +#pragma GCC unroll 2 + for (int n = 0; n < N; n += 16) { + auto [zps_low, zps_high] = load_zps_4vnni(&qzeros[n]); + for (int k = 0; k < K; k += 4) { + auto [vb_low, vb_high] = load_uint4_as_int8(B + ldb * k + n / 2 * 4); + vb_high = _mm256_sub_epi8(vb_high, zps_high); + vb_low = _mm256_sub_epi8(vb_low, zps_low); + // store vb to B + _mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4), vb_low); + _mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high); + } + } +} + +template +void _dequant_and_store( + float* __restrict__ output, + const int32_t* __restrict__ input, + const float* __restrict__ scale_a, + const int32_t* __restrict__ zp_a, + const float* __restrict__ scale_b, + const int32_t* __restrict__ comp_b, + int M, + int ldi, + int ldo, + int ldsa = 1) { + for (int m = 0; m < M; ++m) { + float a_scale = *(scale_a + m * ldsa); + __m512 va_scale = _mm512_set1_ps(a_scale); + int32_t a_zp; + __m512i va_zp; + if constexpr (!sym_quant_a) { + a_zp = *(zp_a + m * ldsa); + va_zp = _mm512_set1_epi32(a_zp); + } + int n = 0; +#pragma GCC unroll 2 + for (; n < N; n += 16) { + __m512i vc = _mm512_loadu_si512(input + m * ldi + n); + if constexpr (!sym_quant_a) { + __m512i vb_comp = _mm512_loadu_si512(comp_b + n); + vc = _mm512_sub_epi32(vc, _mm512_mullo_epi32(vb_comp, va_zp)); + } + __m512 vc_f = _mm512_cvtepi32_ps(vc); + __m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale); + __m512 vb_s = _mm512_loadu_ps(scale_b + n); + vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s); + if constexpr (accum) { + __m512 vo = _mm512_loadu_ps(output + m * ldo + n); + _mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul)); + } else { + _mm512_storeu_ps(output + m * ldo + n, vc_f_mul); + } + } + for (; n < N; ++n) { + float dq_val; + if constexpr (sym_quant_a) { + dq_val = (float)input[m * ldi + n] * a_scale * scale_b[n]; + } else { + dq_val = + (float)(input[m * ldi + n] - a_zp * comp_b[n]) * a_scale * scale_b[n]; + } + if constexpr (accum) { + output[m * ldo + n] += dq_val; + } else { + output[m * ldo + n] = dq_val; + } + } + } +} + +#else +template +void _dequant_weight_zp_only( + const uint8_t* B, + int8_t* dqB, + const int8_t* qzeros, + int64_t K) { + // B shape = [K, N / 2] + // dqB shape = [K, N] + for (int k = 0; k < K; ++k) { + for (int n = 0; n < N / 2; ++n) { + int32_t b = (int32_t)B[k * ldb + n]; + dqB[k * N + n * 2] = (b & 0xf) - qzeros[n]; + dqB[k * N + n * 2 + 1] = (b >> 4) - qzeros[n]; + } + } +} +#endif + +#if defined(CPU_CAPABILITY_AVX512_VNNI) +inline __m512i combine_m256i(__m256i a, __m256i b) { + __m512i c = _mm512_castsi256_si512(a); + return _mm512_inserti64x4(c, b, 1); +} + +inline __m512i combine_m256i(std::array<__m256i, 2> two_256) { + return combine_m256i(two_256[0], two_256[1]); +} + +// negate elements in a according to b's sign +static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) { + __m512i zero = _mm512_setzero_si512(); + __mmask64 blt0 = _mm512_movepi8_mask(b); + return _mm512_mask_sub_epi8(a, blt0, zero, a); +} + +template +void _dequant_gemm_accum_small_M( + float* __restrict__ C, + const uint8_t* A, + const float* scales_a, + const int32_t* qzeros_a, + const uint8_t* B, + const float* scales_b, + const int8_t* qzeros_b, + int64_t K, + int64_t lda, + int64_t ldc) { + // if sym_quant_a is true, A pointer type is passed in as uint8_t* but actually int8_t*. + + constexpr int COLS = N / 16; + // Computing compensation is faster than loading it for small M + // because it's memory bound. + __m512i ones = _mm512_set1_epi8(1); // used for computing compensation + __m512i va; + __m512i vb[COLS]; + __m512i vc[M * COLS]; + __m512 vscales[COLS]; + __m512i vzps[COLS]; + __m512i vcompensate[COLS]; + + // Load scales and zps + c10::ForcedUnroll{}([&](auto i) { + vscales[i] = _mm512_loadu_ps(scales_b + i * 16); + vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16)); + if constexpr (!sym_quant_a) { + vcompensate[i] = _mm512_setzero_epi32(); + } + }); + c10::ForcedUnroll{}( + [&](auto i) { vc[i] = _mm512_setzero_epi32(); }); + + auto compute = [&](auto i, int k) { + constexpr const int row = i / COLS; + constexpr const int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k)); + } + + if constexpr (row == 0) { + int B_offset = k * ldb + col * 16 * 2; + vb[col] = combine_m256i(load_uint4_as_int8(B + B_offset)); + vb[col] = _mm512_sub_epi8(vb[col], vzps[col]); + if constexpr (!sym_quant_a) { + vcompensate[col] = + _mm512_dpbusd_epi32(vcompensate[col], ones, vb[col]); + } + _mm_prefetch(B + B_offset + 128 * ldb, _MM_HINT_T0); + } + if constexpr (sym_quant_a) { + auto vsb = _mm512_sign_epi8(vb[col], va); + auto vabsa = _mm512_sign_epi8(va, va); + vc[i] = _mm512_dpbusds_epi32(vc[i], vabsa, vsb); + } else { + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + } + }; + + // Accumulate along k + constexpr const int unroll = 4; + int k = 0; + for (; k < K / 4 / unroll; k++) { + c10::ForcedUnroll{}([&](auto i) { + c10::ForcedUnroll{}(compute, 4 * (k * unroll + i)); + }); + } + k *= 4 * unroll; + for (; k < K; k += 4) { + c10::ForcedUnroll{}(compute, k); + } + + // Store to C + auto store = [&](auto i) { + constexpr const int row = i / COLS; + constexpr const int col = i % COLS; + // compute (qC - compensate * zp_a) * scale_a * scale_b + __m512 vc_float; + if constexpr (!sym_quant_a) { + vc[i] = _mm512_sub_epi32( + vc[i], + _mm512_mullo_epi32( + vcompensate[col], _mm512_set1_epi32(*(qzeros_a + row)))); + } + vc_float = _mm512_cvtepi32_ps(vc[i]); + vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scales_a + row))); + + vc_float = _mm512_mul_ps(vc_float, vscales[col]); + auto vc_old = _mm512_loadu_ps(C + row * ldc + col * 16); + vc_float = _mm512_add_ps(vc_float, vc_old); + _mm512_storeu_ps(C + row * ldc + col * 16, vc_float); + }; + c10::ForcedUnroll{}(store); + +} + +#define call_dequant_gemm_accum_small_M(M) \ + _dequant_gemm_accum_small_M( \ + C, \ + A, \ + scales_a, \ + qzeros_a, \ + B, \ + scales_b, \ + qzeros_b, \ + K, \ + lda, \ + ldc); +#endif + +template +void _dequant_gemm_accum( + float* C, + const uint8_t* A, + const float* scales_a, + const int32_t* qzeros_a, + const uint8_t* B, + const float* scales_b, + const int8_t* qzeros_b, + const int32_t* compensation, + int64_t M, + int64_t K, + int64_t lda, + int64_t ldc) { + // Compute GEMM int8 * int8 -> int32 + // dequant result to float by applying scales/qzeros +#if defined(CPU_CAPABILITY_AVX512_VNNI) + if (M <= 4 && cpublas_can_pack) { + switch (M) { + case 1: + call_dequant_gemm_accum_small_M(1); + return; + case 2: + call_dequant_gemm_accum_small_M(2); + return; + case 3: + call_dequant_gemm_accum_small_M(3); + return; + case 4: + call_dequant_gemm_accum_small_M(4); + return; + } + } +#endif + + int8_t dqB[K * N]; + _dequant_weight_zp_only(B, dqB, qzeros_b, K); + using Tin = typename ActDtype::type; + Tin* A_ptr = (Tin*)A; +#if defined(CPU_CAPABILITY_AVX512) + if constexpr (cpublas_can_pack) { + int32_t C_i32[M * N]; + at::native::cpublas::brgemm( + M, + N, + K, + lda, + N /*ldb*/, + N /*ldc*/, + false /* add_C */, + A_ptr, + dqB, + C_i32, + true /* is_vnni */); + _mm_prefetch(B + N * K / 2, _MM_HINT_T0); + _mm_prefetch(A + K, _MM_HINT_T0); + _dequant_and_store( + C, + C_i32, + scales_a, + qzeros_a, + scales_b, + compensation, + M, + N /*ldi*/, + ldc, + 1 /*ldsa*/); + } else +#endif + { + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0; + for (int64_t k = 0; k < K; ++k) { + if constexpr (sym_quant_a) { + sum += ((int32_t)A_ptr[i * lda + k] * dqB[k * N + j]); + } else { + sum += ((int32_t)A_ptr[i * lda + k] - qzeros_a[i]) * (int32_t)dqB[k * N + j]; + } + } + C[i * ldc + j] += sum * scales_a[i] * scales_b[j]; + } + } + } +} + +template +inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) { + if (bias_ptr) { + for (int i = 0; i < m; ++i) { + int j = 0; +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 bias_vec = _mm512_loadu_ps(bias_ptr + j); + _mm512_storeu_ps(y_buf + i * N + j, bias_vec); + } +#endif + for (; j < N; ++j) { + y_buf[i * N + j] = bias_ptr[j]; + } + } + } else { // initialize to zero + for (int i = 0; i < m; ++i) { + int j = 0; +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 zero_vec = _mm512_setzero_ps(); + _mm512_storeu_ps(y_buf + i * N + j, zero_vec); + } +#endif + for (; j < N; ++j) { + y_buf[i * N + j] = 0; + } + } + } +} + +template +inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m, /* int64_t n, */ int64_t lda) { + for (int i = 0; i < m; ++i) { + int j = 0; + if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = y_buf[i * N + j]; + } + } else if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]); + } + } else if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]); + } + } else { + TORCH_CHECK(false, "Unsupported output dtype"); + } + } +} + +template +void _da8w4_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& input_qzeros, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const at::Tensor& weight_qzeros, + const at::Tensor& compensation, + const std::optional& bias, + at::Tensor& output) { + // input shape = [..., K] + // input is per token quantized + int64_t K = input.size(-1); + auto input_view = input.view({-1, K}); + int64_t M = input_view.size(0); + TORCH_CHECK(input_scales.numel() == M, "DA8W4: unexpected input scales shape"); + TORCH_CHECK(input_scales.sizes() == input_qzeros.sizes(), "DA8W4: unexpected input qzeros shape"); + + // weight shape = [Nc, Kc, block_k, block_n/2] + // scales/qzeros shape = [Nc, G, block_n] + // compensation shape = [Nc, Kc, block_n] + int64_t Nc = weight.size(0); + int64_t Kc = weight.size(1); + int64_t block_k = weight.size(2); + constexpr int64_t block_n = BLOCK_N; + TORCH_CHECK(weight.size(3) * 2 == block_n, "DA8W4: unexpected weight shape"); + int64_t N = Nc * block_n; + TORCH_CHECK(K == Kc * block_k, "DA8W4: weight and input shapes mismatch"); + int64_t block_m = [&]() -> long { + if (M <= 48) { + return M; + } else if (M < 64) { + return 32; + } else if (M < 96) { + return 48; + } else { + return 64; + } + }(); + int64_t Mc = (M + block_m - 1) / block_m; + bool parallel_on_M = M > 128; + int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc; + + // scales/qzeros shape = [Nc, G, block_n] + int64_t num_groups = weight_scales.size(1); + int64_t group_size = K / num_groups; + TORCH_CHECK(group_size % block_k == 0, + "DA8W4 CPU: group_size should be divisible by block_k"); + int64_t block_per_group = group_size / block_k; + + using Tin = typename ActDtype::type; + const Tin* a_ptr = input_view.data_ptr(); + const float* a_scales_ptr = input_scales.data_ptr(); + const int32_t* a_qzeros_ptr = sym_quant_a ? nullptr : input_qzeros.data_ptr(); + const uint8_t* b_ptr = weight.data_ptr(); + const float* b_scales_ptr = weight_scales.data_ptr(); + const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr(); + const int32_t* compensation_ptr = sym_quant_a ? nullptr : compensation.data_ptr(); + out_dtype* c_ptr = output.data_ptr(); + const float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; + + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + int64_t mc = parallel_on_M ? i / Nc : 0; + int64_t nc = parallel_on_M ? i % Nc : i; + int64_t mc_end = parallel_on_M ? mc + 1 : Mc; + + for (int mci = mc; mci < mc_end; ++mci) { + int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m; + alignas(64) float y_buf[m_size][block_n]; + // copy bias to y_buf if bias is not None + auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; + copy_bias(bias_data, y_buf[0], m_size); + for (int kci = 0; kci < Kc; ++kci) { + _dequant_gemm_accum( + y_buf[0] /*C*/, + (uint8_t*)a_ptr + mci * block_m * K + kci * block_k /*A*/, + a_scales_ptr + mci * block_m /*scales_a*/, + a_qzeros_ptr + mci * block_m /*qzeros_a*/, + b_ptr + (nc * Kc + kci) * block_n * block_k / 2 /*B*/, + b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n /*scales_b*/, + b_qzeros_ptr + nc * block_n * num_groups + kci / block_per_group * block_n /*qzeros_b*/, + compensation_ptr + nc * block_n * Kc + kci * block_n /*compensation*/, + m_size /*M*/, + block_k /*K*/, + K /*lda*/, + block_n /*ldc*/); + } + // store y_buf to output with dtype conversion + store_out( + y_buf[0], + c_ptr + mci * block_m * N + nc * block_n, + m_size, + N /*lda*/); + } + } + if constexpr (cpublas_can_pack) { + at::native::cpublas::brgemm_release(); + } + }); +} + +at::Tensor da8w4_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& input_qzeros, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const at::Tensor& weight_qzeros, + const at::Tensor& compensation, + const std::optional& bias, + at::ScalarType output_dtype) { + static bool cpublas_can_pack = cpublas_could_pack(); + bool sym_quant_a = input.scalar_type() == c10::kChar; + auto out_sizes = input.sizes().vec(); + int64_t N = weight.size(0) * weight.size(-1) * 2; + out_sizes.back() = N; + auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); + +#define call__da8w4_linear_impl(cpublas_can_pack, sym_quant_act) \ + AT_DISPATCH_FLOATING_TYPES_AND2( \ + at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, "da8w4_linear_cpu", [&] { \ + _da8w4_linear_impl( \ + input, \ + input_scales, \ + input_qzeros, \ + weight, \ + weight_scales, \ + weight_qzeros, \ + compensation, \ + bias, \ + output); \ + }); + + if (cpublas_can_pack) { + if (sym_quant_a) { + call__da8w4_linear_impl(true, true); + } else { + call__da8w4_linear_impl(true, false); + } + } else { + if (sym_quant_a) { + call__da8w4_linear_impl(false, true); + } else { + call__da8w4_linear_impl(false, false); + } + } + return output; +} + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::da8w4_linear_prepack_cpu", &da8w4_linear_prepack_impl); + m.impl("torchao::da8w4_linear_cpu", &da8w4_linear_impl); +} + +} // namespace torchao diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 581c3e4ecb..b0dde2cf10 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -20,6 +20,7 @@ CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, + Int8DynamicActInt4WeightCPULayout, MarlinQQQLayout, MarlinQQQTensor, MarlinSparseLayout, @@ -67,4 +68,5 @@ "FbgemmInt4Tensor", "to_fbgemm_fp8", "FbgemmFp8Tensor", + "Int8DynamicActInt4WeightCPULayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 02a2d3004a..8b028352e4 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -35,6 +35,10 @@ _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ) +from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( + _linear_int8_act_int4_weight_cpu_check, + _linear_int8_act_int4_weight_cpu_impl, +) from torchao.dtypes.uintx.gemlite_layout import ( _linear_fp_act_int4_weight_gemlite_check, _linear_fp_act_int4_weight_gemlite_impl, @@ -247,6 +251,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_bf16_act_uint4_weight_float_zero_check, _linear_bf16_act_uint4_weight_float_zero_impl, ), + ( + _linear_int8_act_int4_weight_cpu_check, + _linear_int8_act_int4_weight_cpu_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index fee6141164..6d1bc95653 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -4,6 +4,9 @@ from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) +from .dyn_int8_act_int4_wei_cpu_layout import ( + Int8DynamicActInt4WeightCPULayout, +) from .int4_cpu_layout import ( Int4CPULayout, ) @@ -48,4 +51,5 @@ "PackedLinearInt8DynamicActivationIntxWeightLayout", "QDQLayout", "Int4XPULayout", + "Int8DynamicActInt4WeightCPULayout", ] diff --git a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py new file mode 100644 index 0000000000..ced7ec0dd8 --- /dev/null +++ b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -0,0 +1,312 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import Layout, PlainLayout, is_device +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_7, + TORCH_VERSION_AT_LEAST_2_8, +) + +from .int4_cpu_layout import ( + Int4CPUAQTTensorImpl, + _is_float, +) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class Int8DynamicActInt4WeightCPULayout(Layout): + """Layout class for da8w4 CPU layout for affine quantized tensor""" + + pass + + +@register_layout(Int8DynamicActInt4WeightCPULayout) +class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): + """TensorImpl for da8w4 CPU layout for affine quantized tensor + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 2] (uint8 dtype) + It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor + qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + compensation: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + compensation: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scales = scales + self.qzeros = qzeros + self.compensation = compensation + self.transposed = transposed + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scales", "qzeros", "compensation"], [ + self.transposed, + self._layout, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scales, qzeros, compensation = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scales"], + tensor_data_dict["qzeros"], + tensor_data_dict["compensation"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scales, qzeros, compensation, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) + assert int_data.dtype == torch.uint8, "DA8W4 CPU: expects uint8 weight" + assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" + if scale.dim() == 1: + scale.unsqueeze_(-1) + scale = scale.to(torch.float) + if zero_point.dim() == 1: + zero_point.unsqueeze_(-1) + + weight_int4, scales, qzeros, compensation = ( + torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) + ) + return cls(weight_int4, scales, qzeros, compensation, False, _layout) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scales), + fn(self.qzeros), + fn(self.compensation), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = DA8W4CPUAQTTensorImpl( + args[0].packed_weight, + args[0].scales, + args[0].qzeros, + args[0].compensation, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + else: + return super().__torch_dispatch__(func, types, args, kwargs) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @property + def block_size(self): + assert len(self.packed_weight.shape) == 2 + weight_shape = self.packed_weight.shape + N = weight_shape[0] + K = weight_shape[1] * 2 + groups = self.scales.numel() // N + group_size = K // groups + return (1, group_size) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Unpack weight by linear(eye(K), packed_weight).t() + packed_w_shape = self.packed_weight.shape + if len(packed_w_shape) == 4: + K = packed_w_shape[1] * packed_w_shape[2] + else: + K = packed_w_shape[1] + x = torch.eye(K).to(torch.uint8) + x_scale = torch.ones(K).float() + x_qzero = torch.zeros(K).to(torch.int32) + w_scale = torch.ones_like(self.scales).float() + w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) + plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( + x, + x_scale, + x_qzero, + self.packed_weight, + w_scale, + w_qzero, + self.compensation, + None, # bias + torch.float, # out_dtype + ) + plain_weight = plain_weight.t().contiguous() + plain_weight = plain_weight.to(torch.int8) + + if self.scales.dim() == 2: + assert self.qzeros.dim() == 2 + plain_scales = self.scales + plain_qzeros = self.qzeros + else: + assert self.scales.dim() == 3 and self.qzeros.dim() == 3 + packed_shape = self.scales.shape # [Nc, G, block_n] + plain_scales = ( + self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + plain_qzeros = ( + self.qzeros.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + + return plain_weight, plain_scales, plain_qzeros + + +def _aqt_is_uint8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 255 + ) + + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -127 + and aqt.quant_max == 127 + ) + + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 15 + ) + + +def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): + return ( + TORCH_VERSION_AT_LEAST_2_7 + and is_device(input_tensor.device.type, "cpu") + and is_device(weight_tensor.device.type, "cpu") + and (bias is None or is_device(bias.device.type, "cpu")) + and isinstance(input_tensor, AffineQuantizedTensor) + and (_aqt_is_uint8(input_tensor) or _aqt_is_int8(input_tensor)) + and _is_float(input_tensor.dtype) + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_uint4(weight_tensor) + and _is_float(weight_tensor.dtype) + and isinstance(weight_tensor._layout, Int8DynamicActInt4WeightCPULayout) + ) + + +def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): + assert TORCH_VERSION_AT_LEAST_2_7, ( + f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" + ) + if _aqt_is_int8(input_tensor): + assert TORCH_VERSION_AT_LEAST_2_8, ( + f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" + ) + assert is_device(input_tensor.device.type, "cpu"), ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + act = act_mat.tensor_impl.int_data + act_scales = act_mat.tensor_impl.scale + act_qzeros = act_mat.tensor_impl.zero_point + + packed_weight = weight_tensor.tensor_impl.packed_weight + wei_scales = weight_tensor.tensor_impl.scales + wei_qzeros = weight_tensor.tensor_impl.qzeros + compensation = weight_tensor.tensor_impl.compensation + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act = act.reshape(-1, act.shape[-1]) + + y = torch.ops.torchao.da8w4_linear_cpu.default( + act.contiguous(), + act_scales, + act_qzeros, + packed_weight, + wei_scales, + wei_qzeros, + compensation, + bias.float() if bias is not None else bias, # requires bias to be float + orig_dtype, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index bf9446d265..da19bbc259 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -150,7 +150,7 @@ def to(self, *args, **kwargs): device = kwargs["device"] if not is_device(torch.device(self.device).type, device): raise ValueError( - f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" + f"{self.__class__.__name__} does not support conversion from {self.device} to {device}" ) return self.__class__( self.packed_weight.to(device), @@ -181,18 +181,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) - if func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - transposed = Int4CPUAQTTensorImpl( - args[0].packed_weight, - args[0].scale_and_zero, - not args[0].transposed, - args[0]._layout, - ) - return return_and_correct_aliasing(func, args, kwargs, transposed) - if func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim in [0, 1]: @@ -217,11 +205,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, sliced) else: raise NotImplementedError( - f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + f"{cls.__name__} dispatch: attempting to run {func}, with dim={dim}, that is not supported" ) raise NotImplementedError( - f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + f"{cls.__name__} dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torchao/ops.py b/torchao/ops.py index cda3746624..babe5506c0 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -64,6 +64,12 @@ lib.define( "qscaled_dot_product(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, float? scale=None, float q_scale=1.0, int q_zp=0, float k_scale=1.0, int k_zp=0, float v_scale=1.0, int v_zp=0, float a_scale=1.0, int a_zp=0, float o_scale=1.0, int o_zp=0) -> Tensor" ) +lib.define( + "da8w4_linear_prepack_cpu(Tensor weight, Tensor scales, Tensor qzeros) -> (Tensor, Tensor, Tensor, Tensor)" +) +lib.define( + "da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor" +) def register_custom_op(name): @@ -1022,3 +1028,81 @@ def meta_mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): """Meta impl for mx_fp4_bf16""" # Assume that the contraction happens in the K dim thus M,N are perserved post bit pack return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device) + + +def da8w4_linear_prepack_cpu( + weight: Tensor, + scales: Tensor, + qzeros: Tensor, +) -> Tensor: + """ + Prepack weights for DA8W4 linear operator on CPU. + Args: + weight: weight tensor. + scales: scales for weight tensor. + qzeros: zero points for weight tensor. + Returns: + packed weight, scales, and zero points. + """ + return torch.ops.torchao.da8w4_linear_prepack_cpu.default(weight, scales, qzeros) + + +@register_custom_op("torchao::da8w4_linear_prepack_cpu") +def _(weight: Tensor, scales: Tensor, qzeros: Tensor) -> Tensor: + return weight, scales, qzeros, torch.Tensor() + + +def da8w4_linear_cpu( + input: Tensor, + input_scales: Tensor, + input_qzeros: Tensor, + weight: Tensor, + weight_scales: Tensor, + weight_qzeros: Tensor, + compensation: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +): + """ + DA8W4 linear operator on CPU. + Args: + input: input tensor. + input_scales: scales for input tensor. + input_qzeros: zero points for input tensor. + weight: weight tensor. + weight_scales: scales for weight tensor. + weight_qzeros: zero points for weight tensor. + compensation: compensation tensor for weight. + bias: optional bias tensor. + out_dtype: output data type. + Returns: + output tensor in out_dtype. + """ + return torch.ops.torchao.da8w4_linear_cpu.default( + input, + input_scales, + input_qzeros, + weight, + weight_scales, + weight_qzeros, + compensation, + bias, + out_dtype, + ) + + +@register_custom_op("torchao::da8w4_linear_cpu") +def _( + input: Tensor, + input_scales: Tensor, + input_qzeros: Tensor, + weight: Tensor, + weight_scales: Tensor, + weight_qzeros: Tensor, + compensation: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +) -> Tensor: + assert weight.dim() == 4 + N = weight.size(0) * weight.size(3) * 2 + return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8b66ac84ce..7287ae2bc0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -35,6 +35,7 @@ Float8Layout, Int4CPULayout, Int4XPULayout, + Int8DynamicActInt4WeightCPULayout, MarlinQQQLayout, MarlinSparseLayout, PackedLinearInt8DynamicActivationIntxWeightLayout, @@ -660,6 +661,38 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: ) +def _uint8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.uint8 + scale_dtype = torch.float32 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int32 + quant_min = 0 + quant_max = 255 + if TORCH_VERSION_AT_LEAST_2_6: + out = to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + ) + else: + out = to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + ) + return out + + def _int8_symm_per_token_quant(x: torch.Tensor) -> torch.Tensor: mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 @@ -731,7 +764,10 @@ def _int8_dynamic_activation_int4_weight_transform( # input settings if act_mapping_type == MappingType.ASYMMETRIC: - input_quant_func = _int8_asymm_per_token_quant + if isinstance(layout, Int8DynamicActInt4WeightCPULayout): + input_quant_func = _uint8_asymm_per_token_quant + else: + input_quant_func = _int8_asymm_per_token_quant elif act_mapping_type == MappingType.SYMMETRIC: if isinstance(layout, MarlinQQQLayout): input_quant_func = _int8_symm_per_token_quant @@ -748,6 +784,16 @@ def _int8_dynamic_activation_int4_weight_transform( ) elif isinstance(layout, CutlassInt4PackedLayout): weight = _int4_symm_cutlass_quant(weight) + elif isinstance(layout, Int8DynamicActInt4WeightCPULayout): + weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype=torch.uint8, + quant_min=0, + quant_max=15, + _layout=layout, + ) else: weight = to_affine_quantized_intx( weight,