Skip to content

Commit cff7273

Browse files
authored
Ensure normalization does not promote to fp32 (#951)
1 parent fc7d844 commit cff7273

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

mlx_lm/models/qwen3_next.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass
6+
from functools import partial
67
from typing import Any, Dict, List, Optional, Tuple, Union
78

89
import mlx.core as mx
@@ -53,6 +54,13 @@ class ModelArgs(BaseModelArgs):
5354
full_attention_interval: int = 4
5455

5556

57+
@partial(mx.compile, shapeless=True)
58+
def _precise_swiglu(h, gate, x):
59+
gate = nn.silu(gate.astype(mx.float32))
60+
x = x.astype(mx.float32)
61+
return (gate * x).astype(h.dtype)
62+
63+
5664
class Qwen3NextRMSNormGated(nn.Module):
5765
def __init__(self, hidden_size: int, eps: float = 1e-6):
5866
super().__init__()
@@ -64,8 +72,9 @@ def __call__(
6472
) -> mx.array:
6573
x = mx.fast.rms_norm(hidden_states, self.weight, self.eps)
6674
if gate is not None:
67-
x = swiglu(gate, x)
68-
return x
75+
return _precise_swiglu(hidden_states, gate, x)
76+
else:
77+
return x.astype(hidden_states.dtype)
6978

7079

7180
class Qwen3NextAttention(nn.Module):

0 commit comments

Comments
 (0)