[FSDP2, Do not merge] Refactor #3585
Draft
+179
−73
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
A somewhat big PR, started as enabling fp8 with torchao, ended up realising fsdp2 is too b*tchy about order of operations, so went for a bigger refactor and few changes, going over them biggest -> smallest
Move whole FSDP2 preparation to
accelerator._prepare_fsdp2
such as in other cases as deepspeed - this enables bigger freedom in the future, when composing FSDP2 with other features. This has proved to be beneficial already where this enabled torchao fp8 supportFSDP2 specific compile - we now compile after model converters (AC/FP8) but before FSDP2, this fixes compile issues and is actually how torchtitan does it
Move optimiser parameter switch to be FSDP2 specific - this can now reside in the method mentioned above, enabling us to do simpler fsdp2 specific things. Now we canonicalise names for old and new params (not 100% sure about the reach of this, but proved to be ok with compile + AC + fp8)
Closely related to 2, we now do a bit different order of operations, i.e. applying activation checkpointing before compile, that before fully_shard
By accident, this PR also fixes FP8 with torchao, so the changes are included in there (should be minor)
I have been extensively testing this for a ~week, so I'm quite confident, though you never know. If @winglian had any time to test if this doesn't break anything downstream, would very much appreciate it.