Skip to content

Commit 1db1c1c

Browse files
mrshenlibrianjoholly1238
authored
Add tutorial for ProcessGroup extensions (pytorch#1798)
Co-authored-by: Brian Johnson <[email protected]> Co-authored-by: Holly Sweeney <[email protected]>
1 parent f769bed commit 1db1c1c

File tree

3 files changed

+299
-3
lines changed

3 files changed

+299
-3
lines changed
Loading

index.rst

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Welcome to PyTorch Tutorials
5555
:image: _static/img/thumbnails/cropped/60-min-blitz.png
5656
:link: beginner/basics/intro.html
5757
:tags: Getting-Started
58-
58+
5959
.. customcarditem::
6060
:header: Introduction to PyTorch on YouTube
6161
:card_description: An introduction to building a complete ML workflow with PyTorch. Follows the PyTorch Beginner Series on YouTube.
@@ -120,7 +120,7 @@ Welcome to PyTorch Tutorials
120120
:image: _static/img/thumbnails/cropped/DCGAN-Tutorial.png
121121
:link: beginner/dcgan_faces_tutorial.html
122122
:tags: Image/Video
123-
123+
124124
.. customcarditem::
125125
:header: Spatial Transformer Networks Tutorial
126126
:card_description: Learn how to augment your network using a visual attention mechanism.
@@ -496,6 +496,13 @@ Welcome to PyTorch Tutorials
496496
:link: intermediate/dist_tuto.html
497497
:tags: Parallel-and-Distributed-Training
498498

499+
.. customcarditem::
500+
:header: Customize Process Group Backends Using Cpp Extensions
501+
:card_description: Extend ProcessGroup with custom collective communication implementations.
502+
:image: _static/img/thumbnails/cropped/Customize-Process-Group-Backends-Using-Cpp-Extensions.png
503+
:link: intermediate/process_group_cpp_extension_tutorial.html
504+
:tags: Parallel-and-Distributed-Training
505+
499506
.. customcarditem::
500507
:header: Getting Started with Distributed RPC Framework
501508
:card_description: Learn how to build distributed training using the torch.distributed.rpc package.
@@ -646,7 +653,7 @@ Additional Resources
646653
beginner/basics/autogradqs_tutorial
647654
beginner/basics/optimization_tutorial
648655
beginner/basics/saveloadrun_tutorial
649-
656+
650657
.. toctree::
651658
:maxdepth: 2
652659
:hidden:
@@ -799,6 +806,7 @@ Additional Resources
799806
intermediate/model_parallel_tutorial
800807
intermediate/ddp_tutorial
801808
intermediate/dist_tuto
809+
intermediate/process_group_cpp_extension_tutorial
802810
intermediate/rpc_tutorial
803811
intermediate/rpc_param_server_tutorial
804812
intermediate/dist_pipeline_parallel_tutorial
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
Customize Process Group Backends Using Cpp Extensions
2+
=====================================================
3+
4+
**Author**: `Feng Tian <https://github.com/ftian1>`__, `Shen Li <https://mrshenli.github.io/>`__
5+
6+
7+
Prerequisites:
8+
9+
- `PyTorch Distributed Overview <../beginner/dist_overview.html>`__
10+
- `PyTorch Collective Communication Package <https://pytorch.org/docs/stable/distributed.html>`__
11+
- `PyTorch Cpp Extension <https://pytorch.org/docs/stable/cpp_extension.html>`__
12+
- `Writing Distributed Applications with PyTorch <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
13+
14+
This tutorial demonstrates how to implement a custom ``ProcessGroup``
15+
backend and plug that into
16+
`PyTorch distributed package <https://pytorch.org/docs/stable/distributed.html>`__ using
17+
`cpp extensions <https://pytorch.org/docs/stable/cpp_extension.html>`__. This is helpful when you need a specialized software
18+
stack for your hardware, or when you would like to experiment with new
19+
collective communication algorithms.
20+
21+
22+
Basics
23+
------
24+
25+
PyTorch collective communications power several widely adopted distributed
26+
training features, including
27+
`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__,
28+
`ZeroRedundancyOptimizer <https://pytorch.org/docs/stable/distributed.optim.html#torch.distributed.optim.ZeroRedundancyOptimizer>`__,
29+
`FullyShardedDataParallel <https://github.com/pytorch/pytorch/blob/master/torch/distributed/_fsdp/fully_sharded_data_parallel.py>`__.
30+
In order to make the same collective communication API work with
31+
different communication backends, the distributed package abstracts collective
32+
communication operations into a
33+
`ProcessGroup <https://github.com/pytorch/pytorch/blob/release/1.10/torch/csrc/distributed/c10d/ProcessGroup.hpp>`__
34+
class. Different backends can
35+
then be implemented as subclasses of ``ProcessGroup`` using preferred
36+
third-party libraries. PyTorch distributed comes with three default backends,
37+
``ProcessGroupNCCL``, ``ProcessGroupGloo``, and ``ProcessGroupMPI``. However,
38+
beyond these three backends, there are also other communication libraries
39+
(e.g., `UCC <https://github.com/openucx/ucc>`__,
40+
`OneCCL <https://github.com/oneapi-src/oneCCL>`__), different types of hardware
41+
(e.g., `TPU <https://cloud.google.com/tpu>`__,
42+
`Trainum <https://aws.amazon.com/machine-learning/trainium/>`__), and emerging
43+
communication algorithms (e.g.,
44+
`Herring <https://www.amazon.science/publications/herring-rethinking-the-parameter-server-at-scale-for-the-cloud>`__,
45+
`Reduction Server <https://cloud.google.com/blog/topics/developers-practitioners/optimize-training-performance-reduction-server-vertex-ai>`__).
46+
Therefore, the distributed package exposes extension APIs to allow customizing
47+
collective communication backends.
48+
49+
50+
The 4 steps below show how to implement a dummy ``ProcessGroup`` backend
51+
and use that in Python application code. Please note that this tutorial focuses
52+
on demonstrating the extension APIs, instead of developing a functioning
53+
communication backend. Hence, the ``dummy`` backend just covers a subset of the
54+
APIs (``all_reduce`` and ``all_gather``), and simply sets the values of tensors
55+
to 0.
56+
57+
58+
Step 1: Implement a Subclass of ``ProcessGroup``
59+
------------------------------------------------
60+
61+
This first step is to implement a ``ProcessGroup`` subclass that overrides
62+
target collective communication APIs and runs the custom communication algorithm.
63+
The extension also needs to implement a ``ProcessGroup::Work`` subclass, which
64+
serves as a future of communication results and allows asynchronous execution in
65+
application code. If the extension uses third-party libraries, it can
66+
include the headers and call into the library APIs from the ``ProcessGroupDummy``
67+
subclass. The two code snippets below present the implementation of ``dummy.h`` and
68+
``dummy.cpp``. See the `dummy collectives <https://github.com/mrshenli/dummy_collectives>`__
69+
repository for the full implementation.
70+
71+
.. code-block:: cpp
72+
73+
// file name: dummy.hpp
74+
#include <torch/python.h>
75+
76+
#include <c10d/ProcessGroup.hpp>
77+
#include <c10d/Store.hpp>
78+
#include <c10d/Types.hpp>
79+
#include <c10d/Utils.hpp>
80+
81+
#include <pybind11/chrono.h>
82+
83+
namespace c10d {
84+
85+
class ProcessGroupDummy : public ProcessGroup {
86+
public:
87+
88+
class WorkDummy : public ProcessGroup::Work {
89+
public:
90+
WorkDummy(
91+
OpType opType,
92+
c10::intrusive_ptr<c10::ivalue::Future> future) // future of the output
93+
: ProcessGroup::Work(
94+
-1, // rank, only used by recvAnySource, irrelevant in this demo
95+
opType),
96+
future_(std::move(future)) {}
97+
// There are several additional helper functions that need to be
98+
// implemented. Please refer to https://github.com/mrshenli/dummy_collectives
99+
// for the full implementation.
100+
101+
private:
102+
c10::intrusive_ptr<c10::ivalue::Future> future_;
103+
};
104+
105+
ProcessGroupDummy(int rank, int size);
106+
107+
c10::intrusive_ptr<ProcessGroup::Work> allgather(
108+
std::vector<std::vector<at::Tensor>>& outputTensors,
109+
std::vector<at::Tensor>& inputTensors,
110+
const AllgatherOptions& opts = AllgatherOptions()) override;
111+
112+
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
113+
std::vector<at::Tensor>& tensors,
114+
const AllreduceOptions& opts = AllreduceOptions()) override;
115+
116+
// The collective communication APIs without a custom implementation
117+
// will error out if invoked by application code.
118+
};
119+
} // namespace c10d
120+
121+
122+
.. code-block:: cpp
123+
124+
// file name: dummy.cpp
125+
#include "dummy.hpp"
126+
127+
namespace c10d {
128+
129+
// This is a dummy allgather that sets all output tensors to zero
130+
// Modify the implementation to conduct real communication asynchronously
131+
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupDummy::allgather(
132+
std::vector<std::vector<at::Tensor>>& outputTensors,
133+
std::vector<at::Tensor>& inputTensors,
134+
const AllgatherOptions& /* unused */) {
135+
for (auto& outputTensorVec : outputTensors) {
136+
for (auto& outputTensor : outputTensorVec) {
137+
outputTensor.zero_();
138+
}
139+
}
140+
141+
auto future = c10::make_intrusive<c10::ivalue::Future>(
142+
c10::ListType::create(c10::ListType::create(c10::TensorType::get())));
143+
future->markCompleted(c10::IValue(outputTensors));
144+
return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));
145+
}
146+
147+
// This is a dummy allreduce that sets all output tensors to zero
148+
// Modify the implementation to conduct real communication asynchronously
149+
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupDummy::allreduce(
150+
std::vector<at::Tensor>& tensors,
151+
const AllreduceOptions& opts) {
152+
for (auto& tensor : tensors) {
153+
tensor.zero_();
154+
}
155+
156+
auto future = c10::make_intrusive<c10::ivalue::Future>(
157+
c10::ListType::create(c10::TensorType::get()));
158+
future->markCompleted(c10::IValue(tensors));
159+
return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));
160+
}
161+
} // namespace c10d
162+
163+
Step 2: Expose The Extension Python APIs
164+
----------------------------------------
165+
166+
The backend constructors are called
167+
`from Python side <https://github.com/pytorch/pytorch/blob/v1.9.0/torch/distributed/distributed_c10d.py#L643-L650>`__,
168+
so the extension also needs to expose the constructor APIs to Python. This can
169+
be done by adding the following methods. In this example, ``store`` and
170+
``timeout`` are ignored by the ``ProcessGroupDummy`` instantiation method, as
171+
those are not used in this dummy implementation. However, real-world extensions
172+
should consider using the ``store`` to perform rendezvous and supporting the
173+
``timeout`` argument.
174+
175+
.. code-block:: cpp
176+
177+
class ProcessGroupDummy : public ProcessGroup {
178+
static c10::intrusive_ptr<ProcessGroup> createProcessGroupDummy(
179+
const c10::intrusive_ptr<::c10d::Store>& store,
180+
int rank,
181+
int size,
182+
const std::chrono::duration<float>& timeout);
183+
184+
static void ProcessGroupDummyConstructor() __attribute__((constructor)) {
185+
py::object module = py::module::import("torch.distributed");
186+
py::object register_backend =
187+
module.attr("Backend").attr("register_backend");
188+
// torch.distributed.Backend.register_backend will add `dummy` as a
189+
// new valid backend.
190+
register_backend("dummy", py::cpp_function(createProcessGroupDummy));
191+
}
192+
}
193+
194+
.. code-block:: cpp
195+
196+
c10::intrusive_ptr<ProcessGroup> ProcessGroupDummy::createProcessGroupDummy(
197+
const c10::intrusive_ptr<::c10d::Store>& /* unused */,
198+
int rank,
199+
int size,
200+
const std::chrono::duration<float>& /* unused */) {
201+
return c10::make_intrusive<ProcessGroupDummy>(rank, size);
202+
}
203+
204+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
205+
m.def("createProcessGroupDummy", &ProcessGroupDummy::createProcessGroupDummy);
206+
}
207+
208+
209+
Step 3: Build The Custom Extension
210+
----------------------------------
211+
212+
Now, the extension source code files are ready. We can then use
213+
`cpp extensions <https://pytorch.org/docs/stable/cpp_extension.html>`__
214+
to build it. To do that, create a ``setup.py`` file that prepares the paths and
215+
commands. Then call ``python setup.py install`` to install the extension.
216+
217+
If the extension depends on third-party libraries, you can also specify
218+
``libraries_dirs`` and ``libraries`` to the cpp extension APIs. See the
219+
`torch ucc <https://github.com/openucx/torch-ucc>`__
220+
project as a real-world example.
221+
222+
.. code-block:: python
223+
224+
# file name: setup.py
225+
import os
226+
import sys
227+
import torch
228+
from setuptools import setup
229+
from torch.utils import cpp_extension
230+
231+
sources = ["src/dummy.cpp"]
232+
include_dirs = [f"{os.path.dirname(os.path.abspath(__file__))}/include/"]
233+
234+
if torch.cuda.is_available():
235+
module = cpp_extension.CUDAExtension(
236+
name = "dummy_collectives",
237+
sources = sources,
238+
include_dirs = include_dirs,
239+
)
240+
else:
241+
module = cpp_extension.CppExtension(
242+
name = "dummy_collectives",
243+
sources = sources,
244+
include_dirs = include_dirs,
245+
)
246+
247+
setup(
248+
name = "Dummy-Collectives",
249+
version = "0.0.1",
250+
ext_modules = [module],
251+
cmdclass={'build_ext': cpp_extension.BuildExtension}
252+
)
253+
254+
Step 4: Use The Extension in Application
255+
----------------------------------------
256+
257+
After installation, you can conveniently use the ``dummy`` backend when calling
258+
`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__
259+
as if it is an builtin backend.
260+
261+
.. code-block:: python
262+
263+
import os
264+
265+
import torch
266+
# importing dummy_collectives makes torch.distributed recognize `dummy`
267+
# as a valid backend.
268+
import dummy_collectives
269+
270+
import torch.distributed as dist
271+
272+
os.environ['MASTER_ADDR'] = 'localhost'
273+
os.environ['MASTER_PORT'] = '29500'
274+
275+
dist.init_process_group("dummy", rank=0, world_size=1)
276+
277+
x = torch.ones(6)
278+
dist.all_reduce(x)
279+
y = x.cuda()
280+
dist.all_reduce(y)
281+
282+
print(f"cpu allreduce: {x}")
283+
print(f"cuda allreduce: {y}")
284+
285+
try:
286+
dist.broadcast(x, 0)
287+
except RuntimeError:
288+
print("got RuntimeError as broadcast is not implemented in Dummy ProcessGroup")

0 commit comments

Comments
 (0)