Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit 6766bff

Browse files
committed
Copybara import of #157
PiperOrigin-RevId: 461633967
1 parent ff03ce9 commit 6766bff

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

neural_tangents/_src/batching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def batch(kernel_fn: _KernelFn,
110110
input_req = getattr(kernel_fn, 'input_req', {})
111111
dropout_in_analytic_kernel = input_req.get('use_dropout', False)
112112
use_multidevice = device_count > 0 or (device_count == -1 and
113-
jax.device_count() > 1)
113+
jax.local_device_count() > 1)
114114
use_serial = bool(batch_size)
115115
if use_multidevice:
116116
kernel_fn = _parallel(kernel_fn, use_serial,
@@ -522,7 +522,7 @@ def _parallel(kernel_fn: _KernelFn,
522522
"""
523523

524524
if device_count == -1:
525-
device_count = jax.device_count()
525+
device_count = jax.local_device_count()
526526

527527
def _check_dropout(n1, n2, kwargs):
528528
dropout_in_empirical_kernel = getattr(kwargs, 'rng', None) is not None
@@ -700,7 +700,7 @@ def jit_or_pmap_broadcast(f: Callable, device_count: int = -1) -> Callable:
700700
key = (f, device_count)
701701

702702
if device_count == -1:
703-
device_count = jax.device_count()
703+
device_count = jax.local_device_count()
704704

705705
# TODO(romann): adapt this when JAX allows `axis_in` for `pmap`.
706706
def broadcast(arg: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)