@@ -33,9 +33,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
3333
3434 # The HF model (Qwen2RMSNorm) *always* computes LayerNorm in float32.
3535 # By inheriting from `base.Model`, all `layernorm_attrs["cast"]` flags
36- # are `False`. This causes two problems:
37- # 1. Parity Error (FP32 model): The 47% mismatch you saw.
38- # 2. Type Mismatch Error (BF16 model): The `(float)` vs `(bfloat16)` error.
36+ # are `False`. This causes parity loss and type mismatch error.
3937 #
4038 # SOLUTION: Manually set all `cast` flags to `True`. This forces the
4139 # builder to cast bf16 inputs -> fp32, compute LN, and cast fp32
@@ -330,10 +328,24 @@ def make_dynamic_rope_caches(self, layer_id, basename):
330328 return cos_final_output , sin_final_output
331329
332330 def rotate_half (self , x_name , x_shape , basename , compute_dtype ):
333- """
334- Builds ONNX nodes for rotate_half(x)
335- x_shape is [B, N, S, H]
336- """
331+ # Make nodes for rotate_half subgraph
332+ #
333+ # x (B, N, S, H)
334+ # |
335+ # Split
336+ # / \
337+ # / \
338+ # x1 (..., H/2) x2 (..., H/2)
339+ # | |
340+ # | Neg
341+ # | |
342+ # | -x2
343+ # \ /
344+ # \ /
345+ # Concat
346+ # |
347+ # output (..., H)
348+
337349 # Split: [B, N, S, H] -> [B, N, S, H/2], [B, N, S, H/2]
338350 split_name = f"{ basename } /rotate_half/Split"
339351 split_output_0 = f"{ split_name } /output_0"
0 commit comments