Skip to content

Commit 10a017e

Browse files
committed
New feature : Model to estimate when a intervention had effect
1 parent 6a9e3ad commit 10a017e

File tree

2 files changed

+114
-3
lines changed

2 files changed

+114
-3
lines changed

causalpy/pymc_models.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,114 @@ def fit(self, X, t, coords):
497497
)
498498
)
499499
return self.idata
500+
501+
502+
class InterventionTimeEstimator(PyMCModel):
503+
r"""
504+
Custom PyMC model to estimate the time an intervetnion took place.
505+
506+
defines the PyMC model :
507+
508+
.. math::
509+
\alpha &\sim \mathrm{Normal}(0, 1) \\
510+
\beta &\sim \mathrm{Normal}(0, 1) \\
511+
s(t) &= \gamma_{i(t)} \quad \textrm{with} \quad \gamma_{k \in [0, ..., n_{seasons}-1]} \sim \mathrm{Normal}(0, 1)\\
512+
base_{\mu}(t) &= \alpha + \beta \cdot t + s_t\\
513+
\\
514+
\tau &\sim \mathrm{Uniform}(0, 1) \\
515+
w(t) &= sigmoid(t-\tau) \\
516+
\\
517+
level &\sim \mathrm{Normal}(0, 1) \\
518+
trend &\sim \mathrm{Normal}(0, 1) \\
519+
A &\sim \mathrm{Normal}(0, 1) \\
520+
\lambda &\sim \mathrm{HalfNormal}(0, 1) \\
521+
impulse(t) &= A \cdot exp(-\lambda \cdot |t-\tau|) \\
522+
intervention(t) &= level + trend \cdot (t-\tau) + impulse_t\\
523+
\\
524+
\sigma &\sim \mathrm{Normal}(0, 1) \\
525+
\mu(t) &= base_{\mu}(t) + w(t) \cdot intervention(t) \\
526+
\\
527+
y(t) &\sim \mathrm{Normal}(\mu (t), \sigma)
528+
529+
Example
530+
--------
531+
>>> import causalpy as cp
532+
>>> import numpy as np
533+
>>> from causalpy.pymc_models import InterventionTimeEstimator
534+
>>> df = cp.load("its")
535+
>>> y = df["y"].values
536+
>>> t = df["t"].values
537+
>>> coords = {"sseasons" = range(12)} # The data is monthly
538+
>>> estimator = InterventionTimeEstimator()
539+
>>> # We are trying to capture an impulse in the number of death per month due to Covid.
540+
>>> estimator.fit(
541+
... t,
542+
... y,
543+
... coords,
544+
... effect=["impulse"])
545+
Inference data...
546+
"""
547+
548+
def build_model(self, t, y, coords, effect, span, grain_season):
549+
"""
550+
Defines the PyMC model
551+
552+
:param t: An array of values representing the time over which y is spread
553+
:param y: An array of values representing our outcome y
554+
:param coords: A dictionary with the coordinate names for our instruments
555+
"""
556+
557+
with self:
558+
self.add_coords(coords)
559+
560+
if span is None:
561+
span = (t.min(), t.max())
562+
563+
# --- Priors ---
564+
switchpoint = pm.Uniform("switchpoint", lower=span[0], upper=span[1])
565+
alpha = pm.Normal(name="alpha", mu=0, sigma=10)
566+
beta = pm.Normal(name="beta", mu=0, sigma=10)
567+
seasons = 0
568+
if "seasons" in coords and len(coords["seasons"]) > 0:
569+
season_idx = np.arange(len(y)) // grain_season % len(coords["seasons"])
570+
season_effect = pm.Normal("season", mu=0, sigma=1, dims="seasons")
571+
seasons = season_effect[season_idx]
572+
573+
# --- Intervention effect ---
574+
level = trend = impulse = 0
575+
576+
if "level" in effect:
577+
level = pm.Normal("level", mu=0, sigma=10)
578+
579+
if "trend" in effect:
580+
trend = pm.Normal("trend", mu=0, sigma=10)
581+
582+
if "impulse" in effect:
583+
impulse_amplitude = pm.Normal("impulse_amplitude", mu=0, sigma=1)
584+
decay_rate = pm.HalfNormal("decay_rate", sigma=1)
585+
impulse = impulse_amplitude * pm.math.exp(
586+
-decay_rate * abs(t - switchpoint)
587+
)
588+
589+
# --- Parameterization ---
590+
weight = pm.math.sigmoid(t - switchpoint)
591+
# Compute and store the modelled time series
592+
mu_ts = pm.Deterministic(name="mu_ts", var=alpha + beta * t + seasons)
593+
# Compute and store the modelled intervention effect
594+
mu_in = pm.Deterministic(
595+
name="mu_in", var=level + trend * (t - switchpoint) + impulse
596+
)
597+
# Compute and store the the sum of the intervention and the time series
598+
mu = pm.Deterministic("mu", mu_ts + weight * mu_in)
599+
600+
# --- Likelihood ---
601+
pm.Normal("y_hat", mu=mu, sigma=2, observed=y)
602+
603+
def fit(self, t, y, coords, effect=[], span=None, grain_season=1, n=1000):
604+
"""
605+
Draw samples from posterior distribution
606+
"""
607+
self.build_model(t, y, coords, effect, span, grain_season)
608+
with self:
609+
self.idata = pm.sample(n, **self.sample_kwargs)
610+
return self.idata

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)