-
Notifications
You must be signed in to change notification settings - Fork 563
Open
Labels
distributedSPMD and other distributed things.SPMD and other distributed things.enhancementNew feature or requestNew feature or request
Description
The side effects of use_spmd() do not seem reversible through any obvious APIs.
Lines 191 to 231 in 6b6ef5c
def use_spmd(auto: Optional[bool] = False): | |
"""API to enable SPMD mode. This is a recommended way to enable SPMD. | |
This forces SPMD mode if some tensors are already initialized on non-SPMD | |
devices. This means that those tensors would be replicated across the devices. | |
Args: | |
auto (bool): Whether to enable the auto-sharding. Read | |
https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#auto-sharding | |
for more detail | |
""" | |
if os.environ.get("XLA_USE_SPMD") is not None: | |
warnings.warn("XLA_USE_SPMD is being deprecated. " | |
"Use torch_xla.runtime.use_spmd() " | |
"without setting XLA_USE_SPMD env-var.") | |
if torch_xla._XLAC._xla_get_spmd_config_is_locked( | |
) and not xu.check_env_flag("XLA_USE_SPMD"): | |
warnings.warn( | |
"Replicating tensors already initialized on non-virtual XLA device for SPMD " | |
"to force SPMD mode. This is one-time overhead to setup, and to minimize such, " | |
"please set SPMD mode before initializting tensors " | |
"(i.e., call use_spmd() in the beginning of the program).") | |
torch_xla._XLAC._xla_force_spmd_device() | |
xm.wait_device_ops() | |
# TODO(yeounoh) we can drop envvar in the future | |
os.environ["XLA_USE_SPMD"] = "1" | |
if auto: | |
torch_xla._XLAC._xla_set_auto_sharding() | |
os.environ["XLA_AUTO_SPMD"] = "1" | |
if device_type() == 'NEURON': | |
# In case of Neuron, reset the initialization environment to accommodate SPMD. | |
try: | |
from torch_neuronx.initialization import initialize | |
initialize() | |
except ImportError: | |
pass | |
Is there some mechanism to do this?
Metadata
Metadata
Assignees
Labels
distributedSPMD and other distributed things.SPMD and other distributed things.enhancementNew feature or requestNew feature or request