-
Notifications
You must be signed in to change notification settings - Fork 393
Open
Labels
Description
Hi,
I can't train the FLUX model in float8 using torch ao. The sanity check (inference) is done successfully. That means I am able to do 25 step inference in float8 and generate a good quality image. However, when it comes to the training / the backward process, the code gives me error. I can't fix it. Could you please identify the issue here?
A small code script regarding how the model is defined:
def LoadPipeline(dtype = torch.bfloat16):
pipeline = FluxKontextPipeline.from_pretrained("/home/cropy/flux_kontext",
local_files_only=True,
# quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16)
pipeline.to("cuda")
# quantize_(
# pipeline.transformer,
# float8_dynamic_activation_float8_weight(),
# )
quantize_(
pipeline.vae,
float8_dynamic_activation_float8_weight(),
)
quantize_(
pipeline.text_encoder,
float8_dynamic_activation_float8_weight(),
)
quantize_(
pipeline.text_encoder_2,
float8_dynamic_activation_float8_weight(),
)
pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
pipeline.text_encoder = torch.compile(pipeline.text_encoder, mode="max-autotune", fullgraph=True)
pipeline.text_encoder_2 = torch.compile(pipeline.text_encoder_2, mode="max-autotune", fullgraph=True)
return pipeline
----- My Model -----
self.pipeline = LoadPipeline(self.target_dtype)
self.model = self.pipeline.transformer.to(self.target_dtype)
train_fp8_config = Float8LinearConfig.from_recipe_name("rowwise") # tried tensorwise as well
convert_to_float8_training(self.model, config=train_fp8_config, module_filter_fn=module_filter_fn)The error:
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] failed while attempting to run meta for aten._scaled_mm.default
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] Traceback (most recent call last):
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] r = func(*args, **kwargs)
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] ^^^^^^^^^^^^^^^^^^^^^
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_ops.py", line 829, in __call__
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] return self._op(*args, **kwargs)
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] ^^^^^^^^^^^^^^^^^^^^^^^^^
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_meta_registrations.py", line 6448, in meta_scaled_mm
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] torch._check(
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/__init__.py", line 1684, in _check
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] _check_with(RuntimeError, cond, message)
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/__init__.py", line 1666, in _check_with
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] raise error_type(message_evaluated)
E1002 02:15:43.590000 278392 site-packages/torch/_subclasses/fake_tensor.py:2721] [2/1] RuntimeError: self must be row_major, got stride (1, 18432)
/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/autograd/graph.py:829: UserWarning: Error detected in matmul_with_hp_or_float8_argsBackward. Traceback of forward call that caused the error:
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 537, in forward
return torch.utils.checkpoint.checkpoint(
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 481, in _forward
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/diffusers/models/normalization.py", line 168, in forward
emb = self.linear(self.silu(emb))
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torchao/float8/float8_linear.py", line 264, in forward
output = matmul_with_hp_or_float8_args.apply(
(Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File "/mnt/c/Users/user_1/Desktop/rectifiedflow_main_image_editing/rectifiedflow_main_image_editing_flux/main.py", line 155, in <module>
trainer.fit(model, train_dataloader, val_dataloader)
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
call._call_and_handle_interrupt(
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
results = self._run_stage()
^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1056, in _run_stage
self.fit_loop.run()
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
self.advance()
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 455, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 150, in run
self.advance(data_fetcher)
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 320, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run
self._optimizer_step(batch_idx, closure)
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step
call._call_lightning_module_hook(
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 176, in _call_lightning_module_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/core/module.py", line 1302, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py", line 154, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py", line 76, in optimizer_step
return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py", line 123, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 133, in wrapper
return func.__get__(opt, opt.__class__)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/optim/optimizer.py", line 516, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/bitsandbytes/optim/optimizer.py", line 272, in step
loss = closure()
^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py", line 109, in _wrap_closure
closure_result = closure()
^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
self._result = self.closure(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 131, in closure
step_output = self._step_fn()
^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 391, in training_step
return self.lightning_module.training_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/user_1/Desktop/rectifiedflow_main_image_editing/rectifiedflow_main_image_editing_flux/models/models.py", line 71, in training_step
loss = self.flow.loss_fn(self.apply_model, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/user_1/Desktop/rectifiedflow_main_image_editing/rectifiedflow_main_image_editing_flux/modules/diffusion/flows.py", line 46, in loss_fn
pred = model_fn(xt, t=t, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/user_1/Desktop/rectifiedflow_main_image_editing/rectifiedflow_main_image_editing_flux/models/models.py", line 440, in apply_model
v_pred = m(hidden_states=hidden_states, timestep=t, return_dict=False, **kwargs2)[0].clone()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 802, in forward
encoder_hidden_states, hidden_states = block(
^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1871, in _call_user_compiler
raise BackendCompilerFailed(
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1846, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/__init__.py", line 2380, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 2418, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 109, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1199, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 1140, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1184, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 836, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 1262, in aot_dispatch_autograd
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 318, in aot_dispatch_autograd_graph
fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 55, in _create_graph
fx_g = make_fx(
^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 2318, in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 2250, in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 2221, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1254, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 703, in flatten_fn
tree_out = root_fn(*tree_args)
^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1312, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 720, in inner_fn
outs = fn(*args)
^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 671, in joint_helper
return _functionalized_f_helper(primals, tangents)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 419, in _functionalized_f_helper
f_outs = fn(*f_args)
^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 286, in inner_fn_with_anomaly
return inner_fn(*args)
^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 271, in inner_fn
backward_out = torch.autograd.grad(
^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/autograd/__init__.py", line 452, in grad
return handle_torch_function(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/overrides.py", line 1725, in handle_torch_function
result = mode.__torch_function__(public_api, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1360, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/autograd/__init__.py", line 503, in grad
result = _engine_run_backward(
^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/autograd/graph.py", line 829, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/autograd/function.py", line 311, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torchao/float8/float8_linear.py", line 197, in backward
grad_weight = torch.mm(
^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torchao/float8/float8_training_tensor.py", line 374, in __torch_dispatch__
return FLOAT8_OPS_TABLE[func](func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torchao/float8/float8_ops.py", line 385, in float8_mm
tensor_out = addmm_float8_unwrapped(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torchao/float8/float8_ops.py", line 72, in addmm_float8_unwrapped
output = torch._scaled_mm(
^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/functional_tensor.py", line 511, in __torch_dispatch__
outs_unwrapped = func._op_dk(
^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/utils/_stats.py", line 28, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1462, in __torch_dispatch__
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 914, in proxy_call
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_ops.py", line 829, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/utils/_stats.py", line 28, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2058, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_ops.py", line 829, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/_meta_registrations.py", line 6448, in meta_scaled_mm
torch._check(
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/__init__.py", line 1684, in _check
_check_with(RuntimeError, cond, message)
File "/home/user_1/miniconda3/envs/deneme/lib/python3.11/site-packages/torch/__init__.py", line 1666, in _check_with
raise error_type(message_evaluated)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: self must be row_major, got stride (1, 18432)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
I tried to make every input tensor contiguous for nn.Linear layers using the code below, but it didn't solve the issue.
class SafeFloat8Linear(nn.Module):
def __init__(self, float8_linear):
super().__init__()
self.inner = float8_linear
def forward(self, x, *args, **kwargs):
return self.inner(x.contiguous(), *args, **kwargs)
def wrap_float8_linears(module: nn.Module):
"""
Recursively replace all nn.Linear modules with SafeFloat8Linear-wrapped versions.
"""
for name, child in list(module.named_children()):
# If it's a Linear (or a specific float8 Linear class), wrap it
if isinstance(child, nn.Linear):
setattr(module, name, SafeFloat8Linear(child))
else:
# Recurse into children
wrap_float8_linears(child)
return module