Skip to content

[Bugfix][ROCm] Fix ROCm FP8 Quantization Padding Issue #18606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vllmellm
Copy link
Contributor

@vllmellm vllmellm commented May 23, 2025

Problem

When running Qwen2.5-VL-7B-Instruct with dynamic FP8 quantization on ROCm, the following error occurs:

INFO 05-22 03:06:17 [__init__.py:248] Automatically detected platform rocm.
INFO 05-22 03:06:20 [__init__.py:30] Available plugins for group vllm.general_plugins:
INFO 05-22 03:06:20 [__init__.py:32] name=lora_filesystem_resolver, value=vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver
INFO 05-22 03:06:20 [__init__.py:34] all available plugins for group vllm.general_plugins will be loaded.
INFO 05-22 03:06:20 [__init__.py:36] set environment variable VLLM_PLUGINS to control which plugins to load.
INFO 05-22 03:06:20 [__init__.py:44] plugin lora_filesystem_resolver loaded.
INFO 05-22 03:06:22 [api_server.py:1289] vLLM API server version 0.9.1.dev17+g3b17ea26e.d20250521
INFO 05-22 03:06:22 [cli_args.py:300] non-default args: {'host': '0.0.0.0', 'port': 19999, 'quantization': 'fp8', 'tensor_parallel_size': 2}
INFO 05-22 03:06:39 [config.py:787] This model supports multiple tasks: {'generate', 'score', 'classify', 'reward', 'embed'}. Defaulting to 'generate'.
WARNING 05-22 03:06:39 [arg_utils.py:1595] Detected VLLM_USE_V1=1 with rocm. Usage should be considered experimental. Please report any issues on Github.
INFO 05-22 03:06:39 [config.py:1869] Defaulting to use mp for distributed inference
INFO 05-22 03:06:39 [config.py:2112] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 05-22 03:06:52 [__init__.py:248] Automatically detected platform rocm.
INFO 05-22 03:06:54 [core.py:427] Waiting for init message from front-end.
INFO 05-22 03:06:54 [core.py:61] Initializing a V1 LLM engine (v0.9.1.dev17+g3b17ea26e.d20250521) with config: model='Qwen/Qwen2.5-VL-7B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-VL-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=128000, download_dir=None, load_format=auto, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen2.5-VL-7B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level": 3, "custom_ops": ["none"], "splitting_ops": ["vllm.unified_attention", "vllm.unified_attention_with_output"], "compile_sizes": [], "inductor_compile_config": {"enable_auto_functionalized_v2": false}, "use_cudagraph": true, "cudagraph_num_of_warmups": 1, "cudagraph_capture_sizes": [512, 504, 496, 488, 480, 472, 464, 456, 448, 440, 432, 424, 416, 408, 400, 392, 384, 376, 368, 360, 352, 344, 336, 328, 320, 312, 304, 296, 288, 280, 272, 264, 256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 4, 2, 1], "max_capture_size": 512}
WARNING 05-22 03:06:54 [multiproc_worker_utils.py:306] Reducing Torch parallelism from 192 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 05-22 03:06:54 [shm_broadcast.py:266] vLLM message queue communication handle: Handle(local_reader_ranks=[0, 1], buffer_handle=(2, 10485760, 10, 'psm_21a94d62'), local_subscribe_addr='ipc:///tmp/47d1bf25-8a71-47f9-b888-8a017f4f5e9d', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 05-22 03:07:06 [__init__.py:248] Automatically detected platform rocm.
INFO 05-22 03:07:06 [__init__.py:248] Automatically detected platform rocm.
INFO 05-22 03:07:09 [__init__.py:30] Available plugins for group vllm.general_plugins:
INFO 05-22 03:07:09 [__init__.py:32] name=lora_filesystem_resolver, value=vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver
INFO 05-22 03:07:09 [__init__.py:34] all available plugins for group vllm.general_plugins will be loaded.
INFO 05-22 03:07:09 [__init__.py:36] set environment variable VLLM_PLUGINS to control which plugins to load.
INFO 05-22 03:07:09 [__init__.py:44] plugin lora_filesystem_resolver loaded.
INFO 05-22 03:07:09 [__init__.py:30] Available plugins for group vllm.general_plugins:
INFO 05-22 03:07:09 [__init__.py:32] name=lora_filesystem_resolver, value=vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver
INFO 05-22 03:07:09 [__init__.py:34] all available plugins for group vllm.general_plugins will be loaded.
INFO 05-22 03:07:09 [__init__.py:36] set environment variable VLLM_PLUGINS to control which plugins to load.
INFO 05-22 03:07:09 [__init__.py:44] plugin lora_filesystem_resolver loaded.
WARNING 05-22 03:07:09 [utils.py:2664] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7feddfe09db0>
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:09 [shm_broadcast.py:266] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_fc320423'), local_subscribe_addr='ipc:///tmp/88841b13-543b-4ea3-8f5f-c7d60f9416a9', remote_subscribe_addr=None, remote_addr_ipv6=False)
WARNING 05-22 03:07:09 [utils.py:2664] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7fea6a601de0>
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:09 [shm_broadcast.py:266] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_6fd5988b'), local_subscribe_addr='ipc:///tmp/7950fce1-9cb6-4ff9-8872-82060608855d', remote_subscribe_addr=None, remote_addr_ipv6=False)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:09 [utils.py:1071] Found nccl from library librccl.so.1
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:09 [pynccl.py:69] vLLM is using nccl==2.22.3
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:09 [utils.py:1071] Found nccl from library librccl.so.1
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:09 [pynccl.py:69] vLLM is using nccl==2.22.3
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:11 [shm_broadcast.py:266] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_2f38ec55'), local_subscribe_addr='ipc:///tmp/afcceb26-fb89-44d5-a4ba-4af74b08c92a', remote_subscribe_addr=None, remote_addr_ipv6=False)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:11 [parallel_state.py:1079] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 1, EP rank 1
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:11 [parallel_state.py:1079] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:29 [gpu_model_runner.py:1503] Starting to load model Qwen/Qwen2.5-VL-7B-Instruct...
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m WARNING 05-22 03:07:29 [rocm.py:298] Model architecture 'Qwen2ForCausalLM' is partially supported by ROCm: Sliding window attention (SWA) is not yet supported in Triton flash attention. For half-precision SWA support, please use CK flash attention by setting `VLLM_USE_TRITON_FLASH_ATTN=0`
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:29 [rocm.py:184] Using Triton Attention backend on V1 engine.
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:29 [backends.py:37] Using InductorAdaptor
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:29 [weight_utils.py:291] Using model weights format ['*.safetensors']
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:34 [gpu_model_runner.py:1503] Starting to load model Qwen/Qwen2.5-VL-7B-Instruct...
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m WARNING 05-22 03:07:35 [rocm.py:298] Model architecture 'Qwen2ForCausalLM' is partially supported by ROCm: Sliding window attention (SWA) is not yet supported in Triton flash attention. For half-precision SWA support, please use CK flash attention by setting `VLLM_USE_TRITON_FLASH_ATTN=0`
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:35 [rocm.py:184] Using Triton Attention backend on V1 engine.
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:35 [backends.py:37] Using InductorAdaptor
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:35 [default_loader.py:279] Loading weights took 5.17 seconds
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:35 [weight_utils.py:291] Using model weights format ['*.safetensors']
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:35 [gpu_model_runner.py:1521] Model loading took 8.1172 GiB and 5.852300 seconds
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:40 [default_loader.py:279] Loading weights took 5.17 seconds
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:40 [gpu_model_runner.py:1521] Model loading took 8.1172 GiB and 5.602293 seconds
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m INFO 05-22 03:07:48 [gpu_model_runner.py:1823] Encoder cache will be initialized with a budget of 98304 tokens, and profiled with 1 video items of the maximum feature size.
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m INFO 05-22 03:07:48 [gpu_model_runner.py:1823] Encoder cache will be initialized with a budget of 98304 tokens, and profiled with 1 video items of the maximum feature size.
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522] WorkerProc hit an exception.
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522] Traceback (most recent call last):
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/v1/executor/multiproc_executor.py", line 517, in worker_busy_loop
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     output = func(*args, **kwargs)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return func(*args, **kwargs)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     self.model_runner.profile_run()
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/v1/worker/gpu_model_runner.py", line 1843, in profile_run
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     dummy_encoder_outputs = self.model.get_multimodal_embeddings(
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 1045, in get_multimodal_embeddings
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     video_embeddings = self._process_video_input(multimodal_input)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 997, in _process_video_input
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     video_embeds = self.visual(pixel_values_videos,
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 728, in forward
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     hidden_states = blk(
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 401, in forward
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     x = x + self.mlp(self.norm2(x))
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 191, in forward
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     x_gate, _ = self.gate_proj(x)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/layers/linear.py", line 485, in forward
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     output_parallel = self.quant_method.apply(self, input_, bias)
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/layers/quantization/fp8.py", line 417, in apply
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return self.fp8_linear.apply(input=x,
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 418, in apply
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return w8a8_scaled_mm_func(qinput=qinput,
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 207, in rocm_per_tensor_w8a8_scaled_mm
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     output = torch._scaled_mm(qinput,
^[[1;36m(VllmWorker rank=1 pid=72392)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522] RuntimeError: mat2 shape (1280x1710) must be divisible by 16
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522] WorkerProc hit an exception.
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522] Traceback (most recent call last):
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/v1/executor/multiproc_executor.py", line 517, in worker_busy_loop
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     output = func(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return func(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     self.model_runner.profile_run()
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/v1/worker/gpu_model_runner.py", line 1843, in profile_run
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     dummy_encoder_outputs = self.model.get_multimodal_embeddings(
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 1045, in get_multimodal_embeddings
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     video_embeddings = self._process_video_input(multimodal_input)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 997, in _process_video_input
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     video_embeds = self.visual(pixel_values_videos,
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 728, in forward
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     hidden_states = blk(
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 401, in forward
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     x = x + self.mlp(self.norm2(x))
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 191, in forward
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     x_gate, _ = self.gate_proj(x)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/layers/linear.py", line 485, in forward
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     output_parallel = self.quant_method.apply(self, input_, bias)
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/layers/quantization/fp8.py", line 417, in apply
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return self.fp8_linear.apply(input=x,
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 418, in apply
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     return w8a8_scaled_mm_func(qinput=qinput,
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]   File "/app/lmcache/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 207, in rocm_per_tensor_w8a8_scaled_mm
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522]     output = torch._scaled_mm(qinput,
^[[1;36m(VllmWorker rank=0 pid=72391)^[[0;0m ERROR 05-22 03:07:53 [multiproc_executor.py:522] RuntimeError: mat2 shape (1280x1710) must be divisible by 16
ERROR 05-22 03:07:53 [core.py:489] EngineCore failed to start.
ERROR 05-22 03:07:53 [core.py:489] Traceback (most recent call last):
ERROR 05-22 03:07:53 [core.py:489]   File "/app/lmcache/vllm/vllm/v1/engine/core.py", line 480, in run_engine_core
ERROR 05-22 03:07:53 [core.py:489]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 05-22 03:07:53 [core.py:489]   File "/app/lmcache/vllm/vllm/v1/engine/core.py", line 379, in __init__
ERROR 05-22 03:07:53 [core.py:489]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 05-22 03:07:53 [core.py:489]   File "/app/lmcache/vllm/vllm/v1/engine/core.py", line 74, in __init__
ERROR 05-22 03:07:53 [core.py:489]     self._initialize_kv_caches(vllm_config)
ERROR 05-22 03:07:53 [core.py:489]   File "/app/lmcache/vllm/vllm/v1/engine/core.py", line 133, in _initialize_kv_caches
ERROR 05-22 03:07:53 [core.py:489]     available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 05-22 03:07:53 [core.py:489]   File "/app/lmcache/vllm/vllm/v1/executor/abstract.py", line 75, in determine_available_memory
ERROR 05-22 03:07:53 [core.py:489]     output = self.collective_rpc("determine_available_memory")
ERROR 05-22 03:07:53 [core.py:489]   File "/app/lmcache/vllm/vllm/v1/executor/multiproc_executor.py", line 215, in collective_rpc
ERROR 05-22 03:07:53 [core.py:489]     result = get_response(w, dequeue_timeout)
ERROR 05-22 03:07:53 [core.py:489]   File "/app/lmcache/vllm/vllm/v1/executor/multiproc_executor.py", line 202, in get_response
ERROR 05-22 03:07:53 [core.py:489]     raise RuntimeError(
ERROR 05-22 03:07:53 [core.py:489] RuntimeError: Worker failed with error 'mat2 shape (1280x1710) must be divisible by 16', please check the stack trace above for the root cause
ERROR 05-22 03:07:56 [multiproc_executor.py:135] Worker proc VllmWorker-1 died unexpectedly, shutting down executor.

This happens because torch._scaled_mm on ROCm requires tensor dimensions to be divisible by 16 which is not currently handled for FP8 methods.

Solution

  1. Added padding logic to ensure all tensors meet torch's alignment requirements:
  2. Input Tensor Padding: Pad qinput at dimension 1 to be divisible by 16
  3. Bias Padding: Pad bias tensor at dimension 0 when present
  4. Weight Padding: Use _maybe_per_tensor_padding to pad both dimensions of the weight tensor
  5. Output Trimming: Remove excess padding from output to maintain correct dimensions.

Testing

The code has been validated using Lm_eval on three modeles. The results are as follows:

Tasks Version Filter n-shot Metric Value Stderr
Qwen/Qwen3-30B-A3B
gsm8k 3 flexible-extract 5 exact_match 0.8287 ± 0.0104
strict-match 5 exact_match 0.8931 ± 0.0085
Llama-3.3-70B-Instruct
gsm8k 3 flexible-extract 5 exact_match 0.8234 ± 0.0105
strict-match 5 exact_match 0.7710 ± 0.0116
Qwen/Qwen2.5-VL-7B-Instruct
gsm8k 3 flexible-extract 5 exact_match 0.4124 ± 0.0136
strict-match 5 exact_match 0.7202 ± 0.0124

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's worth it to add a regression test for this?

@@ -384,7 +414,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks unnecessary. Won't it fail pre-commit?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants