Skip to content

Commit a65cbb7

Browse files
authored
[zero] refactor shard and gather operation (#773)
1 parent 5a1a095 commit a65cbb7

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

colossalai/zero/shard_utils/commons.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.T
1414
num_to_pad = chunks[0].numel() - chunks[rank].numel()
1515
assert num_to_pad >= 0, num_to_pad
1616

17-
shard = chunks[rank].clone()
18-
if num_to_pad > 0:
19-
shard = F.pad(shard, [0, num_to_pad])
17+
shard = torch.zeros_like(chunks[0])
18+
length = chunks[rank].size(0)
19+
shard_temp = shard[:length]
20+
shard_temp.copy_(chunks[rank])
21+
2022
return shard, num_to_pad

colossalai/zero/shard_utils/tensor_shard_strategy.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,16 @@ def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessG
4343
if not t.is_sharded:
4444
return
4545
target_device = t.device
46-
buffer_list = []
4746
payload_numel = t.payload.numel()
4847
world_size = dist.get_world_size(process_group)
4948
rank = dist.get_rank(process_group)
50-
for i in range(world_size):
51-
if i == rank:
52-
buffer_list.append(t.payload.cuda(get_current_device()))
53-
else:
54-
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
49+
50+
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
51+
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
52+
buffer_list[rank].copy_(t.payload)
5553

5654
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
57-
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
55+
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
5856
t.reset_payload(gathered_payload)
5957
colo_model_data_tensor_move_inline(t, target_device)
6058
t.is_sharded = False

0 commit comments

Comments
 (0)