Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def all_gather_into_tensor(output_tensor: Tensor, input_tensor: Tensor, pg: dist
# assert input_tensor.numel() % chunk_size == 0
cupy_stream = PyTorchStreamWrapper(torch.cuda.current_stream())
# print(f'chunk size {chunk_size}')

# nvshmem_device_producer_all_gather_2d_get_block_kernel_chunked[(torch.distributed.get_world_size(pg), )](
# input_tensor,
# target_tensor,
Expand All @@ -266,7 +266,7 @@ def all_gather_into_tensor(output_tensor: Tensor, input_tensor: Tensor, pg: dist
assert target_buf_size <= buf_size
target_tensor_split = target_tensor[:target_buf_size].view(world_size, size)
signal_ptr.fill_(0)
assert world_size % 8 == 0
assert world_size % 8 == 0 or world_size < 8
grid_size = 8 if world_size == 32 else world_size
nvshmem_device_producer_all_gather_2d_get_block_kernel_chunked_synced[(grid_size, )](
sub_input_tensor,
Expand Down Expand Up @@ -312,8 +312,56 @@ def all_gather_into_tensor(output_tensor: Tensor, input_tensor: Tensor, pg: dist
# )
# nvshmem.core.quiet(stream=cupy_stream)


def all_gather_into_tensor_nccl_p2p(output_tensor: Tensor, input_tensor: Tensor, pg: dist.ProcessGroup):
buf_size = buffer_splitter.get_global_buffer_size(output_tensor.shape)

world_size = torch.distributed.get_world_size(pg)


assert buf_size % world_size == 0
local_buf_size = buf_size // world_size
target_tensor_split = output_tensor.view(world_size, -1)

batch_p2p = os.environ.get("DISABLE_BATCH_P2P", "0") != "1"
handles = []
rank = torch.distributed.get_rank(pg)
ops = []
for start in range(0, input_tensor.numel(), local_buf_size):
size = min(local_buf_size, input_tensor.numel() - start)
sub_input_tensor = input_tensor.view(-1)[start:start+size]
for i in range(world_size):
peer = i ^ rank
buf = target_tensor_split[peer, start:start+size]
if peer == rank:
buf.copy_(sub_input_tensor)
continue
def send():
with torch.cuda.nvtx.range(f"send_{rank}->{peer}"):
if batch_p2p:
ops.append(dist.P2POp(dist.isend, sub_input_tensor, peer, group=pg))
else:
torch.distributed.isend(sub_input_tensor, peer, group=pg)
def recv():
with torch.cuda.nvtx.range(f"recv_{peer}->{rank}"):
if batch_p2p:
ops.append(dist.P2POp(dist.irecv, buf, peer, group=pg))
else:
handles.append(torch.distributed.irecv(buf, peer, group=pg))
if rank < peer:
send()
recv()
else:
recv()
send()
if batch_p2p:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
else:
for handle in handles:
handle.wait()


# for chunk in range(total_chunks):
# nvshmem_device_producer_all_gather_2d_get_block_kernel[grid](
# input_tensor,
Expand Down Expand Up @@ -420,6 +468,7 @@ def some_compute(x):
x = x @ compute_param
return x

sync_input_tensor = torch.empty(world_size, dtype=torch.int32, device="cuda")

src_tensors = [torch.empty(size, dtype=dtype, device="cuda") for _ in range(cnt)]
for i in range(cnt):
Expand All @@ -429,43 +478,45 @@ def some_compute(x):

# for all_gather_func in [all_gather_into_tensor, all_gather_into_tensor_nccl_comm, all_gather_into_tensor_nccl, all_gather_into_tensor_multinode]:
comp_stream = torch.cuda.Stream()
for all_gather_func in [all_gather_into_tensor_nccl, all_gather_into_tensor]:
for all_gather_func in [all_gather_into_tensor_nccl, all_gather_into_tensor_nccl_p2p, all_gather_into_tensor]:
with torch.cuda.nvtx.range(all_gather_func.__name__):
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(cnt)]
comm_events = [torch.cuda.Event(enable_timing=True) for _ in range(cnt)]
compute_events = [torch.cuda.Event(enable_timing=True) for _ in range(cnt)]
start = torch.cuda.Event(enable_timing=True)

dsts = [torch.empty(size * group_size, dtype=dtype, device="cuda") for _ in range(cnt)]
for i in range(cnt):
torch.distributed.all_reduce(sync_input_tensor, group=group)
if i == 1:
torch.distributed.barrier(group)
start.record()
dst = torch.empty(size * group_size, dtype=dtype, device="cuda")
dst = dsts[i]
# dst_arr = [
# dst[r * size:(r + 1) * size]
# for r in range(world_size)
# ]
start_events[i].record()
comp_stream.wait_stream(torch.cuda.current_stream())
# comp_stream.wait_stream(torch.cuda.current_stream())
all_gather_func(dst, src_tensors[i], group)
with torch.cuda.stream(comp_stream):
some_compute(compute_buffer[0])
torch.cuda.current_stream().wait_stream(comp_stream)
# with torch.cuda.stream(comp_stream):
# some_compute(compute_buffer[0])
# torch.cuda.current_stream().wait_stream(comp_stream)
comm_events[i].record()
# compute_buffer[i] @ compute_param
compute_events[i].record()

# print(dst)
for r in range(group_size):
expected = group_ranks[r] * 2 + i
assert torch.eq(dst[r * size:(r + 1) * size], expected).all(), f"Rank {rank} cnt {i} r {r} dst: {dst[r * size:(r + 1) * size]}, expected: {expected}"
# assert torch.eq(dst[r * size:(r + 1) * size], expected).all(), f"Rank {rank} cnt {i} r {r} dst: {dst[r * size:(r + 1) * size]}, expected: {expected}"
end = torch.cuda.Event(enable_timing=True)
end.record()
dist.barrier()
torch.cuda.synchronize()
# print(f"Rank {rank} comm time: {[start_events[i].elapsed_time(comm_events[i]) for i in range(cnt)]}, compute time: {[comm_events[i].elapsed_time(compute_events[i]) for i in range(cnt)]}")
all_gather_payload = size * (group_size - 1)* dtype.itemsize
print(f"Rank {rank} all_gather bw: {all_gather_payload / 1024 ** 2 * (cnt - 1) / start.elapsed_time(end)}")
print(f"Rank {rank} {all_gather_func.__name__} bw: {all_gather_payload / 1024 ** 2 * (cnt - 1) / start.elapsed_time(end)}")
print(f"Total time: {start.elapsed_time(end)}")
# print(f"Rank {rank} dst: {dst}")

Expand Down
103 changes: 93 additions & 10 deletions reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,6 @@ def __init__(self, accumulation_dtype=None):
self.input_buffer = {}
self.dispatched_tasks = 0
self.accumulation_dtype = accumulation_dtype
self.rank_streams = defaultdict(lambda: torch.cuda.Stream())
self.buffer_shape_cache = {}
self.buffer_splitter = BufferSplitter()

Expand Down Expand Up @@ -607,7 +606,7 @@ def reduce_scatter_accumulation(self, key, input_tensor, pg: dist.ProcessGroup):
buf = buffer[start:start+size]
signal_ptr.fill_(0)
need_accumulation = split_trans_buffer or start + size == output_size
assert world_size % 8 == 0
assert world_size % 8 == 0 or world_size < 8
grid_size = 8 if world_size == 32 else world_size
nvshmem_reduce_scatter_kernel[(grid_size, )](
input_tensor_ptr = input_tensor_symm_split.view(-1),
Expand Down Expand Up @@ -693,6 +692,79 @@ def stop(self):
torch.distributed.barrier()
torch.cuda.synchronize()

def reduce_scatter_accumulation_nccl_p2p(self, key, input_tensor, pg: dist.ProcessGroup):
output_tensor_shape = self.infer_output_shape(input_tensor, pg)
accum_dtype = self.accumulation_dtype if self.accumulation_dtype is not None else input_tensor.dtype
if key not in self.accumulation_indices:
self.register(key, output_tensor_shape, input_tensor.dtype, accum_dtype)

world_size = torch.distributed.get_world_size(pg)
local_buf_size = self.buffer_splitter.get_local_buffer_size(output_tensor_shape, world_size)
output_size = reduce(lambda x, y: x * y, output_tensor_shape)
assert local_buf_size <= output_size

acc = self.accumulations[self.accumulation_indices[key]]
rank = torch.distributed.get_rank(pg)
buffer = self.buffers[self.buffer_indices[key]][rank]
buffer_id = self.buffer_indices[key]

split_trans_buffer = buffer.numel() < output_size
input_tensor_split = input_tensor.view(-1).view(world_size, -1)

batch_p2p = os.environ.get("DISABLE_BATCH_P2P", "0") != "1"
ops = []
handles = defaultdict(list)
for start in range(0, output_size, local_buf_size):
size = min(local_buf_size, output_size - start)
need_accumulation = split_trans_buffer or start + size == output_size
assert world_size % 8 == 0 or world_size < 8
bufs = []
for i in range(world_size):
peer = i ^ rank
peer_buffer = self.buffers[buffer_id][peer]
assert local_buf_size <= peer_buffer.numel()
if split_trans_buffer:
buf = peer_buffer[:size]
else:
buf = peer_buffer[start:start+size]
bufs.append(buf)
send_data = input_tensor_split[peer, start:start+size]
if peer == rank:
buf.copy_(send_data)
continue
def send():
with torch.cuda.nvtx.range(f"send_{rank}->{peer}"):
if batch_p2p:
ops.append(dist.P2POp(dist.isend, send_data, peer, group=pg))
else:
torch.distributed.isend(send_data, peer, group=pg)
def recv():
with torch.cuda.nvtx.range(f"recv_{peer}->{rank}"):
if batch_p2p:
ops.append(dist.P2POp(dist.irecv, buf, peer, group=pg))
else:
handle = torch.distributed.irecv(buf, peer, group=pg)
handles[peer].append(handle)
if rank < peer:
send()
recv()
else:
recv()
send()
if need_accumulation:
if batch_p2p:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
ops.clear()
for peer in range(world_size):
if peer != rank:
if not batch_p2p:
for handle in handles[peer]:
handle.wait()
handles[peer].clear()
acc[start:start+size].add_(bufs[peer])



if __name__ == "__main__":
Expand Down Expand Up @@ -793,10 +865,20 @@ def reduce_scatter_accumulation(src_tensor, dest_idx, pg: dist.ProcessGroup):
def reduce_scatter_accumulation_nccl_comm(src_tensor, dest_idx, pg: dist.ProcessGroup):
reduction_service.reduce_scatter_accumulation_nccl_comm(dest_idx, src_tensor, pg)

def reduce_scatter_accumulation_nccl_p2p(src_tensor, dest_idx, pg: dist.ProcessGroup):
reduction_service.reduce_scatter_accumulation_nccl_p2p(dest_idx, src_tensor, pg)

sync_input_tensor = torch.empty(world_size, dtype=torch.int32, device="cuda")

def wait_kernel_launch():
with torch.cuda.nvtx.range("wait_kernel_launch"):
for _ in range(5):
compute_buffer[0] @ compute_param

dist.barrier()
torch.cuda.synchronize()
comp_stream = torch.cuda.Stream()
for reduce_scatter_func in [reduce_scatter_accumulation_nccl, reduce_scatter_accumulation]:
for reduce_scatter_func in [reduce_scatter_accumulation_nccl, reduce_scatter_accumulation_nccl_p2p, reduce_scatter_accumulation]:
reduction_service.clear_accumulations()

with torch.cuda.nvtx.range(reduce_scatter_func.__name__):
Expand All @@ -807,21 +889,23 @@ def reduce_scatter_accumulation_nccl_comm(src_tensor, dest_idx, pg: dist.Process

for i in range(cnt * times):
dst_idx = i % cnt
torch.distributed.barrier(group)
# torch.distributed.barrier(group)
torch.distributed.all_reduce(sync_input_tensor, group=group)

if i == cnt:
wait_kernel_launch()
start.record()
# dst_arr = [
# dst[r * size:(r + 1) * size]
# for r in range(world_size)
# ]
start_events[i].record()
comp_stream.wait_stream(torch.cuda.current_stream())
# comp_stream.wait_stream(torch.cuda.current_stream())

reduce_scatter_func(data[i], dst_idx, group)
with torch.cuda.stream(comp_stream):
some_compute(compute_buffer[0])
torch.cuda.current_stream().wait_stream(comp_stream)
# with torch.cuda.stream(comp_stream):
# some_compute(compute_buffer[0])
# torch.cuda.current_stream().wait_stream(comp_stream)
comm_events[i].record()
# compute_buffer[dst_idx] @ compute_param
compute_events[i].record()
Expand All @@ -846,10 +930,9 @@ def reduce_scatter_accumulation_nccl_comm(src_tensor, dest_idx, pg: dist.Process
torch.cuda.synchronize()
# print(f"Rank {rank} comm time: {[start_events[i].elapsed_time(comm_events[i]) for i in range(cnt * times)]}, compute time: {[comm_events[i].elapsed_time(compute_events[i]) for i in range(cnt * times)]}")
reduce_scatter_payload = size // group_size* (group_size - 1)* data[0].dtype.itemsize
print(f"Rank {rank} reduce_scatter bw: {reduce_scatter_payload / 1024 ** 2 * (cnt * (times - 1)) / start.elapsed_time(end)}")
print(f"Rank {rank} {reduce_scatter_func.__name__} bw: {reduce_scatter_payload / 1024 ** 2 * (cnt * (times - 1)) / start.elapsed_time(end)}")
print(f"Rank {rank} Total time: {start.elapsed_time(end)}")
# print(f"Rank {rank} dst: {dst}")
# torch.cuda.current_stream().wait_stream(comp_stream)

reduction_service.stop()

Expand Down