11# SPDX-License-Identifier: MIT
22# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
33
4+ from aiter .jit .utils .torch_guard import torch_compile_guard
45import torch
56from torch import Tensor
6- from typing import Optional
7+ from typing import Optional , Tuple
78from ..jit .core import compile_ops
89import torch .nn .functional as F
910import functools
@@ -181,13 +182,14 @@ def raise_NotImplementedError(*a, **k):
181182 return tmp .get (qType , raise_NotImplementedError )
182183
183184
185+ @torch_compile_guard ()
184186def per_token_quant_hip (
185- x ,
186- scale = None ,
187- quant_dtype = dtypes .i8 ,
188- num_rows : Optional [torch . tensor ] = None ,
189- num_rows_factor = 1 ,
190- ):
187+ x : Tensor ,
188+ scale : Optional [ Tensor ] = None ,
189+ quant_dtype : torch . dtype = dtypes .i8 ,
190+ num_rows : Optional [Tensor ] = None ,
191+ num_rows_factor : int = 1 ,
192+ ) -> Tuple [ Tensor , Tensor ] :
191193 shape = x .shape
192194 device = x .device
193195 if scale is None :
@@ -213,15 +215,16 @@ def per_token_quant_hip(
213215 return y , scale
214216
215217
218+ @torch_compile_guard ()
216219def per_group_quant_hip (
217- x ,
218- scale = None ,
219- quant_dtype = dtypes .i8 ,
220- group_size = 128 ,
221- transpose_scale = False ,
222- num_rows : Optional [torch .tensor ] = None ,
223- num_rows_factor = 1 ,
224- ):
220+ x : Tensor ,
221+ scale : Optional [ Tensor ] = None ,
222+ quant_dtype : torch . dtype = dtypes .i8 ,
223+ group_size : int = 128 ,
224+ transpose_scale : bool = False ,
225+ num_rows : Optional [torch .Tensor ] = None ,
226+ num_rows_factor : int = 1 ,
227+ ) -> Tuple [ Tensor , Tensor ] :
225228 shape = x .shape
226229 device = x .device
227230 if scale is None :
@@ -252,7 +255,7 @@ def per_1x32_f4_quant_hip(
252255 scale = None ,
253256 quant_dtype = dtypes .fp4x2 ,
254257 shuffle = False ,
255- num_rows : Optional [torch .tensor ] = None ,
258+ num_rows : Optional [torch .Tensor ] = None ,
256259 num_rows_factor = 1 ,
257260):
258261 m , n = x .shape
@@ -302,7 +305,7 @@ def per_tensor_quant_hip(
302305 x ,
303306 scale = None ,
304307 quant_dtype = dtypes .i8 ,
305- num_rows : Optional [torch .tensor ] = None ,
308+ num_rows : Optional [torch .Tensor ] = None ,
306309 num_rows_factor = 1 ,
307310):
308311 assert num_rows is None , "num_rows is not supported for per_tensor_quant_hip"
0 commit comments