Skip to content

Commit 56c1891

Browse files
fix torch 2.9 deepseek-0528 error (ROCm#1577)
1 parent 6af8b68 commit 56c1891

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

aiter/ops/quant.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# SPDX-License-Identifier: MIT
22
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
33

4+
from aiter.jit.utils.torch_guard import torch_compile_guard
45
import torch
56
from torch import Tensor
6-
from typing import Optional
7+
from typing import Optional, Tuple
78
from ..jit.core import compile_ops
89
import torch.nn.functional as F
910
import functools
@@ -181,13 +182,14 @@ def raise_NotImplementedError(*a, **k):
181182
return tmp.get(qType, raise_NotImplementedError)
182183

183184

185+
@torch_compile_guard()
184186
def per_token_quant_hip(
185-
x,
186-
scale=None,
187-
quant_dtype=dtypes.i8,
188-
num_rows: Optional[torch.tensor] = None,
189-
num_rows_factor=1,
190-
):
187+
x: Tensor,
188+
scale: Optional[Tensor] = None,
189+
quant_dtype: torch.dtype = dtypes.i8,
190+
num_rows: Optional[Tensor] = None,
191+
num_rows_factor: int = 1,
192+
) -> Tuple[Tensor, Tensor]:
191193
shape = x.shape
192194
device = x.device
193195
if scale is None:
@@ -213,15 +215,16 @@ def per_token_quant_hip(
213215
return y, scale
214216

215217

218+
@torch_compile_guard()
216219
def per_group_quant_hip(
217-
x,
218-
scale=None,
219-
quant_dtype=dtypes.i8,
220-
group_size=128,
221-
transpose_scale=False,
222-
num_rows: Optional[torch.tensor] = None,
223-
num_rows_factor=1,
224-
):
220+
x: Tensor,
221+
scale: Optional[Tensor] = None,
222+
quant_dtype: torch.dtype = dtypes.i8,
223+
group_size: int = 128,
224+
transpose_scale: bool = False,
225+
num_rows: Optional[torch.Tensor] = None,
226+
num_rows_factor: int = 1,
227+
) -> Tuple[Tensor, Tensor]:
225228
shape = x.shape
226229
device = x.device
227230
if scale is None:
@@ -252,7 +255,7 @@ def per_1x32_f4_quant_hip(
252255
scale=None,
253256
quant_dtype=dtypes.fp4x2,
254257
shuffle=False,
255-
num_rows: Optional[torch.tensor] = None,
258+
num_rows: Optional[torch.Tensor] = None,
256259
num_rows_factor=1,
257260
):
258261
m, n = x.shape
@@ -302,7 +305,7 @@ def per_tensor_quant_hip(
302305
x,
303306
scale=None,
304307
quant_dtype=dtypes.i8,
305-
num_rows: Optional[torch.tensor] = None,
308+
num_rows: Optional[torch.Tensor] = None,
306309
num_rows_factor=1,
307310
):
308311
assert num_rows is None, "num_rows is not supported for per_tensor_quant_hip"

0 commit comments

Comments
 (0)