@@ -1651,7 +1651,18 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
16511651 samples = d .sample (key = random .PRNGKey (0 ), sample_shape = (100 ,))
16521652 quantiles = random .uniform (random .PRNGKey (1 ), (100 ,) + d .shape ())
16531653 try :
1654- rtol = 2e-3 if jax_dist in (dist .Gamma , dist .LogNormal , dist .StudentT ) else 1e-5
1654+ rtol = (
1655+ 2e-3
1656+ if jax_dist
1657+ in (
1658+ _TruncatedCauchy ,
1659+ _TruncatedNormal ,
1660+ dist .Gamma ,
1661+ dist .LogNormal ,
1662+ dist .StudentT ,
1663+ )
1664+ else 1e-5
1665+ )
16551666 if d .shape () == () and not d .is_discrete :
16561667 assert_allclose (
16571668 jax .vmap (jax .grad (d .cdf ))(samples ),
@@ -3602,3 +3613,235 @@ def test_distribution_repr():
36023613 result = repr (dist .Wishart (7 , jnp .eye (5 )).expand ([3 , 4 ]).to_event (1 ))
36033614 assert "batch shape (3,)" in result
36043615 assert "event shape (4, 5, 5)"
3616+
3617+
3618+ @pytest .mark .parametrize (
3619+ "base_dist_class, base_params" ,
3620+ [
3621+ (dist .Normal , (0.0 , 1.0 )),
3622+ (dist .Normal , (2.0 , 0.5 )),
3623+ (dist .Cauchy , (0.0 , 1.0 )),
3624+ (dist .Laplace , (0.0 , 1.0 )),
3625+ (dist .Logistic , (0.0 , 1.0 )),
3626+ (dist .StudentT , (2.0 , 0.0 , 1.0 )),
3627+ ],
3628+ )
3629+ @pytest .mark .parametrize ("low" , [- 2.0 , - 1.0 , 0.0 ])
3630+ def test_left_truncated_cdf (base_dist_class , base_params , low ):
3631+ """Test CDF for left truncated distributions."""
3632+ base_dist = base_dist_class (* base_params )
3633+ truncated_dist = dist .LeftTruncatedDistribution (base_dist , low )
3634+
3635+ # Test points
3636+ test_values = jnp .array ([low - 1.0 , low , low + 0.5 , low + 1.0 , low + 2.0 ])
3637+
3638+ # Compute CDF
3639+ cdf_values = truncated_dist .cdf (test_values )
3640+
3641+ # Basic properties
3642+ assert cdf_values .shape == test_values .shape
3643+ assert jnp .all (cdf_values >= 0.0 )
3644+ assert jnp .all (cdf_values <= 1.0 )
3645+
3646+ # Values below truncation point should have CDF = 0
3647+ assert_allclose (cdf_values [0 ], 0.0 , atol = 1e-6 )
3648+
3649+ # CDF should be monotonically increasing
3650+ assert jnp .all (jnp .diff (cdf_values [1 :]) >= - 1e-6 ) # Allow small numerical errors
3651+
3652+ # Test consistency with icdf (inverse CDF)
3653+ quantiles = jnp .array ([0.1 , 0.3 , 0.5 , 0.7 , 0.9 ])
3654+ icdf_values = truncated_dist .icdf (quantiles )
3655+ recovered_quantiles = truncated_dist .cdf (icdf_values )
3656+ assert_allclose (recovered_quantiles , quantiles , atol = 1e-5 )
3657+
3658+
3659+ @pytest .mark .parametrize (
3660+ "base_dist_class, base_params" ,
3661+ [
3662+ (dist .Normal , (0.0 , 1.0 )),
3663+ (dist .Normal , (- 1.0 , 2.0 )),
3664+ (dist .Cauchy , (0.0 , 1.0 )),
3665+ (dist .Laplace , (0.0 , 1.0 )),
3666+ (dist .Logistic , (0.0 , 1.0 )),
3667+ (dist .StudentT , (2.0 , 0.0 , 1.0 )),
3668+ ],
3669+ )
3670+ @pytest .mark .parametrize ("high" , [0.0 , 1.0 , 2.0 ])
3671+ def test_right_truncated_cdf (base_dist_class , base_params , high ):
3672+ """Test CDF for right truncated distributions."""
3673+ base_dist = base_dist_class (* base_params )
3674+ truncated_dist = dist .RightTruncatedDistribution (base_dist , high )
3675+
3676+ # Test points
3677+ test_values = jnp .array ([high - 2.0 , high - 1.0 , high - 0.5 , high , high + 1.0 ])
3678+
3679+ # Compute CDF
3680+ cdf_values = truncated_dist .cdf (test_values )
3681+
3682+ # Basic properties
3683+ assert cdf_values .shape == test_values .shape
3684+ assert jnp .all (cdf_values >= 0.0 )
3685+ assert jnp .all (cdf_values <= 1.0 )
3686+
3687+ # Values above truncation point should have CDF = 1
3688+ assert_allclose (cdf_values [- 1 ], 1.0 , atol = 1e-6 )
3689+
3690+ # CDF should be monotonically increasing
3691+ assert jnp .all (jnp .diff (cdf_values [:- 1 ]) >= - 1e-6 ) # Allow small numerical errors
3692+
3693+ # Test consistency with icdf (inverse CDF)
3694+ quantiles = jnp .array ([0.1 , 0.3 , 0.5 , 0.7 , 0.9 ])
3695+ icdf_values = truncated_dist .icdf (quantiles )
3696+ recovered_quantiles = truncated_dist .cdf (icdf_values )
3697+ assert_allclose (recovered_quantiles , quantiles , atol = 1e-5 )
3698+
3699+
3700+ @pytest .mark .parametrize (
3701+ "base_dist_class, base_params" ,
3702+ [
3703+ (dist .Normal , (0.0 , 1.0 )),
3704+ (dist .Normal , (1.0 , 0.8 )),
3705+ (dist .Cauchy , (0.0 , 1.0 )),
3706+ (dist .Laplace , (0.0 , 1.0 )),
3707+ (dist .Logistic , (0.0 , 1.0 )),
3708+ (dist .StudentT , (2.0 , 0.0 , 1.0 )),
3709+ ],
3710+ )
3711+ @pytest .mark .parametrize ("low, high" , [(- 2.0 , 2.0 ), (- 1.0 , 1.0 ), (0.0 , 3.0 )])
3712+ def test_two_sided_truncated_cdf (base_dist_class , base_params , low , high ):
3713+ """Test CDF for two-sided truncated distributions."""
3714+ base_dist = base_dist_class (* base_params )
3715+ truncated_dist = dist .TwoSidedTruncatedDistribution (base_dist , low , high )
3716+
3717+ # Test points
3718+ test_values = jnp .array ([low - 1.0 , low , (low + high ) / 2 , high , high + 1.0 ])
3719+
3720+ # Compute CDF
3721+ cdf_values = truncated_dist .cdf (test_values )
3722+
3723+ # Basic properties
3724+ assert cdf_values .shape == test_values .shape
3725+ assert jnp .all (cdf_values >= 0.0 )
3726+ assert jnp .all (cdf_values <= 1.0 )
3727+
3728+ # Values below truncation point should have CDF = 0
3729+ assert_allclose (cdf_values [0 ], 0.0 , atol = 1e-6 )
3730+
3731+ # Values above truncation point should have CDF = 1
3732+ assert_allclose (cdf_values [- 1 ], 1.0 , atol = 1e-6 )
3733+
3734+ # CDF should be monotonically increasing
3735+ assert jnp .all (jnp .diff (cdf_values [1 :- 1 ]) >= - 1e-6 ) # Allow small numerical errors
3736+
3737+ # Test consistency with icdf (inverse CDF)
3738+ quantiles = jnp .array ([0.1 , 0.3 , 0.5 , 0.7 , 0.9 ])
3739+ icdf_values = truncated_dist .icdf (quantiles )
3740+ recovered_quantiles = truncated_dist .cdf (icdf_values )
3741+ assert_allclose (recovered_quantiles , quantiles , atol = 1e-5 )
3742+
3743+
3744+ @pytest .mark .parametrize ("loc, scale" , [(0.0 , 1.0 ), (2.0 , 0.5 ), (- 1.0 , 2.0 )])
3745+ @pytest .mark .parametrize (
3746+ "low, high" , [(- 2.0 , 2.0 ), (- 1.0 , 1.0 ), (0.0 , 3.0 ), (None , 2.0 ), (- 2.0 , None )]
3747+ )
3748+ def test_truncated_normal_cdf_scipy_consistency (loc , scale , low , high ):
3749+ """Test consistency with scipy truncated normal CDF."""
3750+ from jax .scipy .stats import truncnorm as jax_truncnorm
3751+
3752+ # Create truncated normal distribution
3753+ if low is None and high is None :
3754+ pytest .skip ("Cannot test when both bounds are None" )
3755+
3756+ if low is None :
3757+ truncated_dist = dist .RightTruncatedDistribution (dist .Normal (loc , scale ), high )
3758+ a = - jnp .inf
3759+ b = (high - loc ) / scale
3760+ elif high is None :
3761+ truncated_dist = dist .LeftTruncatedDistribution (dist .Normal (loc , scale ), low )
3762+ a = (low - loc ) / scale
3763+ b = jnp .inf
3764+ else :
3765+ truncated_dist = dist .TwoSidedTruncatedDistribution (
3766+ dist .Normal (loc , scale ), low , high
3767+ )
3768+ a = (low - loc ) / scale
3769+ b = (high - loc ) / scale
3770+
3771+ # Test values within the truncation range
3772+ if low is None :
3773+ test_values = jnp .linspace (high - 3 * scale , high - 0.1 * scale , 10 )
3774+ elif high is None :
3775+ test_values = jnp .linspace (low + 0.1 * scale , low + 3 * scale , 10 )
3776+ else :
3777+ test_values = jnp .linspace (
3778+ low + 0.1 * (high - low ), high - 0.1 * (high - low ), 10
3779+ )
3780+
3781+ # Compare CDFs
3782+ numpyro_cdf = truncated_dist .cdf (test_values )
3783+ jax_cdf = jax_truncnorm .cdf (test_values , a = a , b = b , loc = loc , scale = scale )
3784+
3785+ assert_allclose (numpyro_cdf , jax_cdf , rtol = 1e-5 , atol = 1e-6 )
3786+
3787+
3788+ def test_truncated_cdf_edge_cases ():
3789+ """Test edge cases for truncated distribution CDFs."""
3790+ base_dist = dist .Normal (0.0 , 1.0 )
3791+
3792+ # Test with extreme truncation points
3793+ left_truncated = dist .LeftTruncatedDistribution (base_dist , 5.0 ) # Far in the tail
3794+ right_truncated = dist .RightTruncatedDistribution (
3795+ base_dist , - 5.0
3796+ ) # Far in the tail
3797+ two_sided = dist .TwoSidedTruncatedDistribution (base_dist , - 0.1 , 0.1 ) # Very narrow
3798+
3799+ # Test that CDFs are well-behaved
3800+ test_values = jnp .array ([- 10.0 , 0.0 , 10.0 ])
3801+
3802+ left_cdf = left_truncated .cdf (test_values )
3803+ assert jnp .all (jnp .isfinite (left_cdf ))
3804+ assert jnp .all (left_cdf >= 0.0 ) and jnp .all (left_cdf <= 1.0 )
3805+
3806+ right_cdf = right_truncated .cdf (test_values )
3807+ assert jnp .all (jnp .isfinite (right_cdf ))
3808+ assert jnp .all (right_cdf >= 0.0 ) and jnp .all (right_cdf <= 1.0 )
3809+
3810+ two_sided_cdf = two_sided .cdf (test_values )
3811+ assert jnp .all (jnp .isfinite (two_sided_cdf ))
3812+ assert jnp .all (two_sided_cdf >= 0.0 ) and jnp .all (two_sided_cdf <= 1.0 )
3813+
3814+
3815+ @pytest .mark .parametrize ("batch_shape" , [(), (3 ,)])
3816+ def test_truncated_cdf_batch_shapes (batch_shape ):
3817+ """Test that CDF works correctly with batch shapes."""
3818+ if batch_shape == ():
3819+ loc = 0.0
3820+ scale = 1.0
3821+ low = - 1.0
3822+ high = 1.0
3823+ else :
3824+ loc = jnp .zeros (batch_shape )
3825+ scale = jnp .ones (batch_shape )
3826+ low = - jnp .ones (batch_shape )
3827+ high = jnp .ones (batch_shape )
3828+
3829+ base_dist = dist .Normal (loc , scale )
3830+ truncated_dist = dist .TwoSidedTruncatedDistribution (base_dist , low , high )
3831+
3832+ # Test with single value
3833+ value = 0.0
3834+ cdf_value = truncated_dist .cdf (value )
3835+ assert cdf_value .shape == batch_shape
3836+
3837+ # Test with multiple values - these should broadcast properly
3838+ if batch_shape == ():
3839+ values = jnp .array ([- 2.0 , 0.0 , 2.0 ])
3840+ cdf_values = truncated_dist .cdf (values )
3841+ expected_shape = values .shape
3842+ assert cdf_values .shape == expected_shape
3843+ else :
3844+ # For batched case, test with single values to avoid broadcasting issues
3845+ for value in [- 2.0 , 0.0 , 2.0 ]:
3846+ cdf_value = truncated_dist .cdf (value )
3847+ assert cdf_value .shape == batch_shape
0 commit comments