@@ -144,3 +144,42 @@ def model():
144144 render_model (model , filename = "graph.png" )
145145 assert os .path .exists ("graph.png" )
146146 os .remove ("graph.png" )
147+
148+
149+ def test_param_with_lambda_function ():
150+ """
151+ Test that get_model_relations works when params are initialized with lambda functions.
152+ Regression test for issue #2064.
153+ """
154+
155+ def guide ():
156+ numpyro .param ("p" , lambda _ : 1.0 )
157+
158+ # This should not raise a TypeError about lambda functions not being valid JAX types
159+ relations = get_model_relations (guide )
160+
161+ # Verify the param is captured correctly
162+ assert "p" in relations ["param_constraint" ]
163+ assert relations ["param_constraint" ]["p" ] == ""
164+ assert relations ["sample_sample" ] == {}
165+ assert relations ["sample_param" ] == {}
166+ assert relations ["sample_dist" ] == {}
167+ assert relations ["observed" ] == []
168+
169+
170+ def test_param_with_lambda_and_sample ():
171+ """
172+ Test that get_model_relations works with both params (lambda) and sample sites.
173+ """
174+
175+ def model ():
176+ p = numpyro .param ("p" , lambda _ : jnp .array ([0.5 , 0.5 ]))
177+ numpyro .sample ("x" , dist .Categorical (p ))
178+
179+ relations = get_model_relations (model )
180+
181+ # Verify both param and sample are captured
182+ assert "p" in relations ["param_constraint" ]
183+ assert "x" in relations ["sample_dist" ]
184+ assert relations ["sample_dist" ]["x" ] == "CategoricalProbs"
185+ assert "p" in relations ["sample_param" ]["x" ]
0 commit comments