-
Notifications
You must be signed in to change notification settings - Fork 9.7k
TP SP examples improvement #1354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ Deploy Preview for pytorch-examples-preview canceled.
|
output.sum().backward() | ||
optimizer.step() | ||
inp = torch.rand(4, 10, device=device_type) | ||
comm_mode = CommDebugMode() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work on non cuda devices? Would be great to share some local logs of your tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gladly. Please see attached logs for H100.
Starting PyTorch TP example on rank 3.
Starting PyTorch TP example on rank 0.
06/16/2025 05:55:00 PM Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch TP example on rank 2.
Starting PyTorch TP example on rank 1.
model ToyModel(
(in_proj): Linear(in_features=10, out_features=32, bias=True)
(relu): ReLU()
(out_proj): Linear(in_features=32, out_features=5, bias=True)
)
06/16/2025 05:55:03 PM Tensor Parallel training starting...
06/16/2025 05:55:03 PM Tensor Parallel iter 0 completed
rank3 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_reduce: 1
BACKWARD PASS
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_reduce: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32]), torch.Size([32])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32, 10]), torch.Size([32, 10])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5]), torch.Size([5])]
sharding: [(Replicate(),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5, 32]), torch.Size([5, 32])]
sharding: [(Shard(dim=1),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
06/16/2025 05:55:03 PM Tensor Parallel iter 1 completed
rank0 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_reduce: 1
BACKWARD PASS
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_reduce: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32]), torch.Size([32])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32, 10]), torch.Size([32, 10])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5]), torch.Size([5])]
sharding: [(Replicate(),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5, 32]), torch.Size([5, 32])]
sharding: [(Shard(dim=1),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
rank2 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_reduce: 1
BACKWARD PASS
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_reduce: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32]), torch.Size([32])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32, 10]), torch.Size([32, 10])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5]), torch.Size([5])]
sharding: [(Replicate(),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5, 32]), torch.Size([5, 32])]
sharding: [(Shard(dim=1),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
rank1 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_reduce: 1
BACKWARD PASS
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_reduce: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32]), torch.Size([32])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32, 10]), torch.Size([32, 10])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5]), torch.Size([5])]
sharding: [(Replicate(),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5, 32]), torch.Size([5, 32])]
sharding: [(Shard(dim=1),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
06/16/2025 05:55:03 PM Tensor Parallel iter 2 completed
06/16/2025 05:55:03 PM Tensor Parallel iter 3 completed
06/16/2025 05:55:03 PM Tensor Parallel iter 4 completed
06/16/2025 05:55:03 PM Tensor Parallel iter 5 completed
06/16/2025 05:55:03 PM Tensor Parallel iter 6 completed
06/16/2025 05:55:04 PM Tensor Parallel iter 7 completed
06/16/2025 05:55:04 PM Tensor Parallel iter 8 completed
06/16/2025 05:55:04 PM Tensor Parallel iter 9 completed
06/16/2025 05:55:04 PM Tensor Parallel training completed!
[rank0]:[W616 17:55:04.791527408 ProcessGroupNCCL.cpp:1516] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Starting PyTorch Sequence Parallel example on rank 0.
06/16/2025 05:53:21 PM Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch Sequence Parallel example on rank 3.
Starting PyTorch Sequence Parallel example on rank 2.
Starting PyTorch Sequence Parallel example on rank 1.
model ToyModel(
(in_proj): Linear(in_features=10, out_features=32, bias=True)
(relu): ReLU()
(out_proj): Linear(in_features=32, out_features=5, bias=True)
)
06/16/2025 05:53:24 PM Sequence Parallel training starting...
rank2 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.reduce_scatter_tensor: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
06/16/2025 05:53:25 PM Sequence Parallel iter 0 completed
rank0 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.reduce_scatter_tensor: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
rank1 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.reduce_scatter_tensor: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
rank3 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
*c10d_functional.reduce_scatter_tensor: 1
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([32])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.zeros_like.default
shape: [torch.Size([5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.reduce_scatter_tensor: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
*c10d_functional.all_gather_into_tensor: 1
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
06/16/2025 05:53:25 PM Sequence Parallel iter 1 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 2 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 3 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 4 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 5 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 6 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 7 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 8 completed
06/16/2025 05:53:25 PM Sequence Parallel iter 9 completed
06/16/2025 05:53:25 PM Sequence Parallel training completed!
[rank0]:[W616 17:53:25.948217933 ProcessGroupNCCL.cpp:1516] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I meant on non CUDA devices, as does this API work if you use MPS or CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch,accelerator works for cuda and non-cuda GPUs and accelerators. CommDebugMode is also a PyTorch feature, so should work for all devices. If not, that would be a bug.
Changing cuda to accelerator, adding ConmDebugMode to tensor_parallel_example.py, sequence_parallel_example.py, and log_utils.py .