Skip to content

Commit 9f736d7

Browse files
committed
update from bleepblop
1 parent 8c421ac commit 9f736d7

23 files changed

+5
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,4 @@ The functions $u$, $v$ must be monotonic in $\tau$ and $y_c$ respectively.
9393
- randomize $\tau$ sampling during training in jax instead of grid ...
9494
- regularization/calibration/conformal prediction
9595
- review the translation to jax and make sure arch is actually the same as tf
96+
- make sure data is pure and deterministic across envs

figs/jax/cdfloss_crps.png

6.32 KB
Loading

figs/jax/cdfloss_logistic.png

-129 Bytes
Loading

figs/jax/cdfloss_nox_crps.png

3.17 KB
Loading

figs/jax/cdfloss_nox_logistic.png

2.36 KB
Loading

figs/jax/p_crps.png

3.27 KB
Loading

figs/jax/p_logistic.png

7.69 KB
Loading

figs/jax/p_nox_crps.png

-2.02 KB
Loading

figs/jax/p_nox_logistic.png

-669 Bytes
Loading

figs/jax/q.png

-7.16 KB
Loading

figs/jax/q_nox.png

380 Bytes
Loading

figs/jax/qloss.png

-1.71 KB
Loading

figs/jax/qloss_nox.png

-226 Bytes
Loading

figs/tf/cdfloss.png

513 Bytes
Loading

figs/tf/cdfloss_nox.png

-29 Bytes
Loading

figs/tf/p.png

1.21 KB
Loading

figs/tf/p_nox.png

-3.18 KB
Loading

figs/tf/q.png

3.56 KB
Loading

figs/tf/q_nox.png

1.55 KB
Loading

figs/tf/qloss.png

651 Bytes
Loading

figs/tf/qloss_nox.png

-76 Bytes
Loading

quantile_regression_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def get_data(seed=1, m=250, n_x=1, n_tau=11, L=2):
1313
y ~ N(mu(x), sigma(x))
1414
"""
1515
random.seed(seed)
16+
np.random.seed(seed)
1617
x = (2 * np.random.rand(m, n_x).astype(np.float64) - 1) * 2
1718
i = np.argsort(x[:, 0])
1819
x = x[i] # to make plotting nicer

quantile_regression_jax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(self, tau_dims, x_dims, final_dims, *, key):
161161
# Tau network layers
162162
tau_keys = jax.random.split(keys[0], len(tau_dims) - 1)
163163
self.tau_layers = [
164-
Dense(
164+
NonNegDense(
165165
tau_dims[i],
166166
tau_dims[i + 1],
167167
key=tau_keys[i],
@@ -184,7 +184,8 @@ def __init__(self, tau_dims, x_dims, final_dims, *, key):
184184
final_dims[i],
185185
final_dims[i + 1],
186186
key=final_keys[i],
187-
activation=(jax.nn.tanh if i < len(final_dims) - 2 else None),
187+
activation=None,
188+
# activation=(jax.nn.tanh if i < len(final_dims) - 2 else None),
188189
)
189190
for i in range(len(final_dims) - 1)
190191
]

0 commit comments

Comments
 (0)