Skip to content

TP SP examples improvement #1354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

githubsgi
Copy link

Changing cuda to accelerator, adding ConmDebugMode to tensor_parallel_example.py, sequence_parallel_example.py, and log_utils.py .

Copy link

netlify bot commented Jun 11, 2025

Deploy Preview for pytorch-examples-preview canceled.

Name Link
🔨 Latest commit cf381e0
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-examples-preview/deploys/684a12b06909be0008ad6e09

@githubsgi githubsgi changed the title TP SP example improvement TP SP examples improvement Jun 11, 2025
output.sum().backward()
optimizer.step()
inp = torch.rand(4, 10, device=device_type)
comm_mode = CommDebugMode()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work on non cuda devices? Would be great to share some local logs of your tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gladly. Please see attached logs for H100.

Starting PyTorch TP example on rank 3.
Starting PyTorch TP example on rank 0.
06/16/2025 05:55:00 PM  Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch TP example on rank 2.
Starting PyTorch TP example on rank 1.
model ToyModel(
  (in_proj): Linear(in_features=10, out_features=32, bias=True)
  (relu): ReLU()
  (out_proj): Linear(in_features=32, out_features=5, bias=True)
)
06/16/2025 05:55:03 PM  Tensor Parallel training starting...
06/16/2025 05:55:03 PM  Tensor Parallel iter 0 completed
 rank3 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:55:03 PM  Tensor Parallel iter 1 completed
 rank0 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank2 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank1 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:55:03 PM  Tensor Parallel iter 2 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 3 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 4 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 5 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 6 completed
06/16/2025 05:55:04 PM  Tensor Parallel iter 7 completed
06/16/2025 05:55:04 PM  Tensor Parallel iter 8 completed
06/16/2025 05:55:04 PM  Tensor Parallel iter 9 completed
06/16/2025 05:55:04 PM  Tensor Parallel training completed!
[rank0]:[W616 17:55:04.791527408 ProcessGroupNCCL.cpp:1516] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Starting PyTorch Sequence Parallel example on rank 0.
06/16/2025 05:53:21 PM  Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch Sequence Parallel example on rank 3.
Starting PyTorch Sequence Parallel example on rank 2.
Starting PyTorch Sequence Parallel example on rank 1.
model ToyModel(
  (in_proj): Linear(in_features=10, out_features=32, bias=True)
  (relu): ReLU()
  (out_proj): Linear(in_features=32, out_features=5, bias=True)
)
06/16/2025 05:53:24 PM  Sequence Parallel training starting...
 rank2 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:53:25 PM  Sequence Parallel iter 0 completed
 rank0 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank1 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank3 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:53:25 PM  Sequence Parallel iter 1 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 2 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 3 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 4 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 5 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 6 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 7 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 8 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 9 completed
06/16/2025 05:53:25 PM  Sequence Parallel training completed!
[rank0]:[W616 17:53:25.948217933 ProcessGroupNCCL.cpp:1516] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant on non CUDA devices, as does this API work if you use MPS or CPU?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch,accelerator works for cuda and non-cuda GPUs and accelerators. CommDebugMode is also a PyTorch feature, so should work for all devices. If not, that would be a bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants