|
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 | from dataclasses import dataclass
|
7 |
| -from typing import Optional, Tuple, Union |
| 7 | +from typing import Any, Dict, List, Optional, Tuple, Union |
8 | 8 |
|
9 | 9 | import torch
|
10 | 10 | from torch.utils._python_dispatch import (
|
|
26 | 26 | from torchao.utils import _is_float8_type, fill_defaults
|
27 | 27 |
|
28 | 28 | aten = torch.ops.aten
|
| 29 | +FLOAT8_IMPL_OPS_TABLE: Dict[Any, Any] = {} |
| 30 | + |
| 31 | + |
| 32 | +def implements(aten_ops: List[Any]): |
| 33 | + """Register aten ops to the float8 op table""" |
| 34 | + |
| 35 | + def decorator(func): |
| 36 | + for op in aten_ops: |
| 37 | + FLOAT8_IMPL_OPS_TABLE[op] = func |
| 38 | + return func |
| 39 | + |
| 40 | + return decorator |
29 | 41 |
|
30 | 42 |
|
31 | 43 | def _same_metadata(self: "Float8AQTTensorImpl", src: "Float8AQTTensorImpl") -> bool:
|
32 | 44 | # Special handling for transposed attribute
|
33 | 45 | transposed_match = (self.transposed == src.transposed) or (
|
34 | 46 | self.transposed is False and src.transposed is None
|
35 | 47 | )
|
36 |
| - |
37 | 48 | return (
|
38 | 49 | isinstance(self, Float8AQTTensorImpl)
|
39 | 50 | and isinstance(src, Float8AQTTensorImpl)
|
@@ -160,90 +171,23 @@ def __tensor_unflatten__(
|
160 | 171 | def __torch_dispatch__(cls, func, types, args, kwargs):
|
161 | 172 | kwargs = {} if kwargs is None else kwargs
|
162 | 173 |
|
163 |
| - if func is aten.detach.default: |
164 |
| - return return_and_correct_aliasing( |
165 |
| - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) |
166 |
| - ) |
167 |
| - elif func is aten.clone.default: |
168 |
| - return return_and_correct_aliasing( |
169 |
| - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) |
170 |
| - ) |
171 |
| - elif func is aten.t.default: |
172 |
| - """we don't need to repack the weight and just rely on external |
173 |
| - shape being changed and record the status of transpose/no-transpose |
174 |
| - """ |
175 |
| - args[0].transposed = not args[0].transposed |
176 |
| - return return_and_correct_aliasing(func, args, kwargs, args[0]) |
177 |
| - elif func is aten.copy_.default: |
178 |
| - self = args[0] |
179 |
| - src = args[1] |
180 |
| - if _same_metadata(self, src): |
181 |
| - self_tensors = self.__tensor_flatten__()[0] |
182 |
| - for tensor_name in self_tensors: |
183 |
| - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) |
184 |
| - return |
185 |
| - raise ValueError( |
186 |
| - f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" |
187 |
| - ) |
188 |
| - elif func in [aten.select.int, aten.index.Tensor]: |
189 |
| - return return_and_correct_aliasing( |
190 |
| - func, |
191 |
| - args, |
192 |
| - kwargs, |
193 |
| - args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), |
194 |
| - ) |
195 |
| - elif func is aten.slice.Tensor: |
196 |
| - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) |
197 |
| - if dim == 0: |
198 |
| - # TODO: scale replecation should be dependent on block size |
199 |
| - if self.scale.ndim == 1: |
200 |
| - return return_and_correct_aliasing( |
201 |
| - func, |
202 |
| - args, |
203 |
| - kwargs, |
204 |
| - args[0]._apply_fn_to_data( |
205 |
| - lambda x: aten.slice.Tensor(x, dim, start, end, step) |
206 |
| - ), |
207 |
| - ) |
208 |
| - elif self.scale.ndim == 0: |
209 |
| - return return_and_correct_aliasing( |
210 |
| - func, |
211 |
| - args, |
212 |
| - kwargs, |
213 |
| - Float8AQTTensorImpl( |
214 |
| - aten.slice.Tensor(self.float8_data, dim, start, end, step), |
215 |
| - self.scale, |
216 |
| - None, |
217 |
| - self._layout, |
218 |
| - ), |
219 |
| - ) |
220 |
| - else: |
221 |
| - raise NotImplementedError( |
222 |
| - f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" |
223 |
| - ) |
224 |
| - elif dim == 1: |
225 |
| - return return_and_correct_aliasing( |
226 |
| - func, |
227 |
| - args, |
228 |
| - kwargs, |
229 |
| - Float8AQTTensorImpl( |
230 |
| - aten.slice.Tensor( |
231 |
| - self.float8_data, dim, start, end, step |
232 |
| - ).contiguous(), |
233 |
| - self.scale, |
234 |
| - None, |
235 |
| - self._layout, |
236 |
| - ), |
237 |
| - ) |
238 |
| - else: |
239 |
| - raise NotImplementedError( |
240 |
| - f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" |
| 174 | + def allowed_subclasses(type): |
| 175 | + return ( |
| 176 | + issubclass(cls, type) |
| 177 | + or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) |
| 178 | + or issubclass( |
| 179 | + torch._subclasses.functional_tensor.FunctionalTensor, type |
241 | 180 | )
|
242 |
| - else: |
243 |
| - raise NotImplementedError( |
244 |
| - f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" |
245 | 181 | )
|
246 | 182 |
|
| 183 | + if not all(allowed_subclasses(t) for t in types): |
| 184 | + return NotImplemented |
| 185 | + |
| 186 | + if func in FLOAT8_IMPL_OPS_TABLE: |
| 187 | + return FLOAT8_IMPL_OPS_TABLE[func](func, types, args, kwargs) |
| 188 | + |
| 189 | + raise NotImplementedError(f"attempting to run {func}, this is not supported") |
| 190 | + |
247 | 191 | __torch_function__ = torch._C._disabled_torch_function_impl
|
248 | 192 |
|
249 | 193 | def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
@@ -281,6 +225,100 @@ def __repr__(self):
|
281 | 225 | )
|
282 | 226 |
|
283 | 227 |
|
| 228 | +########################## |
| 229 | +# Regsiter FP8 Ops |
| 230 | +########################## |
| 231 | + |
| 232 | + |
| 233 | +@implements([aten.detach.default, aten.alias.default, aten.clone.default]) |
| 234 | +def _(func, types, args, kwargs): |
| 235 | + return return_and_correct_aliasing( |
| 236 | + func, args, kwargs, args[0]._apply_fn_to_data(func) |
| 237 | + ) |
| 238 | + |
| 239 | + |
| 240 | +@implements([aten.t.default]) |
| 241 | +def _(func, types, args, kwargs): |
| 242 | + """we don't need to repack the weight and just rely on external |
| 243 | + shape being changed and record the status of transpose/no-transpose |
| 244 | + """ |
| 245 | + args[0].transposed = not args[0].transposed |
| 246 | + return return_and_correct_aliasing(func, args, kwargs, args[0]) |
| 247 | + |
| 248 | + |
| 249 | +@implements([aten.copy_.default]) |
| 250 | +def _(func, types, args, kwargs): |
| 251 | + self = args[0] |
| 252 | + src = args[1] |
| 253 | + if _same_metadata(self, src): |
| 254 | + self_tensors = self.__tensor_flatten__()[0] |
| 255 | + for tensor_name in self_tensors: |
| 256 | + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) |
| 257 | + return |
| 258 | + raise ValueError( |
| 259 | + f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" |
| 260 | + ) |
| 261 | + |
| 262 | + |
| 263 | +@implements([aten.select.int, aten.index.Tensor]) |
| 264 | +def _(func, types, args, kwargs): |
| 265 | + return return_and_correct_aliasing( |
| 266 | + func, |
| 267 | + args, |
| 268 | + kwargs, |
| 269 | + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), |
| 270 | + ) |
| 271 | + |
| 272 | + |
| 273 | +@implements([aten.slice.Tensor]) |
| 274 | +def _(func, types, args, kwargs): |
| 275 | + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) |
| 276 | + if dim == 0: |
| 277 | + if self.scale.numel() == 1: |
| 278 | + # Per Tensor |
| 279 | + return return_and_correct_aliasing( |
| 280 | + func, |
| 281 | + args, |
| 282 | + kwargs, |
| 283 | + Float8AQTTensorImpl( |
| 284 | + aten.slice.Tensor(self.float8_data, dim, start, end, step), |
| 285 | + self.scale, |
| 286 | + self.transposed, |
| 287 | + self._layout, |
| 288 | + ), |
| 289 | + ) |
| 290 | + elif self.scale.ndim == 2: |
| 291 | + # TODO: scale replication should be dependent on block size |
| 292 | + return return_and_correct_aliasing( |
| 293 | + func, |
| 294 | + args, |
| 295 | + kwargs, |
| 296 | + args[0]._apply_fn_to_data( |
| 297 | + lambda x: aten.slice.Tensor(x, dim, start, end, step) |
| 298 | + ), |
| 299 | + ) |
| 300 | + else: |
| 301 | + raise NotImplementedError( |
| 302 | + f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" |
| 303 | + ) |
| 304 | + elif dim == 1: |
| 305 | + return return_and_correct_aliasing( |
| 306 | + func, |
| 307 | + args, |
| 308 | + kwargs, |
| 309 | + Float8AQTTensorImpl( |
| 310 | + aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(), |
| 311 | + self.scale, |
| 312 | + self.transposed, |
| 313 | + self._layout, |
| 314 | + ), |
| 315 | + ) |
| 316 | + else: |
| 317 | + raise NotImplementedError( |
| 318 | + f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" |
| 319 | + ) |
| 320 | + |
| 321 | + |
284 | 322 | ##########################
|
285 | 323 | # Float8 Dispatch Kernels
|
286 | 324 | ##########################
|
@@ -333,13 +371,12 @@ def _linear_fp8_act_fp8_weight_impl(
|
333 | 371 | input_scale = input_tensor.tensor_impl.scale
|
334 | 372 | # Handle case where input tensor is more than 2D
|
335 | 373 | inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
|
336 |
| - |
337 | 374 | # Handle rowwise case
|
338 | 375 | if _is_rowwise_scaled(weight_tensor):
|
339 | 376 | assert _is_rowwise_scaled(input_tensor), (
|
340 | 377 | "Input tensor must be rowwise block size"
|
341 | 378 | )
|
342 |
| - w_scale = w_scale.unsqueeze(-1).T |
| 379 | + w_scale = w_scale.T |
343 | 380 | input_scale = preprocess_scale(input_scale, input_tensor.shape)
|
344 | 381 |
|
345 | 382 | # Preprocess data
|
|
0 commit comments