Skip to content

Commit 378f062

Browse files
authored
Bump transformer version and fix Llava example (#12324)
A lot has changed on Llava model definition between 4.47 to 4.52, this PR: * Change the state dict key mapping to match the new Llava model definition in HF. * Use the `processor.apply_chat_template()` API to get `input_id`s so that we can be a bit more resilient to input_id format changes.
1 parent aec1322 commit 378f062

File tree

4 files changed

+58
-39
lines changed

4 files changed

+58
-39
lines changed

examples/models/llava/export_llava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def quant_embedding(model):
186186
packed=False,
187187
).quantized_model()
188188

189-
quantized_token_embed = quant_embedding(llava.model_.language_model.model)
189+
quantized_token_embed = quant_embedding(llava.model_.model.language_model)
190190
token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
191191
dynamic_shapes = [{1: token_dim_1}]
192192
with torch.no_grad():

examples/models/llava/model.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from transformers import (
3232
AutoProcessor,
3333
CLIPImageProcessor,
34-
LlamaForCausalLM,
3534
LlavaForConditionalGeneration,
3635
)
3736

@@ -104,19 +103,19 @@ def __init__(
104103

105104
def _translate_state_dict_for_text_model(self) -> Dict[str, Any]:
106105
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
107-
state_dict = self.model_.language_model.state_dict()
106+
state_dict = self.model_.state_dict()
108107
key_map = {
109108
# fmt: off
110-
r"model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.",
111-
r"model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.",
112-
r"model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.",
113-
r"model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.",
114-
r"model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.",
115-
r"model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.",
116-
r"model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.",
117-
r"model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.",
118-
r"model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.",
119-
r"model.norm.": r"norm.",
109+
r"model.language_model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.",
110+
r"model.language_model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.",
111+
r"model.language_model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.",
112+
r"model.language_model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.",
113+
r"model.language_model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.",
114+
r"model.language_model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.",
115+
r"model.language_model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.",
116+
r"model.language_model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.",
117+
r"model.language_model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.",
118+
r"model.language_model.norm.": r"norm.",
120119
# r"model.embed_tokens.": r"tok_embeddings.", # load separately
121120
r"lm_head.": r"output.",
122121
# fmt: on
@@ -157,7 +156,7 @@ def get_model(self):
157156

158157
def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
159158
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
160-
return self.model_.language_model.model.embed_tokens(tokens)
159+
return self.model_.language_model.embed_tokens(tokens)
161160

162161
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
163162
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `dtype`.
@@ -289,13 +288,8 @@ def prefill_ref(
289288
"""Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
290289
embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
291290
# pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `LlamaForCausalLM`.
292-
return LlamaForCausalLM.forward(
293-
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
294-
self.model_.language_model,
295-
inputs_embeds=embeds,
296-
return_dict=False,
297-
use_cache=False,
298-
output_hidden_states=False,
291+
return self.model_.forward(
292+
inputs_embeds=embeds, use_cache=False, return_dict=False, logits_to_keep=1
299293
)
300294

301295
def forward(
@@ -309,25 +303,42 @@ class LlavaModel(EagerModelBase):
309303
def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768):
310304
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
311305
self.max_seq_len = max_seq_len
312-
self.processor = AutoProcessor.from_pretrained(
313-
"llava-hf/llava-1.5-7b-hf",
314-
revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2
315-
)
316-
self.tokenizer = self.processor.tokenizer
317-
self.image_processor = self.processor.image_processor
318306
self.model = LlavaForConditionalGeneration.from_pretrained(
319307
"llava-hf/llava-1.5-7b-hf",
320308
device_map="cpu",
321309
revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2
322310
)
323-
self.image = Image.open(
324-
requests.get(
325-
"https://llava-vl.github.io/static/images/view.jpg", stream=True
326-
).raw
311+
self.processor = AutoProcessor.from_pretrained(
312+
"llava-hf/llava-1.5-7b-hf",
313+
revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2
314+
patch_size=self.model.vision_tower.config.patch_size, # Required after transformers >= 4.52.0
327315
)
328-
self.prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>
329-
What are the things I should be cautious about when I visit here? ASSISTANT:"""
316+
self.tokenizer = self.processor.tokenizer
317+
self.image_processor = self.processor.image_processor
318+
self.image_url = "https://llava-vl.github.io/static/images/view.jpg"
319+
self.image = Image.open(requests.get(self.image_url, stream=True).raw)
320+
self.system_prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. """
321+
current_template = self.processor.chat_template
322+
# Prepend the system prompt to the template
323+
new_template = self.system_prompt + current_template
324+
325+
# Set the modified template back to the tokenizer
326+
self.processor.chat_template = new_template
327+
330328
self.model_name = "llava-1.5-7b-hf"
329+
330+
self.conversation = [
331+
{
332+
"role": "user",
333+
"content": [
334+
{"type": "image", "url": self.image_url},
335+
{
336+
"type": "text",
337+
"text": "What are the things I should be cautious about when I visit here?",
338+
},
339+
],
340+
},
341+
]
331342
# set input to None and initialize them lazily
332343
self.input = None
333344
self.resized_image = None
@@ -358,11 +369,18 @@ def get_inputs_for_prefill(self):
358369
"""Returns prompts as well as image."""
359370
if self.input:
360371
return self.input
361-
self.input_ids = self.tokenizer.encode(self.prompt, return_tensors="pt").cpu()
372+
inputs = self.processor.apply_chat_template(
373+
self.conversation,
374+
add_generation_prompt=True,
375+
tokenize=True,
376+
return_dict=True,
377+
return_tensors="pt",
378+
)
379+
self.input_ids = inputs["input_ids"]
362380
index = torch.where(self.input_ids == self.model.config.image_token_index)[1]
363-
self.prompt_before_image = self.input_ids[:, :index]
381+
self.prompt_before_image = self.input_ids[:, : index[0]]
364382
# print(prompt_before_image.shape)
365-
self.prompt_after_image = self.input_ids[:, index + 1 :]
383+
self.prompt_after_image = self.input_ids[:, index[-1] + 1 :]
366384
# print(prompt_after_image.shape)
367385
self.input = (
368386
self.prompt_before_image,

examples/models/llava/test/test_llava.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ def test_prefill_logits(self):
4141
# The reference implementation in HF genetates the full logits. Get the last one.
4242
prefill_logits_ref = self.llava.prefill_ref(
4343
self.prompt_before_image, self.resized, self.prompt_after_image
44-
)[0][:, -1, :]
45-
self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2))
44+
)[0]
45+
46+
torch.testing.assert_close(prefill_logits, prefill_logits_ref.squeeze(0))
4647

4748
def test_generated_output(self):
4849
# source of truth, using HF llava

requirements-examples.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ datasets == 3.6.0 # 4.0.0 deprecates trust_remote_code and load scripts. For now
44
timm == 1.0.7
55
torchsr == 1.0.4
66
torchtune >= 0.6.1
7-
transformers ==4.47.1
7+
transformers >= 4.52.1

0 commit comments

Comments
 (0)