Skip to content

Commit fb14da7

Browse files
authored
Add cdf method for truncated distributions (#2074)
* implementation init * rtol * improve tol for specific distribution
1 parent 4fa96d7 commit fb14da7

File tree

2 files changed

+293
-1
lines changed

2 files changed

+293
-1
lines changed

numpyro/distributions/truncated.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,20 @@ def icdf(self, q: ArrayLike) -> ArrayLike:
9090
)
9191
return jnp.where(q < 0, jnp.nan, ppf)
9292

93+
def cdf(self, value: ArrayLike) -> ArrayLike:
94+
# For left truncated distribution: CDF(x) = (F(x) - F(low)) / (1 - F(low))
95+
# where F is the base distribution CDF
96+
base_cdf_value = self.base_dist.cdf(value)
97+
base_cdf_low = self.base_dist.cdf(self.low)
98+
99+
# Handle the case where value < low (should be 0)
100+
# and value >= low (should be the truncated CDF)
101+
truncated_cdf = (base_cdf_value - base_cdf_low) / (1.0 - base_cdf_low)
102+
103+
# Clamp to [0, 1] and handle values below the truncation point
104+
result = jnp.where(value < self.low, 0.0, jnp.clip(truncated_cdf, 0.0, 1.0))
105+
return result
106+
93107
@validate_sample
94108
def log_prob(self, value: ArrayLike) -> ArrayLike:
95109
sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0)
@@ -169,6 +183,20 @@ def icdf(self, q: ArrayLike) -> ArrayLike:
169183
ppf = self.base_dist.icdf(q * self._cdf_at_high)
170184
return jnp.where(q > 1, jnp.nan, ppf)
171185

186+
def cdf(self, value: ArrayLike) -> ArrayLike:
187+
# For right truncated distribution: CDF(x) = F(x) / F(high)
188+
# where F is the base distribution CDF
189+
base_cdf_value = self.base_dist.cdf(value)
190+
base_cdf_high = self._cdf_at_high
191+
192+
# Handle the case where value > high (should be 1)
193+
# and value <= high (should be the truncated CDF)
194+
truncated_cdf = base_cdf_value / base_cdf_high
195+
196+
# Clamp to [0, 1] and handle values above the truncation point
197+
result = jnp.where(value > self.high, 1.0, jnp.clip(truncated_cdf, 0.0, 1.0))
198+
return result
199+
172200
@validate_sample
173201
def log_prob(self, value: ArrayLike) -> ArrayLike:
174202
return self.base_dist.log_prob(value) - jnp.log(self._cdf_at_high)
@@ -290,6 +318,27 @@ def icdf(self, q: ArrayLike) -> ArrayLike:
290318
)
291319
return jnp.where(jnp.logical_or(q < 0, q > 1), jnp.nan, ppf)
292320

321+
def cdf(self, value: ArrayLike) -> ArrayLike:
322+
# For two-sided truncated distribution: CDF(x) = (F(x) - F(low)) / (F(high) - F(low))
323+
# where F is the base distribution CDF
324+
base_cdf_value = self.base_dist.cdf(value)
325+
base_cdf_low = self.base_dist.cdf(self.low)
326+
base_cdf_high = self.base_dist.cdf(self.high)
327+
328+
# Calculate the normalization constant (F(high) - F(low))
329+
normalization = base_cdf_high - base_cdf_low
330+
331+
# Calculate the truncated CDF
332+
truncated_cdf = (base_cdf_value - base_cdf_low) / normalization
333+
334+
# Handle values outside the truncation interval
335+
result = jnp.where(
336+
value < self.low,
337+
0.0,
338+
jnp.where(value > self.high, 1.0, jnp.clip(truncated_cdf, 0.0, 1.0)),
339+
)
340+
return result
341+
293342
@validate_sample
294343
def log_prob(self, value: ArrayLike) -> ArrayLike:
295344
# NB: we use a more numerically stable formula for a symmetric base distribution

test/test_distributions.py

Lines changed: 244 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1651,7 +1651,18 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
16511651
samples = d.sample(key=random.PRNGKey(0), sample_shape=(100,))
16521652
quantiles = random.uniform(random.PRNGKey(1), (100,) + d.shape())
16531653
try:
1654-
rtol = 2e-3 if jax_dist in (dist.Gamma, dist.LogNormal, dist.StudentT) else 1e-5
1654+
rtol = (
1655+
2e-3
1656+
if jax_dist
1657+
in (
1658+
_TruncatedCauchy,
1659+
_TruncatedNormal,
1660+
dist.Gamma,
1661+
dist.LogNormal,
1662+
dist.StudentT,
1663+
)
1664+
else 1e-5
1665+
)
16551666
if d.shape() == () and not d.is_discrete:
16561667
assert_allclose(
16571668
jax.vmap(jax.grad(d.cdf))(samples),
@@ -3602,3 +3613,235 @@ def test_distribution_repr():
36023613
result = repr(dist.Wishart(7, jnp.eye(5)).expand([3, 4]).to_event(1))
36033614
assert "batch shape (3,)" in result
36043615
assert "event shape (4, 5, 5)"
3616+
3617+
3618+
@pytest.mark.parametrize(
3619+
"base_dist_class, base_params",
3620+
[
3621+
(dist.Normal, (0.0, 1.0)),
3622+
(dist.Normal, (2.0, 0.5)),
3623+
(dist.Cauchy, (0.0, 1.0)),
3624+
(dist.Laplace, (0.0, 1.0)),
3625+
(dist.Logistic, (0.0, 1.0)),
3626+
(dist.StudentT, (2.0, 0.0, 1.0)),
3627+
],
3628+
)
3629+
@pytest.mark.parametrize("low", [-2.0, -1.0, 0.0])
3630+
def test_left_truncated_cdf(base_dist_class, base_params, low):
3631+
"""Test CDF for left truncated distributions."""
3632+
base_dist = base_dist_class(*base_params)
3633+
truncated_dist = dist.LeftTruncatedDistribution(base_dist, low)
3634+
3635+
# Test points
3636+
test_values = jnp.array([low - 1.0, low, low + 0.5, low + 1.0, low + 2.0])
3637+
3638+
# Compute CDF
3639+
cdf_values = truncated_dist.cdf(test_values)
3640+
3641+
# Basic properties
3642+
assert cdf_values.shape == test_values.shape
3643+
assert jnp.all(cdf_values >= 0.0)
3644+
assert jnp.all(cdf_values <= 1.0)
3645+
3646+
# Values below truncation point should have CDF = 0
3647+
assert_allclose(cdf_values[0], 0.0, atol=1e-6)
3648+
3649+
# CDF should be monotonically increasing
3650+
assert jnp.all(jnp.diff(cdf_values[1:]) >= -1e-6) # Allow small numerical errors
3651+
3652+
# Test consistency with icdf (inverse CDF)
3653+
quantiles = jnp.array([0.1, 0.3, 0.5, 0.7, 0.9])
3654+
icdf_values = truncated_dist.icdf(quantiles)
3655+
recovered_quantiles = truncated_dist.cdf(icdf_values)
3656+
assert_allclose(recovered_quantiles, quantiles, atol=1e-5)
3657+
3658+
3659+
@pytest.mark.parametrize(
3660+
"base_dist_class, base_params",
3661+
[
3662+
(dist.Normal, (0.0, 1.0)),
3663+
(dist.Normal, (-1.0, 2.0)),
3664+
(dist.Cauchy, (0.0, 1.0)),
3665+
(dist.Laplace, (0.0, 1.0)),
3666+
(dist.Logistic, (0.0, 1.0)),
3667+
(dist.StudentT, (2.0, 0.0, 1.0)),
3668+
],
3669+
)
3670+
@pytest.mark.parametrize("high", [0.0, 1.0, 2.0])
3671+
def test_right_truncated_cdf(base_dist_class, base_params, high):
3672+
"""Test CDF for right truncated distributions."""
3673+
base_dist = base_dist_class(*base_params)
3674+
truncated_dist = dist.RightTruncatedDistribution(base_dist, high)
3675+
3676+
# Test points
3677+
test_values = jnp.array([high - 2.0, high - 1.0, high - 0.5, high, high + 1.0])
3678+
3679+
# Compute CDF
3680+
cdf_values = truncated_dist.cdf(test_values)
3681+
3682+
# Basic properties
3683+
assert cdf_values.shape == test_values.shape
3684+
assert jnp.all(cdf_values >= 0.0)
3685+
assert jnp.all(cdf_values <= 1.0)
3686+
3687+
# Values above truncation point should have CDF = 1
3688+
assert_allclose(cdf_values[-1], 1.0, atol=1e-6)
3689+
3690+
# CDF should be monotonically increasing
3691+
assert jnp.all(jnp.diff(cdf_values[:-1]) >= -1e-6) # Allow small numerical errors
3692+
3693+
# Test consistency with icdf (inverse CDF)
3694+
quantiles = jnp.array([0.1, 0.3, 0.5, 0.7, 0.9])
3695+
icdf_values = truncated_dist.icdf(quantiles)
3696+
recovered_quantiles = truncated_dist.cdf(icdf_values)
3697+
assert_allclose(recovered_quantiles, quantiles, atol=1e-5)
3698+
3699+
3700+
@pytest.mark.parametrize(
3701+
"base_dist_class, base_params",
3702+
[
3703+
(dist.Normal, (0.0, 1.0)),
3704+
(dist.Normal, (1.0, 0.8)),
3705+
(dist.Cauchy, (0.0, 1.0)),
3706+
(dist.Laplace, (0.0, 1.0)),
3707+
(dist.Logistic, (0.0, 1.0)),
3708+
(dist.StudentT, (2.0, 0.0, 1.0)),
3709+
],
3710+
)
3711+
@pytest.mark.parametrize("low, high", [(-2.0, 2.0), (-1.0, 1.0), (0.0, 3.0)])
3712+
def test_two_sided_truncated_cdf(base_dist_class, base_params, low, high):
3713+
"""Test CDF for two-sided truncated distributions."""
3714+
base_dist = base_dist_class(*base_params)
3715+
truncated_dist = dist.TwoSidedTruncatedDistribution(base_dist, low, high)
3716+
3717+
# Test points
3718+
test_values = jnp.array([low - 1.0, low, (low + high) / 2, high, high + 1.0])
3719+
3720+
# Compute CDF
3721+
cdf_values = truncated_dist.cdf(test_values)
3722+
3723+
# Basic properties
3724+
assert cdf_values.shape == test_values.shape
3725+
assert jnp.all(cdf_values >= 0.0)
3726+
assert jnp.all(cdf_values <= 1.0)
3727+
3728+
# Values below truncation point should have CDF = 0
3729+
assert_allclose(cdf_values[0], 0.0, atol=1e-6)
3730+
3731+
# Values above truncation point should have CDF = 1
3732+
assert_allclose(cdf_values[-1], 1.0, atol=1e-6)
3733+
3734+
# CDF should be monotonically increasing
3735+
assert jnp.all(jnp.diff(cdf_values[1:-1]) >= -1e-6) # Allow small numerical errors
3736+
3737+
# Test consistency with icdf (inverse CDF)
3738+
quantiles = jnp.array([0.1, 0.3, 0.5, 0.7, 0.9])
3739+
icdf_values = truncated_dist.icdf(quantiles)
3740+
recovered_quantiles = truncated_dist.cdf(icdf_values)
3741+
assert_allclose(recovered_quantiles, quantiles, atol=1e-5)
3742+
3743+
3744+
@pytest.mark.parametrize("loc, scale", [(0.0, 1.0), (2.0, 0.5), (-1.0, 2.0)])
3745+
@pytest.mark.parametrize(
3746+
"low, high", [(-2.0, 2.0), (-1.0, 1.0), (0.0, 3.0), (None, 2.0), (-2.0, None)]
3747+
)
3748+
def test_truncated_normal_cdf_scipy_consistency(loc, scale, low, high):
3749+
"""Test consistency with scipy truncated normal CDF."""
3750+
from jax.scipy.stats import truncnorm as jax_truncnorm
3751+
3752+
# Create truncated normal distribution
3753+
if low is None and high is None:
3754+
pytest.skip("Cannot test when both bounds are None")
3755+
3756+
if low is None:
3757+
truncated_dist = dist.RightTruncatedDistribution(dist.Normal(loc, scale), high)
3758+
a = -jnp.inf
3759+
b = (high - loc) / scale
3760+
elif high is None:
3761+
truncated_dist = dist.LeftTruncatedDistribution(dist.Normal(loc, scale), low)
3762+
a = (low - loc) / scale
3763+
b = jnp.inf
3764+
else:
3765+
truncated_dist = dist.TwoSidedTruncatedDistribution(
3766+
dist.Normal(loc, scale), low, high
3767+
)
3768+
a = (low - loc) / scale
3769+
b = (high - loc) / scale
3770+
3771+
# Test values within the truncation range
3772+
if low is None:
3773+
test_values = jnp.linspace(high - 3 * scale, high - 0.1 * scale, 10)
3774+
elif high is None:
3775+
test_values = jnp.linspace(low + 0.1 * scale, low + 3 * scale, 10)
3776+
else:
3777+
test_values = jnp.linspace(
3778+
low + 0.1 * (high - low), high - 0.1 * (high - low), 10
3779+
)
3780+
3781+
# Compare CDFs
3782+
numpyro_cdf = truncated_dist.cdf(test_values)
3783+
jax_cdf = jax_truncnorm.cdf(test_values, a=a, b=b, loc=loc, scale=scale)
3784+
3785+
assert_allclose(numpyro_cdf, jax_cdf, rtol=1e-5, atol=1e-6)
3786+
3787+
3788+
def test_truncated_cdf_edge_cases():
3789+
"""Test edge cases for truncated distribution CDFs."""
3790+
base_dist = dist.Normal(0.0, 1.0)
3791+
3792+
# Test with extreme truncation points
3793+
left_truncated = dist.LeftTruncatedDistribution(base_dist, 5.0) # Far in the tail
3794+
right_truncated = dist.RightTruncatedDistribution(
3795+
base_dist, -5.0
3796+
) # Far in the tail
3797+
two_sided = dist.TwoSidedTruncatedDistribution(base_dist, -0.1, 0.1) # Very narrow
3798+
3799+
# Test that CDFs are well-behaved
3800+
test_values = jnp.array([-10.0, 0.0, 10.0])
3801+
3802+
left_cdf = left_truncated.cdf(test_values)
3803+
assert jnp.all(jnp.isfinite(left_cdf))
3804+
assert jnp.all(left_cdf >= 0.0) and jnp.all(left_cdf <= 1.0)
3805+
3806+
right_cdf = right_truncated.cdf(test_values)
3807+
assert jnp.all(jnp.isfinite(right_cdf))
3808+
assert jnp.all(right_cdf >= 0.0) and jnp.all(right_cdf <= 1.0)
3809+
3810+
two_sided_cdf = two_sided.cdf(test_values)
3811+
assert jnp.all(jnp.isfinite(two_sided_cdf))
3812+
assert jnp.all(two_sided_cdf >= 0.0) and jnp.all(two_sided_cdf <= 1.0)
3813+
3814+
3815+
@pytest.mark.parametrize("batch_shape", [(), (3,)])
3816+
def test_truncated_cdf_batch_shapes(batch_shape):
3817+
"""Test that CDF works correctly with batch shapes."""
3818+
if batch_shape == ():
3819+
loc = 0.0
3820+
scale = 1.0
3821+
low = -1.0
3822+
high = 1.0
3823+
else:
3824+
loc = jnp.zeros(batch_shape)
3825+
scale = jnp.ones(batch_shape)
3826+
low = -jnp.ones(batch_shape)
3827+
high = jnp.ones(batch_shape)
3828+
3829+
base_dist = dist.Normal(loc, scale)
3830+
truncated_dist = dist.TwoSidedTruncatedDistribution(base_dist, low, high)
3831+
3832+
# Test with single value
3833+
value = 0.0
3834+
cdf_value = truncated_dist.cdf(value)
3835+
assert cdf_value.shape == batch_shape
3836+
3837+
# Test with multiple values - these should broadcast properly
3838+
if batch_shape == ():
3839+
values = jnp.array([-2.0, 0.0, 2.0])
3840+
cdf_values = truncated_dist.cdf(values)
3841+
expected_shape = values.shape
3842+
assert cdf_values.shape == expected_shape
3843+
else:
3844+
# For batched case, test with single values to avoid broadcasting issues
3845+
for value in [-2.0, 0.0, 2.0]:
3846+
cdf_value = truncated_dist.cdf(value)
3847+
assert cdf_value.shape == batch_shape

0 commit comments

Comments
 (0)