Skip to content

[Feature]: Support for vLLM integration with GRPO finetuning #1958

@AdityaKulshrestha

Description

@AdityaKulshrestha

🚀 The feature, motivation and pitch

Hi team,

Thank you for the amazing work.

I am working on GRPO finetuning with vLLM. I get some errors while following the original huggingface documentations. I guess we don't currently support it with vLLM.

I ran the vllm server through trl: trl vllm-serve --model Qwen/Qwen2.5-7B --enforce-eager 1

The serve threw an error when trying to update the weights after the backprop:

INFO:     127.0.0.1:52792 - "POST /update_named_param/ HTTP/1.1" 200 OK
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/trl/scripts/vllm_serve.py", line 303, in llm_worker
    result = method(*args, **kwargs)
  File "/root/aditya/vllm-fork/vllm/entrypoints/llm.py", line 503, in collective_rpc
    return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
  File "/root/aditya/vllm-fork/vllm/engine/llm_engine.py", line 2175, in collective_rpc
    return self.model_executor.collective_rpc(method, timeout, args,
  File "/root/aditya/vllm-fork/vllm/executor/uniproc_executor.py", line 59, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
  File "/root/aditya/vllm-fork/vllm/utils.py", line 2784, in run_method
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/trl/scripts/vllm_serve.py", line 136, in update_named_param
    self.model_runner.model.load_weights(weights=[(name, weight)])
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1943, in __getattr__
    raise AttributeError(
AttributeError: 'HpuModelAdapter' object has no attribute 'load_weights'

Alternatives

No response

Additional context

The original script threw this error:

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/aditya/physics_wallah/finetuning/grpo.py", line 244, in <module>
    trainer.train()
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/transformers/trainer.py", line 613, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/transformers/trainer.py", line 1036, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/transformers/trainer.py", line 1632, in training_step
    inputs = self._prepare_inputs(inputs)
  File "/usr/local/lib/python3.10/dist-packages/trl/extras/profiling.py", line 87, in wrapper
    return func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/trl/trainer/grpo_trainer.py", line 899, in _prepare_inputs
    accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/trl/trainer/grpo_trainer.py", line 493, in _generate_and_score_completions
    self._move_model_to_vllm()
  File "/usr/local/lib/python3.10/dist-packages/trl/extras/profiling.py", line 87, in wrapper
    return func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/trl/trainer/grpo_trainer.py", line 872, in _move_model_to_vllm
    self.vllm_client.update_named_param(name, param.data)
  File "/usr/local/lib/python3.10/dist-packages/trl/extras/vllm_client.py", line 245, in update_named_param
    self.pynccl_comm.group.barrier()
  File "/root/aditya/vllm-fork/vllm/distributed/utils.py", line 262, in barrier
    raise RuntimeError("Failed to broadcast barrier_id") from e
RuntimeError: Failed to broadcast barrier_id
Traceback (most recent call last):
  File "/root/aditya/vllm-fork/vllm/distributed/utils.py", line 260, in barrier
    barrier_id = self.broadcast_obj(None, src=0)
  File "/root/aditya/vllm-fork/vllm/distributed/utils.py", line 216, in broadcast_obj
    recv_obj = pickle.loads(self.store.get(key))
torch.distributed.DistNetworkError: failed to recv, got 0 bytes

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/aditya/physics_wallah/finetuning/grpo.py", line 244, in <module>
    trainer.train()
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/transformers/trainer.py", line 613, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/transformers/trainer.py", line 1036, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/transformers/trainer.py", line 1632, in training_step
    inputs = self._prepare_inputs(inputs)
  File "/usr/local/lib/python3.10/dist-packages/trl/extras/profiling.py", line 87, in wrapper
    return func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/trl/trainer/grpo_trainer.py", line 899, in _prepare_inputs
    accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/trl/trainer/grpo_trainer.py", line 493, in _generate_and_score_completions
    self._move_model_to_vllm()
  File "/usr/local/lib/python3.10/dist-packages/trl/extras/profiling.py", line 87, in wrapper
    return func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/trl/trainer/grpo_trainer.py", line 872, in _move_model_to_vllm
    self.vllm_client.update_named_param(name, param.data)
  File "/usr/local/lib/python3.10/dist-packages/trl/extras/vllm_client.py", line 245, in update_named_param
    self.pynccl_comm.group.barrier()
  File "/root/aditya/vllm-fork/vllm/distributed/utils.py", line 262, in barrier
    raise RuntimeError("Failed to broadcast barrier_id") from e
RuntimeError: Failed to broadcast barrier_id

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions