-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add support for CB with native transformers #3471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add support for CB with native transformers #3471
Conversation
PR is merged to transformers! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I'm trying CB right now. Currently, when running (from here): import time
import datasets
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
torch.set_float32_matmul_precision("high")
model_id = "meta-llama/Llama-3.2-3b-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
generation_config = GenerationConfig(
max_new_tokens=512,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
num_blocks=2048,
block_size=128,
do_sample=True,
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
scheduler="prefill_first",
)
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
# --- Example 1: Simple Version using generate_batch ---
print("--- Running CB Generation Example ---")
def tokenize_function(examples):
return tokenizer(examples["question"])
tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
start_time_simple = time.time()
# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True)
batch_outputs = model.generate_batch(
inputs=simple_batch_inputs,
generation_config=generation_config,
)
end_time_simple = time.time()
for request in batch_outputs:
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
try:
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
except Exception as e:
print(f"Decoding failed for request {request}: {e}")
output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False)
if len(output_text) > 0:
print("-" * 20)
print(f"{request} Input: {input_text}")
print(f"{request} Output: {output_text}")
else:
print("", end="\r\r\r\r")
print("-" * 20)
print("--- Finished CB Generation Example ---\n\n")
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds") I'm getting
any idea why? |
Yes! That's open telemetry having issues with the cluster. cc @McPatate let's fix this! |
(TLDR if you do |
What does this PR do?
Testing this with: