Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)#1217
Open
SudarkinV wants to merge 3 commits intoml-explore:mainfrom
Open
Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)#1217SudarkinV wants to merge 3 commits intoml-explore:mainfrom
SudarkinV wants to merge 3 commits intoml-explore:mainfrom
Conversation
added 3 commits
April 27, 2026 22:59
Provides a training-time backward path for gated_delta_update when the current use_kernel=False fallback (gated_delta_ops) exceeds the 36 GB unified-memory budget on Apple Silicon at T >= 2048. New files: - gated_delta_vjp.py: pure-Python chunked reference VJP with mx.checkpoint; O(T/chunk) autodiff graph. - gated_delta_vjp_metal.py: Metal backward kernel registered as mx.custom_function; reverse sweep over saved state history with threadgroup reduction. 8-11x faster than the Python reference and bit-identical gradients up to fp32 SIMD-reduction noise. Integration: - gated_delta_update() gets a new training=False argument. When set (by Qwen3.5 / Qwen3-Next self_attn at .train()), routing picks the Metal VJP first and the Python VJP as import fallback. - qwen3_5.py / qwen3_next.py linear_attn call sites set training=self.training. Tests (appended to tests/test_models.py TestModels): - test_gated_delta_vjp_forward_equivalence: Python VJP forward matches gated_delta_update(use_kernel=False). - test_gated_delta_vjp_fd_gradient: central-difference check of mx.grad vs. analytic backward on a toy shape. - test_gated_delta_vjp_metal_matches_python: Metal backward matches the Python reference up to fp32 SIMD-reduction noise. Inference path and kv-cache behaviour are unchanged.
- gated_delta_vjp / gated_delta_vjp_metal: initialize default state as fp32 to match the existing gated_delta_update path. Previously the VJP modules used q.dtype, silently downgrading the recurrent state to bf16 during Qwen3.5 / Qwen3-Next training when state=None. - gated_delta_update: training=True now selects the Metal VJP only when GPU/Metal is available, mask is None, Dk%32==0 and Dv%4==0. Otherwise it falls back to the Python VJP, which has no shape constraints and also runs on CPU. Previously training=True unconditionally invoked the Metal kernel and crashed for non-GPU runs and for shapes the kernel does not handle. - tests/test_models.py: add test_gated_delta_vjp_metal_gradients_match_python to exercise mx.grad through both the Metal and Python VJP and compare gradients for q, k, v, a, b, A_log, dt_bias, state. The previous Metal-vs-Python test only compared forward outputs, so a broken Metal backward could have passed.
The previous fix initialised the default recurrent state as fp32 to
match the existing gated_delta_update path, but the Metal forward and
backward kernels typed all state buffers as InT (the input dtype).
For bf16 inputs with state=None this produced a Metal compile error:
error: incompatible pointer types assigning to
'const device bfloat *' from 'const device float *'
caused by `S_prev_row = s_initial + ...` where s_initial points to
the fp32 state_initial while S_prev_row was declared `const device InT*`.
Fix: introduce a separate StT template parameter for state-typed
buffers and route it explicitly through both the scalar and the
vectorised forward/backward kernels. State writes (state_history,
state_out, dS_initial) are cast to StT instead of InT, and the Python
wrappers _fwd_save / _bwd publish the state dtype on the kernel
template and on the corresponding output_dtypes slots.
Also adds tests/test_gated_delta_vjp_bf16_default_state to cover the
public training route with bf16 inputs and state=None — the configuration
the previous combination would have crashed on.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)
Extends #496 with a Metal backward kernel for
gated_delta_update. Theexisting
use_kernel=Falsefallback intogated_delta_opsunblocks thetracing error but unrolls an
O(T)-node auto-diff graph, which OOMs atT ≥ 2048on a 36 GB M-series Mac — the common full-parameter / LoRAfine-tuning setup for Qwen3.5-9B and Qwen3-Next-80B.
In merged PR #997, @angeloskath noted "This does affect finetuning
fairly heavily but I think we need a kernel for that to be an enjoyable
experience anyway." This PR is that kernel.
Related context: #1206 reports Qwen3.5-9B LoRA crashing on the first
backward pass. I do not claim this PR fully fixes that hardware-specific
report, but it targets the same training-time bottleneck: avoiding an
O(T)-node autodiff graph through the gated-delta recurrence by providing
a custom VJP path.
What this PR adds
Two new modules and a small integration change:
mlx_lm/models/gated_delta_vjp.py— pure-Python reference VJPusing
mx.checkpointon fixed-size chunks.O(T/CHUNK_SIZE)autodiffgraph, verifiable against
gated_delta_ops. Used as the importfallback.
mlx_lm/models/gated_delta_vjp_metal.py— Metal backward kernelregistered via
mx.custom_function. Forward-with-save + reversesweep in the same chunked layout; threadgroup-local reduction (no
atomics, deterministic). 8–11× faster than the Python VJP at
Qwen3.5-9B shapes.
mlx_lm/models/gated_delta.py—gated_delta_update()gains atraining: bool = Falseargument that routes to the VJP path. TheMetal backward is selected only when GPU/Metal is available,
mask is None,Dk % 32 == 0andDv % 4 == 0; otherwise the call falls backto the Python VJP (which has no shape constraints and runs on CPU).
The existing
use_kerneland mask/inference behaviour is unchanged.mlx_lm/models/qwen3_5.py/mlx_lm/models/qwen3_next.py—linear_attncall site passestraining=self.training.Inference and KV-cache paths are untouched.
Correctness
New tests appended to
tests/test_models.py(TestModelsclass,unittest.TestCasestyle, matching the existingtest_gated_delta*block):
test_gated_delta_vjp_forward_equivalence— Python VJP forward isbit-exact against
gated_delta_update(use_kernel=False).test_gated_delta_vjp_fd_gradient— central-difference check ofmx.gradoutput vs. analytic backward on a toy shape (B=1, T=4, Hk=2, Hv=4, Dk=8, Dv=8, fp32), tolerance1e-3.test_gated_delta_vjp_metal_matches_python— Metal VJP forward andstate agree with the Python reference within fp32 SIMD-reduction
noise (
atol=1e-3).test_gated_delta_vjp_metal_gradients_match_python—mx.gradoutput of the Metal VJP matches the Python VJP across all eight
trainable inputs (
q, k, v, a, b, A_log, dt_bias, state) withinatol=1e-3, rtol=1e-3.All eight
test_gated_delta*tests pass locally (python -m unittest tests.test_models.TestModels -k gated_delta).Performance (Qwen3.5-9B linear_attn shape: `B=1, Hk=16, Hv=64, Dk=192,
Dv=128`, bf16)
use_kernel=False(ops) fwd+bwdEnd-to-end training (Qwen3.5-9B LoRA on 36 GB M4 Max,
max_seq=4096)500-iteration full LoRA run on the unfiltered training set,
batch=1,grad_checkpoint=true, 4 LoRA keys (q_proj, v_proj, in_proj_qkv, out_proj):Converges in this configuration; peak memory stable across the run.
Total time ≈ 84 minutes (10 s/iter). This is one observation on one
shape and is not a generic stability guarantee for all training
configurations.
Relationship to #496
PR #496 added the
use_kernel: bool = Truerouting so that.trainingfalls through togated_delta_ops. This PR reuses thatsignal — the new
training=Truepath selects the VJP module; anythingelse goes through the existing
use_kernelbranch. No behaviour changefor inference or for the
use_kernel=Falseeval path.Scope
This PR covers only the training-time VJP/backward path for
gated_delta_update.Out of scope and not included:
Files
mlx_lm/models/gated_delta_vjp.py(new, ~180 LoC)mlx_lm/models/gated_delta_vjp_metal.py(new, ~770 LoC)mlx_lm/models/gated_delta.py(+~30 LoC)mlx_lm/models/qwen3_5.py(+1 LoC)mlx_lm/models/qwen3_next.py(+1 LoC)tests/test_models.py(+~160 LoC)