Skip to content

Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)#1217

Open
SudarkinV wants to merge 3 commits intoml-explore:mainfrom
SudarkinV:feat/gated-delta-vjp-narrow
Open

Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)#1217
SudarkinV wants to merge 3 commits intoml-explore:mainfrom
SudarkinV:feat/gated-delta-vjp-narrow

Conversation

@SudarkinV
Copy link
Copy Markdown

Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)

Extends #496 with a Metal backward kernel for gated_delta_update. The
existing use_kernel=False fallback into gated_delta_ops unblocks the
tracing error but unrolls an O(T)-node auto-diff graph, which OOMs at
T ≥ 2048 on a 36 GB M-series Mac — the common full-parameter / LoRA
fine-tuning setup for Qwen3.5-9B and Qwen3-Next-80B.

In merged PR #997, @angeloskath noted "This does affect finetuning
fairly heavily but I think we need a kernel for that to be an enjoyable
experience anyway."
This PR is that kernel.

Related context: #1206 reports Qwen3.5-9B LoRA crashing on the first
backward pass. I do not claim this PR fully fixes that hardware-specific
report, but it targets the same training-time bottleneck: avoiding an
O(T)-node autodiff graph through the gated-delta recurrence by providing
a custom VJP path.

What this PR adds

Two new modules and a small integration change:

  • mlx_lm/models/gated_delta_vjp.py — pure-Python reference VJP
    using mx.checkpoint on fixed-size chunks. O(T/CHUNK_SIZE) autodiff
    graph, verifiable against gated_delta_ops. Used as the import
    fallback.
  • mlx_lm/models/gated_delta_vjp_metal.py — Metal backward kernel
    registered via mx.custom_function. Forward-with-save + reverse
    sweep in the same chunked layout; threadgroup-local reduction (no
    atomics, deterministic). 8–11× faster than the Python VJP at
    Qwen3.5-9B shapes.
  • mlx_lm/models/gated_delta.pygated_delta_update() gains a
    training: bool = False argument that routes to the VJP path. The
    Metal backward is selected only when GPU/Metal is available, mask is None, Dk % 32 == 0 and Dv % 4 == 0; otherwise the call falls back
    to the Python VJP (which has no shape constraints and runs on CPU).
    The existing use_kernel and mask/inference behaviour is unchanged.
  • mlx_lm/models/qwen3_5.py / mlx_lm/models/qwen3_next.py
    linear_attn call site passes training=self.training.

Inference and KV-cache paths are untouched.

Correctness

New tests appended to tests/test_models.py (TestModels class,
unittest.TestCase style, matching the existing test_gated_delta*
block):

  • test_gated_delta_vjp_forward_equivalence — Python VJP forward is
    bit-exact against gated_delta_update(use_kernel=False).
  • test_gated_delta_vjp_fd_gradient — central-difference check of
    mx.grad output vs. analytic backward on a toy shape (B=1, T=4, Hk=2, Hv=4, Dk=8, Dv=8, fp32), tolerance 1e-3.
  • test_gated_delta_vjp_metal_matches_python — Metal VJP forward and
    state agree with the Python reference within fp32 SIMD-reduction
    noise (atol=1e-3).
  • test_gated_delta_vjp_metal_gradients_match_pythonmx.grad
    output of the Metal VJP matches the Python VJP across all eight
    trainable inputs (q, k, v, a, b, A_log, dt_bias, state) within
    atol=1e-3, rtol=1e-3.

All eight test_gated_delta* tests pass locally (python -m unittest tests.test_models.TestModels -k gated_delta).

Performance (Qwen3.5-9B linear_attn shape: `B=1, Hk=16, Hv=64, Dk=192,

Dv=128`, bf16)

T use_kernel=False (ops) fwd+bwd Python VJP fwd+bwd Metal VJP fwd+bwd Peak mem (Metal)
256 152 ms (graph-bound) 145.3 ms 13.4 ms 1.8 GB
512 304 ms 296.3 ms 28.2 ms 3.0 GB
1024 617 ms 599.6 ms 62.2 ms 4.7 GB
2048 OOM on 36 GB 1233.5 ms 149.8 ms 8.1 GB

End-to-end training (Qwen3.5-9B LoRA on 36 GB M4 Max, max_seq=4096)

500-iteration full LoRA run on the unfiltered training set, batch=1,
grad_checkpoint=true, 4 LoRA keys (q_proj, v_proj, in_proj_qkv, out_proj):

Iter Val loss Peak mem
1 0.524 9.1 GB
50 0.248 9.2 GB
100 0.121 9.2 GB
200 0.143 9.2 GB
300 0.246 13.8 GB
500 0.155 13.8 GB

Converges in this configuration; peak memory stable across the run.
Total time ≈ 84 minutes (10 s/iter). This is one observation on one
shape and is not a generic stability guarantee for all training
configurations.

Relationship to #496

PR #496 added the use_kernel: bool = True routing so that
.training falls through to gated_delta_ops. This PR reuses that
signal — the new training=True path selects the VJP module; anything
else goes through the existing use_kernel branch. No behaviour change
for inference or for the use_kernel=False eval path.

Scope

This PR covers only the training-time VJP/backward path for
gated_delta_update.

Out of scope and not included:

  • inference kernel changes
  • speculative decoding
  • prefix-scan prototypes
  • broader training or quantization changes

Files

  • mlx_lm/models/gated_delta_vjp.py (new, ~180 LoC)
  • mlx_lm/models/gated_delta_vjp_metal.py (new, ~770 LoC)
  • mlx_lm/models/gated_delta.py (+~30 LoC)
  • mlx_lm/models/qwen3_5.py (+1 LoC)
  • mlx_lm/models/qwen3_next.py (+1 LoC)
  • tests/test_models.py (+~160 LoC)

Viktor Sudarkin added 3 commits April 27, 2026 22:59
Provides a training-time backward path for gated_delta_update when
the current use_kernel=False fallback (gated_delta_ops) exceeds the
36 GB unified-memory budget on Apple Silicon at T >= 2048.

New files:
- gated_delta_vjp.py: pure-Python chunked reference VJP with
  mx.checkpoint; O(T/chunk) autodiff graph.
- gated_delta_vjp_metal.py: Metal backward kernel registered as
  mx.custom_function; reverse sweep over saved state history with
  threadgroup reduction. 8-11x faster than the Python reference
  and bit-identical gradients up to fp32 SIMD-reduction noise.

Integration:
- gated_delta_update() gets a new training=False argument. When
  set (by Qwen3.5 / Qwen3-Next self_attn at .train()), routing
  picks the Metal VJP first and the Python VJP as import fallback.
- qwen3_5.py / qwen3_next.py linear_attn call sites set
  training=self.training.

Tests (appended to tests/test_models.py TestModels):
- test_gated_delta_vjp_forward_equivalence: Python VJP forward
  matches gated_delta_update(use_kernel=False).
- test_gated_delta_vjp_fd_gradient: central-difference check of
  mx.grad vs. analytic backward on a toy shape.
- test_gated_delta_vjp_metal_matches_python: Metal backward matches
  the Python reference up to fp32 SIMD-reduction noise.

Inference path and kv-cache behaviour are unchanged.
- gated_delta_vjp / gated_delta_vjp_metal: initialize default state as
  fp32 to match the existing gated_delta_update path. Previously the VJP
  modules used q.dtype, silently downgrading the recurrent state to
  bf16 during Qwen3.5 / Qwen3-Next training when state=None.

- gated_delta_update: training=True now selects the Metal VJP only when
  GPU/Metal is available, mask is None, Dk%32==0 and Dv%4==0. Otherwise
  it falls back to the Python VJP, which has no shape constraints and
  also runs on CPU. Previously training=True unconditionally invoked
  the Metal kernel and crashed for non-GPU runs and for shapes the
  kernel does not handle.

- tests/test_models.py: add test_gated_delta_vjp_metal_gradients_match_python
  to exercise mx.grad through both the Metal and Python VJP and compare
  gradients for q, k, v, a, b, A_log, dt_bias, state. The previous
  Metal-vs-Python test only compared forward outputs, so a broken Metal
  backward could have passed.
The previous fix initialised the default recurrent state as fp32 to
match the existing gated_delta_update path, but the Metal forward and
backward kernels typed all state buffers as InT (the input dtype).
For bf16 inputs with state=None this produced a Metal compile error:

    error: incompatible pointer types assigning to
    'const device bfloat *' from 'const device float *'

caused by `S_prev_row = s_initial + ...` where s_initial points to
the fp32 state_initial while S_prev_row was declared `const device InT*`.

Fix: introduce a separate StT template parameter for state-typed
buffers and route it explicitly through both the scalar and the
vectorised forward/backward kernels. State writes (state_history,
state_out, dS_initial) are cast to StT instead of InT, and the Python
wrappers _fwd_save / _bwd publish the state dtype on the kernel
template and on the corresponding output_dtypes slots.

Also adds tests/test_gated_delta_vjp_bf16_default_state to cover the
public training route with bf16 inputs and state=None — the configuration
the previous combination would have crashed on.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant