Skip to content

API for disabling SPMD? #9578

@jameszianxuTT

Description

@jameszianxuTT

The side effects of use_spmd() do not seem reversible through any obvious APIs.

xla/torch_xla/runtime.py

Lines 191 to 231 in 6b6ef5c

def use_spmd(auto: Optional[bool] = False):
"""API to enable SPMD mode. This is a recommended way to enable SPMD.
This forces SPMD mode if some tensors are already initialized on non-SPMD
devices. This means that those tensors would be replicated across the devices.
Args:
auto (bool): Whether to enable the auto-sharding. Read
https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#auto-sharding
for more detail
"""
if os.environ.get("XLA_USE_SPMD") is not None:
warnings.warn("XLA_USE_SPMD is being deprecated. "
"Use torch_xla.runtime.use_spmd() "
"without setting XLA_USE_SPMD env-var.")
if torch_xla._XLAC._xla_get_spmd_config_is_locked(
) and not xu.check_env_flag("XLA_USE_SPMD"):
warnings.warn(
"Replicating tensors already initialized on non-virtual XLA device for SPMD "
"to force SPMD mode. This is one-time overhead to setup, and to minimize such, "
"please set SPMD mode before initializting tensors "
"(i.e., call use_spmd() in the beginning of the program).")
torch_xla._XLAC._xla_force_spmd_device()
xm.wait_device_ops()
# TODO(yeounoh) we can drop envvar in the future
os.environ["XLA_USE_SPMD"] = "1"
if auto:
torch_xla._XLAC._xla_set_auto_sharding()
os.environ["XLA_AUTO_SPMD"] = "1"
if device_type() == 'NEURON':
# In case of Neuron, reset the initialization environment to accommodate SPMD.
try:
from torch_neuronx.initialization import initialize
initialize()
except ImportError:
pass

Is there some mechanism to do this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    distributedSPMD and other distributed things.enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions