Skip to content

Hybrid sharding#3194

Open
nastya236 wants to merge 4 commits intoml-explore:mainfrom
nastya236:fsdp_ddp
Open

Hybrid sharding#3194
nastya236 wants to merge 4 commits intoml-explore:mainfrom
nastya236:fsdp_ddp

Conversation

@nastya236
Copy link
Collaborator

Added hybrid sharding to fsdp.
I refactored the code, because I thought it would be a bit more consistent, but I can revert it back and simply call mx.distributed.all_sum.

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

I left a few comments, lmk what you think. Especially about removing the communication type 🤔. I think it 's just unnecessary complexity at this point.

rank = group.rank()
world = mx.distributed.init()
N = world.size()
fsdp_group = fsdp_group or world
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is interestingly incorrect.

mx.distributed.init() doesn't return the world communicator but the first one instantiated. There is no guarantee that it would even be the same type ie world could be TCP ring and fsdp_group could be NCCL.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants