Hi everyone, thank you so much for your exceptional work!
I'm encountering some numerical issues when weights are drawn from Gaussians with a high standard deviation. Please see the snippet below:
import numpy as np
from neural_tangents import stax
from jax import jit
W_stds = list(range(1, 17))
# W_stds.reverse()
layer_fn = []
for i in range(len(W_stds) - 1):
layer_fn.append(stax.Dense(1, W_std=W_stds[i]))
layer_fn.append(stax.Relu())
layer_fn.append(stax.Dense(1, 1.0, 0.0))
_, _, kernel_fn = stax.serial(*layer_fn)
kernel_fn = jit(kernel_fn, static_argnames="get")
x = np.random.rand(100, 100)
print(kernel_fn(x, x, "ntk"))
The result achieves:
[[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
...
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]]
By enabling float64 precision, the results indicate numerical values blowing up:
[[2.2293401e+18 9.3420067e+17 9.2034030e+17 ... 8.9008971e+17
9.6801663e+17 9.6436509e+17]
[9.3420067e+17 2.3730658e+18 9.4658846e+17 ... 9.6854199e+17
9.6182735e+17 9.9944418e+17]
[9.2034030e+17 9.4658846e+17 2.3106050e+18 ... 9.1702287e+17
9.5415269e+17 9.9692925e+17]
...
[8.9008971e+17 9.6854199e+17 9.1702300e+17 ... 2.2127619e+18
9.2056034e+17 1.0147568e+18]
[9.6801663e+17 9.6182728e+17 9.5415269e+17 ... 9.2056034e+17
2.3979914e+18 9.9505658e+17]
[9.6436488e+17 9.9944418e+17 9.9692925e+17 ... 1.0147568e+18
9.9505658e+17 2.4954969e+18]]
What's interesting is that the behavior appears to be more dependent on the depth than the high values in the weights' standard deviation. If the standard deviation of the weights were reversed (by uncommenting the code), so that in layer 1 we would have $w_{ij} \sim \mathcal{N}(0,17)$, and so on so forth. The results would remain unchanged.
Thank you in advance, and happy new year!
Hi everyone, thank you so much for your exceptional work!
I'm encountering some numerical issues when weights are drawn from Gaussians with a high standard deviation. Please see the snippet below:
The result achieves:
By enabling float64 precision, the results indicate numerical values blowing up:
What's interesting is that the behavior appears to be more dependent on the depth than the high values in the weights' standard deviation. If the standard deviation of the weights were reversed (by uncommenting the code), so that in layer 1 we would have$w_{ij} \sim \mathcal{N}(0,17)$ , and so on so forth. The results would remain unchanged.
Thank you in advance, and happy new year!