Skip to content

Remove excessive warn message in maybe_get_jax as it creates too many log lines during training #9569

@rajkthakur

Description

@rajkthakur

🐛 Bug

The maybe_get_jax() function in torch_xla/_internal/jax_workarounds.py merged in #9521 currently emits a warning message when JAX is not installed. While informative, this warning results in an excessive number of log lines during training workloads, cluttering the logs and making it difficult to spot genuinely important debug messages.

To Reproduce

Steps to reproduce the behavior:

  1. Create Python Virtual Environment (python3 -m venv ptxla_28) on Ubuntu 22.04
  2. pip install torch==2.8.0 torchvision; pip install torch_xla==2.8.0
  3. Create small python script(let's call it trigger_warning.py)
import sys
sys.path.insert(0, 'ptxla_28/lib/python3.10/site-packages')
from torch_xla._internal.jax_workarounds import maybe_get_jax
maybe_get_jax() 
  1. execute the script bash -c "source ptxla_28/bin/activate && python trigger_warning.py"
  2. You should be able to see the warning message like below
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING:root:You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]

Expected behavior

Remove or suppress this warning message, or limit it to display only once per process/session instead of for every invocation.

Environment

Additional context

The current behavior results in thousands of lines of repeated warnings when running workloads that do not require JAX, negatively impacting developer experience. Reducing or removing this warning will significantly clean up logs for users running long or large-scale training jobs, improving usability without sacrificing relevant error reporting.

Metadata

Metadata

Labels

2.8 releaseperformanceusabilityBugs/features related to improving the usability of PyTorch/XLA

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions