Skip to content

Commit b68cbbc

Browse files
committed
Fix BurrIII variance to match scipy nan/inf; update test to allow nan==nan in allclose
1 parent 15983fe commit b68cbbc

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

skpro/distributions/adapters/scipy/tests/test_scipy_adapters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_method_no_params(self, object_instance, method, scipy_method):
8585

8686
scipy_res = getattr(scipy_obj, scipy_method)(*params[0], **params[1])
8787

88-
assert np.allclose(res, scipy_res)
88+
assert np.allclose(res, scipy_res, equal_nan=True)
8989

9090
@pytest.mark.parametrize("method,scipy_method", METHOD_TESTS["X_PARAMS"])
9191
@pytest.mark.parametrize("x", X_VALUES)

skpro/distributions/burr_iii.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,18 @@ def _mean(self):
6060
scale = self._bc_params["scale"]
6161
return burr12.mean(c, 1, scale=scale)
6262

63+
6364
def _var(self):
6465
c = self._bc_params["c"]
6566
scale = self._bc_params["scale"]
66-
v = burr12.var(c, 1, scale=scale)
67-
# If variance is nan, return np.inf to pass assert res >= 0
6867
import numpy as np
69-
70-
return v if np.isfinite(v) and v >= 0 else np.inf
68+
v = burr12.var(c, 1, scale=scale)
69+
# Match scipy: if nan, return nan; if negative/overflow, return inf
70+
if np.isnan(v):
71+
return np.nan
72+
if not np.isfinite(v) or v < 0:
73+
return np.inf
74+
return v
7175

7276
@classmethod
7377
def get_test_params(cls, parameter_set="default"):

0 commit comments

Comments
 (0)