Skip to content

Commit 20891af

Browse files
committed
review feedback
1 parent 7ccc607 commit 20891af

File tree

3 files changed

+158
-107
lines changed

3 files changed

+158
-107
lines changed

src/python/py/models/builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,10 @@ def create_model(
304304
elif config.architectures[0] == "SmolLM3ForCausalLM":
305305
onnx_model = SmolLM3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
306306
elif config.architectures[0] == "Qwen2_5_VLForConditionalGeneration":
307+
text_config = config.text_config
308+
for key in text_config:
309+
if not hasattr(config, key):
310+
setattr(config, key, getattr(text_config, key))
307311
print(
308312
"WARNING: This is only generating the text component of the model. Setting `--extra_options exclude_embeds=true` by default."
309313
)

src/python/py/models/builders/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def make_rope_init(self, config):
466466
"sections": config.rope_scaling["mrope_section"], # Sections for MRoPE
467467
}
468468

469-
def make_attention_init(self):
469+
def is_gqa_supported(self) -> bool:
470470
valid_gqa_configurations = {
471471
("cpu", ir.DataType.FLOAT),
472472
("cuda", ir.DataType.FLOAT16),
@@ -476,7 +476,10 @@ def make_attention_init(self):
476476
("webgpu", ir.DataType.FLOAT),
477477
("trt-rtx", ir.DataType.FLOAT16),
478478
}
479-
if (self.ep, self.io_dtype) in valid_gqa_configurations:
479+
return (self.ep, self.io_dtype) in valid_gqa_configurations
480+
481+
def make_attention_init(self):
482+
if self.is_gqa_supported():
480483
# Change model settings for GroupQueryAttention
481484
self.attention_attrs["op_type"] = "GroupQueryAttention"
482485
print("GroupQueryAttention (GQA) is used in this model.")

0 commit comments

Comments
 (0)