Skip to content

Commit cd363d3

Browse files
authored
patch param is a lambda function (#2078)
1 parent fb14da7 commit cd363d3

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

numpyro/infer/inspect.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,9 @@ def get_trace():
320320
site["fn_name"] = _get_dist_name(site.pop("fn"))
321321
elif site["type"] == "deterministic":
322322
site["fn_name"] = "Deterministic"
323+
elif site["type"] == "param":
324+
# Remove lambda functions from param args to avoid jax.eval_shape issues
325+
site.pop("args", None)
323326
return PytreeTrace(trace)
324327

325328
# We use eval_shape to avoid any array computation.

test/test_model_rendering.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,42 @@ def model():
144144
render_model(model, filename="graph.png")
145145
assert os.path.exists("graph.png")
146146
os.remove("graph.png")
147+
148+
149+
def test_param_with_lambda_function():
150+
"""
151+
Test that get_model_relations works when params are initialized with lambda functions.
152+
Regression test for issue #2064.
153+
"""
154+
155+
def guide():
156+
numpyro.param("p", lambda _: 1.0)
157+
158+
# This should not raise a TypeError about lambda functions not being valid JAX types
159+
relations = get_model_relations(guide)
160+
161+
# Verify the param is captured correctly
162+
assert "p" in relations["param_constraint"]
163+
assert relations["param_constraint"]["p"] == ""
164+
assert relations["sample_sample"] == {}
165+
assert relations["sample_param"] == {}
166+
assert relations["sample_dist"] == {}
167+
assert relations["observed"] == []
168+
169+
170+
def test_param_with_lambda_and_sample():
171+
"""
172+
Test that get_model_relations works with both params (lambda) and sample sites.
173+
"""
174+
175+
def model():
176+
p = numpyro.param("p", lambda _: jnp.array([0.5, 0.5]))
177+
numpyro.sample("x", dist.Categorical(p))
178+
179+
relations = get_model_relations(model)
180+
181+
# Verify both param and sample are captured
182+
assert "p" in relations["param_constraint"]
183+
assert "x" in relations["sample_dist"]
184+
assert relations["sample_dist"]["x"] == "CategoricalProbs"
185+
assert "p" in relations["sample_param"]["x"]

0 commit comments

Comments
 (0)