Skip to content

Commit d438559

Browse files
committed
update comment
1 parent 9785d0a commit d438559

File tree

1 file changed

+19
-7
lines changed
  • src/python/py/models/builders

1 file changed

+19
-7
lines changed

src/python/py/models/builders/qwen.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
3333

3434
# The HF model (Qwen2RMSNorm) *always* computes LayerNorm in float32.
3535
# By inheriting from `base.Model`, all `layernorm_attrs["cast"]` flags
36-
# are `False`. This causes two problems:
37-
# 1. Parity Error (FP32 model): The 47% mismatch you saw.
38-
# 2. Type Mismatch Error (BF16 model): The `(float)` vs `(bfloat16)` error.
36+
# are `False`. This causes parity loss and type mismatch error.
3937
#
4038
# SOLUTION: Manually set all `cast` flags to `True`. This forces the
4139
# builder to cast bf16 inputs -> fp32, compute LN, and cast fp32
@@ -330,10 +328,24 @@ def make_dynamic_rope_caches(self, layer_id, basename):
330328
return cos_final_output, sin_final_output
331329

332330
def rotate_half(self, x_name, x_shape, basename, compute_dtype):
333-
"""
334-
Builds ONNX nodes for rotate_half(x)
335-
x_shape is [B, N, S, H]
336-
"""
331+
# Make nodes for rotate_half subgraph
332+
#
333+
# x (B, N, S, H)
334+
# |
335+
# Split
336+
# / \
337+
# / \
338+
# x1 (..., H/2) x2 (..., H/2)
339+
# | |
340+
# | Neg
341+
# | |
342+
# | -x2
343+
# \ /
344+
# \ /
345+
# Concat
346+
# |
347+
# output (..., H)
348+
337349
# Split: [B, N, S, H] -> [B, N, S, H/2], [B, N, S, H/2]
338350
split_name = f"{basename}/rotate_half/Split"
339351
split_output_0 = f"{split_name}/output_0"

0 commit comments

Comments
 (0)