Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit 8863a8b

Browse files
committed
Use JAX type annotation for random keys. Fix pypi links to tests.
PiperOrigin-RevId: 430463106
1 parent 6d49d59 commit 8863a8b

3 files changed

Lines changed: 10 additions & 14 deletions

File tree

neural_tangents/_src/monte_carlo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from jax.tree_util import tree_map
3838
from jax.tree_util import tree_multimap
3939
from .utils import utils
40-
from .utils.typing import ApplyFn, Axes, EmpiricalGetKernelFn, Get, InitFn, MonteCarloKernelFn, NTTree, PRNGKey, PyTree, VMapAxes
40+
from .utils.typing import ApplyFn, Axes, EmpiricalGetKernelFn, Get, InitFn, MonteCarloKernelFn, NTTree, PyTree, VMapAxes
4141

4242

4343
def _sample_once_kernel_fn(kernel_fn: EmpiricalGetKernelFn,
@@ -52,7 +52,7 @@ def _sample_once_kernel_fn(kernel_fn: EmpiricalGetKernelFn,
5252
def kernel_fn_sample_once(
5353
x1: NTTree[np.ndarray],
5454
x2: Optional[NTTree[np.ndarray]],
55-
key: PRNGKey,
55+
key: random.KeyArray,
5656
get: Get,
5757
**apply_fn_kwargs):
5858
init_key, dropout_key = random.split(key, 2)
@@ -64,7 +64,7 @@ def kernel_fn_sample_once(
6464

6565
def _sample_many_kernel_fn(
6666
kernel_fn_sample_once,
67-
key: PRNGKey,
67+
key: random.KeyArray,
6868
n_samples: Set[int],
6969
get_generator: bool):
7070
def normalize(sample: PyTree, n: int) -> PyTree:
@@ -115,7 +115,7 @@ def get_sampled_kernel(
115115
def monte_carlo_kernel_fn(
116116
init_fn: InitFn,
117117
apply_fn: ApplyFn,
118-
key: PRNGKey,
118+
key: random.KeyArray,
119119
n_samples: Union[int, Iterable[int]],
120120
batch_size: int = 0,
121121
device_count: int = -1,

neural_tangents/_src/utils/typing.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
1818

1919
import jax.numpy as np
20+
from jax import random
2021
from .kernel import Kernel
2122
from typing_extensions import Protocol
2223

@@ -29,14 +30,6 @@
2930
PyTree = Any
3031

3132

32-
"""A type alias for PRNGKeys.
33-
34-
See https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.PRNGKey
35-
for details.
36-
"""
37-
PRNGKey = np.ndarray
38-
39-
4033
"""A type alias for axes specification.
4134
4235
Axes can be specified as integers (`axis=-1`) or sequences (`axis=(1, 3)`).
@@ -81,7 +74,7 @@ class InitFn(Protocol):
8174

8275
def __call__(
8376
self,
84-
rng: PRNGKey,
77+
rng: random.KeyArray,
8578
input_shape: Shapes,
8679
**kwargs
8780
) -> Tuple[Shapes, PyTree]:

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def _get_version() -> str:
8080
'Bug Tracker': 'https://github.com/google/neural-tangents/issues',
8181
'Release Notes': 'https://github.com/google/neural-tangents/releases',
8282
'PyPi': 'https://pypi.org/project/neural-tangents/',
83-
'Tests': 'https://travis-ci.org/github/google/neural-tangents',
83+
'Linux Tests': 'https://github.com/google/neural-tangents/actions/workflows/linux.yml',
84+
'macOS Tests': 'https://github.com/google/neural-tangents/actions/workflows/macos.yml',
85+
'Pytype': 'https://github.com/google/neural-tangents/actions/workflows/pytype.yml',
86+
'Coverage': 'https://app.codecov.io/gh/google/neural-tangents'
8487
},
8588
packages=setuptools.find_packages(exclude=('presentation',)),
8689
long_description=long_description,

0 commit comments

Comments
 (0)