Skip to content

Commit 79ef786

Browse files
authored
Adds torch.cuda.set_device calls to DDP examples (#1142)
Add set_device calls to DDP examples
1 parent 6a64939 commit 79ef786

File tree

4 files changed

+4
-0
lines changed

4 files changed

+4
-0
lines changed

distributed/ddp-tutorial-series/multigpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def ddp_setup(rank, world_size):
1919
os.environ["MASTER_ADDR"] = "localhost"
2020
os.environ["MASTER_PORT"] = "12355"
2121
init_process_group(backend="nccl", rank=rank, world_size=world_size)
22+
torch.cuda.set_device(rank)
2223

2324
class Trainer:
2425
def __init__(

distributed/ddp-tutorial-series/multigpu_torchrun.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
def ddp_setup():
1414
init_process_group(backend="nccl")
15+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
1516

1617
class Trainer:
1718
def __init__(

distributed/ddp-tutorial-series/multinode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
def ddp_setup():
1414
init_process_group(backend="nccl")
15+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
1516

1617
class Trainer:
1718
def __init__(

distributed/minGPT-ddp/mingpt/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
def ddp_setup():
1010
init_process_group(backend="nccl")
11+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
1112

1213
def get_train_objs(gpt_cfg: GPTConfig, opt_cfg: OptimizerConfig, data_cfg: DataConfig):
1314
dataset = CharDataset(data_cfg)

0 commit comments

Comments
 (0)