This repository was archived by the owner on May 6, 2025. It is now read-only.
Releases: google/neural-tangents
Releases · google/neural-tangents
v0.6.5
Maintenance release:
- refactoring and minor improvements
- Support and require
v0.6.4
Improvements:
- Support Python 3.11
- Use modern generic type annotations
- Various bugfixes, compatibility and documentation improvements
Breaking changes:
v0.6.2
New features:
nt.stax.repeatlayer allowing fast compilation of very deep networks (see #168 and thanks @jglaser!)- Add a Colab notebook accompanying Precise Learning Curves and Higher-Order Scaling Limits for Dot Product Kernel Regression
Improvements:
Breaking changes:
v0.6.1
New features:
-
nt.stax: -
nt.empirical:- An efficient NTK-vector product function
nt.empirical_ntk_vp_fn(without instantiating the NTK).
- An efficient NTK-vector product function
Improvements:
v0.6.0
New features:
-
nt.empirical:- New
implementation=3fornt.empirical, allowing to often speed-up or reduce the memory of the empirical NTK by orders of magnitude. Please see our ICML2022 paper Fast Finite Width Neural Tangent Kernel, new empirical NTK examples, and visit us on Thursday at ICML in-person! - New experimental prototype of using our empirical NTK implementations in Tensorflow via
nt.experimental.empirical_ntk_fn_tf. - Make
nt.empircialwork with arbitrary pytrees.
- New
-
nt.stax:
Improvements:
- Slightly lower memory usage in batching.
- Many improvements to documentation and type annotations.
- Simplify test specifications and avoid relying on JAX testing utilities.
Bugfixes:
Breaking changes:
v0.5.0
Potentially breaking changes:
- Significant internal refactoring, notably splitting
staxinto multiple sub-modules, and moving implementations into an_srcfolder. This could break your code if you use internal function likent.utils.typing,nt.utils.utils,nt.utils.Kerneletc. (public API will remain unchanged). This should be easily fixed by updating the imports, e.g.nt.utils -> nt._src.utils.
New features:
v0.4.0
WARNING:
Our next major release (v0.5.0) will include significant refactoring, and could break your code if you use internal function like nt.utils.typing, nt.utils.utils, nt.utils.Kernel etc. (public API will remain unchanged). This should be easily fixed by updating the imports, e.g. nt.utils -> nt._src.utils.
This release (v0.4.0):
New feature:
Improvements:
- Various internal refactoring and tighter tests.
Bugfixes:
- Fix values and gradients of non-differentiable
kernel_fnat zero inputs to be consistent with finite-width kernels, and how JAX defines gradients of non-differentiable functions to be the mean sub-gradient, see also #123. - Fix wrong treatment of
b_std=Nonein the infinite-width limit withparameterization='standard', see also #123. - Fix a bug in
nt.batchwhenx2 = Noneand inputs are PyTrees.
Breaking changes:
- Bump requirements to
jax==0.3andfrozendict==2.3.
v0.3.9
v0.3.8
New Features:
stax.Elementwise- a layer for generic elementwise functions requiring the user to specify only scalar-valuednngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)]. The NTK computation (thanks to @SiuMath) and vectorization over the underlyingKernelhappen automatically under the hood. If you can't derive thenngp_fnfor your function, usestax.ElementwiseNumerical. See docs for more details.
Bugfixes:
- Compatibility with JAX 0.2.21.
Full Changelog: v0.3.7...v0.3.8
v0.3.7
New Features:
nt.stax.Cosnt.stax.ImageResize- New implementation
implementation="SPARSE"innt.stax.Aggregatefor efficient handling of sparse graphs (see #86, #9) - Support
approximate=Trueinnt.stax.Gelu
Bugfixes:
- Fix a bug that might alter
Kernelrequirements - Fix
nt.batchhandling ofdiagonal_axes(see #87) - Remove the frequent but redundant warning about type conversion in
kernel_fn - Minor fixes to documentation and code clean-up
Breaking changes: