@@ -497,3 +497,114 @@ def fit(self, X, t, coords):
497
497
)
498
498
)
499
499
return self .idata
500
+
501
+
502
+ class InterventionTimeEstimator (PyMCModel ):
503
+ r"""
504
+ Custom PyMC model to estimate the time an intervetnion took place.
505
+
506
+ defines the PyMC model :
507
+
508
+ .. math::
509
+ \alpha &\sim \mathrm{Normal}(0, 1) \\
510
+ \beta &\sim \mathrm{Normal}(0, 1) \\
511
+ s(t) &= \gamma_{i(t)} \quad \textrm{with} \quad \gamma_{k \in [0, ..., n_{seasons}-1]} \sim \mathrm{Normal}(0, 1)\\
512
+ base_{\mu}(t) &= \alpha + \beta \cdot t + s_t\\
513
+ \\
514
+ \tau &\sim \mathrm{Uniform}(0, 1) \\
515
+ w(t) &= sigmoid(t-\tau) \\
516
+ \\
517
+ level &\sim \mathrm{Normal}(0, 1) \\
518
+ trend &\sim \mathrm{Normal}(0, 1) \\
519
+ A &\sim \mathrm{Normal}(0, 1) \\
520
+ \lambda &\sim \mathrm{HalfNormal}(0, 1) \\
521
+ impulse(t) &= A \cdot exp(-\lambda \cdot |t-\tau|) \\
522
+ intervention(t) &= level + trend \cdot (t-\tau) + impulse_t\\
523
+ \\
524
+ \sigma &\sim \mathrm{Normal}(0, 1) \\
525
+ \mu(t) &= base_{\mu}(t) + w(t) \cdot intervention(t) \\
526
+ \\
527
+ y(t) &\sim \mathrm{Normal}(\mu (t), \sigma)
528
+
529
+ Example
530
+ --------
531
+ >>> import causalpy as cp
532
+ >>> import numpy as np
533
+ >>> from causalpy.pymc_models import InterventionTimeEstimator
534
+ >>> df = cp.load("its")
535
+ >>> y = df["y"].values
536
+ >>> t = df["t"].values
537
+ >>> coords = {"sseasons" = range(12)} # The data is monthly
538
+ >>> estimator = InterventionTimeEstimator()
539
+ >>> # We are trying to capture an impulse in the number of death per month due to Covid.
540
+ >>> estimator.fit(
541
+ ... t,
542
+ ... y,
543
+ ... coords,
544
+ ... effect=["impulse"])
545
+ Inference data...
546
+ """
547
+
548
+ def build_model (self , t , y , coords , effect , span , grain_season ):
549
+ """
550
+ Defines the PyMC model
551
+
552
+ :param t: An array of values representing the time over which y is spread
553
+ :param y: An array of values representing our outcome y
554
+ :param coords: A dictionary with the coordinate names for our instruments
555
+ """
556
+
557
+ with self :
558
+ self .add_coords (coords )
559
+
560
+ if span is None :
561
+ span = (t .min (), t .max ())
562
+
563
+ # --- Priors ---
564
+ switchpoint = pm .Uniform ("switchpoint" , lower = span [0 ], upper = span [1 ])
565
+ alpha = pm .Normal (name = "alpha" , mu = 0 , sigma = 10 )
566
+ beta = pm .Normal (name = "beta" , mu = 0 , sigma = 10 )
567
+ seasons = 0
568
+ if "seasons" in coords and len (coords ["seasons" ]) > 0 :
569
+ season_idx = np .arange (len (y )) // grain_season % len (coords ["seasons" ])
570
+ season_effect = pm .Normal ("season" , mu = 0 , sigma = 1 , dims = "seasons" )
571
+ seasons = season_effect [season_idx ]
572
+
573
+ # --- Intervention effect ---
574
+ level = trend = impulse = 0
575
+
576
+ if "level" in effect :
577
+ level = pm .Normal ("level" , mu = 0 , sigma = 10 )
578
+
579
+ if "trend" in effect :
580
+ trend = pm .Normal ("trend" , mu = 0 , sigma = 10 )
581
+
582
+ if "impulse" in effect :
583
+ impulse_amplitude = pm .Normal ("impulse_amplitude" , mu = 0 , sigma = 1 )
584
+ decay_rate = pm .HalfNormal ("decay_rate" , sigma = 1 )
585
+ impulse = impulse_amplitude * pm .math .exp (
586
+ - decay_rate * abs (t - switchpoint )
587
+ )
588
+
589
+ # --- Parameterization ---
590
+ weight = pm .math .sigmoid (t - switchpoint )
591
+ # Compute and store the modelled time series
592
+ mu_ts = pm .Deterministic (name = "mu_ts" , var = alpha + beta * t + seasons )
593
+ # Compute and store the modelled intervention effect
594
+ mu_in = pm .Deterministic (
595
+ name = "mu_in" , var = level + trend * (t - switchpoint ) + impulse
596
+ )
597
+ # Compute and store the the sum of the intervention and the time series
598
+ mu = pm .Deterministic ("mu" , mu_ts + weight * mu_in )
599
+
600
+ # --- Likelihood ---
601
+ pm .Normal ("y_hat" , mu = mu , sigma = 2 , observed = y )
602
+
603
+ def fit (self , t , y , coords , effect = [], span = None , grain_season = 1 , n = 1000 ):
604
+ """
605
+ Draw samples from posterior distribution
606
+ """
607
+ self .build_model (t , y , coords , effect , span , grain_season )
608
+ with self :
609
+ self .idata = pm .sample (n , ** self .sample_kwargs )
610
+ return self .idata
0 commit comments