-
Notifications
You must be signed in to change notification settings - Fork 20
Open
Description
I'm running into jax backend issues when running a model that samples without error under the default backend. Switching to the jax backend via:
pm.sample(1000, tune=1000, chains=2, nuts_sampler='nutpie', nuts_sampler_kwargs={'backend': 'jax'}, random_seed=RANDOM_SEED)
results in the following panic in a nutpie thread:
thread 'nutpie-worker-1' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[67], line 32
28 p = pm.Deterministic('p', pm.math.invlogit(logit_p), dims='obs')
30 y = pm.Binomial('y', n=pa, p=p, observed=hr)
---> 32 gp_covariate_trace = pm.sample(1000, tune=1000, chains=2, nuts_sampler='nutpie', nuts_sampler_kwargs={'backend': 'jax'}, random_seed=RANDOM_SEED)
File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:809, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
804 raise ValueError(
805 "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
806 )
808 with joined_blas_limiter():
--> 809 return _sample_external_nuts(
810 sampler=nuts_sampler,
811 draws=draws,
812 tune=tune,
813 chains=chains,
814 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
815 random_seed=random_seed,
816 initvals=initvals,
817 model=model,
818 var_names=var_names,
819 progressbar=progress_bool,
820 idata_kwargs=idata_kwargs,
821 compute_convergence_checks=compute_convergence_checks,
822 nuts_sampler_kwargs=nuts_sampler_kwargs,
823 **kwargs,
824 )
826 if exclusive_nuts and not provided_steps:
827 # Special path for NUTS initialization
828 if "nuts" in kwargs:
File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:349, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
344 compiled_model = nutpie.compile_pymc_model(
345 model,
346 **compile_kwargs,
347 )
348 t_start = time.time()
--> 349 idata = nutpie.sample(
350 compiled_model,
351 draws=draws,
352 tune=tune,
353 chains=chains,
354 target_accept=target_accept,
355 seed=_get_seeds_per_chain(random_seed, 1)[0],
356 progress_bar=progressbar,
357 **nuts_sampler_kwargs,
358 )
359 t_sample = time.time() - t_start
360 # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
361 # gather observed and constant data as nutpie.sample() has no access to the PyMC model
File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/nutpie/sample.py:654, in sample(compiled_model, draws, tune, chains, cores, seed, save_warmup, progress_bar, low_rank_modified_mass_matrix, transform_adapt, init_mean, return_raw_trace, blocking, progress_template, progress_style, progress_rate, **kwargs)
651 return sampler
653 try:
--> 654 result = sampler.wait()
655 except KeyboardInterrupt:
656 result = sampler.abort()
File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/nutpie/sample.py:388, in _BackgroundSampler.wait(self, timeout)
378 def wait(self, *, timeout=None):
379 """Wait until sampling is finished and return the trace.
380
381 KeyboardInterrupt will lead to interrupt the waiting.
(...)
386 This resumes the sampler in case it had been paused.
387 """
--> 388 self._sampler.wait(timeout)
389 results = self._sampler.extract_results()
390 return self._extract(results)
RuntimeError: All initialization points failed
Caused by:
Logp function returned error: PyError(PyErr { type: <class 'AttributeError'>, value: AttributeError("module 'jax.lax' has no attribute 'mul_without_zeros'"), traceback: Some(<traceback object at 0x7f813ad75080>) })
Running on the following environment:
Python implementation: CPython
Python version : 3.12.8
IPython version : 8.32.0
numpy : 1.26.4
scipy : 1.12.0
pymc : 5.20.1
preliz : 0.15.0
nutpie : 0.14.2
pandas : 2.2.3
pytensor : 2.27.1
matplotlib: 3.10.0
plotly : 6.0.0
polars : 1.24.0
arviz : 0.20.0
Metadata
Metadata
Assignees
Labels
No labels