Skip to content

Commit 0449370

Browse files
authored
Merge pull request #494 from pymc-labs/multi-cell-geolift
Enable multiple treated units in synthetic control quasi experiments
2 parents 185a03e + b213b0f commit 0449370

19 files changed

+1513
-545
lines changed

causalpy/experiments/diff_in_diff.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,18 @@ def __init__(
114114
},
115115
)
116116
self.y = xr.DataArray(
117-
self.y[:, 0],
118-
dims=["obs_ind"],
119-
coords={"obs_ind": np.arange(self.y.shape[0])},
117+
self.y,
118+
dims=["obs_ind", "treated_units"],
119+
coords={"obs_ind": np.arange(self.y.shape[0]), "treated_units": ["unit_0"]},
120120
)
121121

122122
# fit model
123123
if isinstance(self.model, PyMCModel):
124-
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])}
124+
COORDS = {
125+
"coeffs": self.labels,
126+
"obs_ind": np.arange(self.X.shape[0]),
127+
"treated_units": ["unit_0"],
128+
}
125129
self.model.fit(X=self.X, y=self.y, coords=COORDS)
126130
elif isinstance(self.model, RegressorMixin):
127131
self.model.fit(X=self.X, y=self.y)
@@ -203,7 +207,7 @@ def __init__(
203207
# TODO: CHECK FOR CORRECTNESS
204208
self.causal_impact = (
205209
self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
206-
)
210+
).item()
207211
else:
208212
raise ValueError("Model type not recognized")
209213

@@ -321,7 +325,7 @@ def _plot_causal_impact_arrow(results, ax):
321325
time_points = self.x_pred_control[self.time_variable_name].values
322326
h_line, h_patch = plot_xY(
323327
time_points,
324-
self.y_pred_control.posterior_predictive.mu,
328+
self.y_pred_control["posterior_predictive"].mu.isel(treated_units=0),
325329
ax=ax,
326330
plot_hdi_kwargs={"color": "C0"},
327331
label="Control group",
@@ -333,7 +337,7 @@ def _plot_causal_impact_arrow(results, ax):
333337
time_points = self.x_pred_control[self.time_variable_name].values
334338
h_line, h_patch = plot_xY(
335339
time_points,
336-
self.y_pred_treatment.posterior_predictive.mu,
340+
self.y_pred_treatment["posterior_predictive"].mu.isel(treated_units=0),
337341
ax=ax,
338342
plot_hdi_kwargs={"color": "C1"},
339343
label="Treatment group",
@@ -345,12 +349,20 @@ def _plot_causal_impact_arrow(results, ax):
345349
# had occurred.
346350
time_points = self.x_pred_counterfactual[self.time_variable_name].values
347351
if len(time_points) == 1:
352+
y_pred_cf = az.extract(
353+
self.y_pred_counterfactual,
354+
group="posterior_predictive",
355+
var_names="mu",
356+
)
357+
# Select single unit data for plotting
358+
y_pred_cf_single = y_pred_cf.isel(treated_units=0)
359+
violin_data = (
360+
y_pred_cf_single.values
361+
if hasattr(y_pred_cf_single, "values")
362+
else y_pred_cf_single
363+
)
348364
parts = ax.violinplot(
349-
az.extract(
350-
self.y_pred_counterfactual,
351-
group="posterior_predictive",
352-
var_names="mu",
353-
).values.T,
365+
violin_data.T,
354366
positions=self.x_pred_counterfactual[self.time_variable_name].values,
355367
showmeans=False,
356368
showmedians=False,
@@ -363,7 +375,9 @@ def _plot_causal_impact_arrow(results, ax):
363375
else:
364376
h_line, h_patch = plot_xY(
365377
time_points,
366-
self.y_pred_counterfactual.posterior_predictive.mu,
378+
self.y_pred_counterfactual.posterior_predictive.mu.isel(
379+
treated_units=0
380+
),
367381
ax=ax,
368382
plot_hdi_kwargs={"color": "C2"},
369383
label="Counterfactual",

causalpy/experiments/interrupted_time_series.py

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def __init__(
120120
},
121121
)
122122
self.pre_y = xr.DataArray(
123-
self.pre_y[:, 0],
124-
dims=["obs_ind"],
125-
coords={"obs_ind": self.datapre.index},
123+
self.pre_y, # Keep 2D shape
124+
dims=["obs_ind", "treated_units"],
125+
coords={"obs_ind": self.datapre.index, "treated_units": ["unit_0"]},
126126
)
127127
self.post_X = xr.DataArray(
128128
self.post_X,
@@ -133,17 +133,22 @@ def __init__(
133133
},
134134
)
135135
self.post_y = xr.DataArray(
136-
self.post_y[:, 0],
137-
dims=["obs_ind"],
138-
coords={"obs_ind": self.datapost.index},
136+
self.post_y, # Keep 2D shape
137+
dims=["obs_ind", "treated_units"],
138+
coords={"obs_ind": self.datapost.index, "treated_units": ["unit_0"]},
139139
)
140140

141141
# fit the model to the observed (pre-intervention) data
142142
if isinstance(self.model, PyMCModel):
143-
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])}
143+
COORDS = {
144+
"coeffs": self.labels,
145+
"obs_ind": np.arange(self.pre_X.shape[0]),
146+
"treated_units": ["unit_0"],
147+
}
144148
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
145149
elif isinstance(self.model, RegressorMixin):
146-
self.model.fit(X=self.pre_X, y=self.pre_y)
150+
# For OLS models, use 1D y data
151+
self.model.fit(X=self.pre_X, y=self.pre_y.isel(treated_units=0))
147152
else:
148153
raise ValueError("Model type not recognized")
149154

@@ -155,8 +160,21 @@ def __init__(
155160

156161
# calculate the counterfactual
157162
self.post_pred = self.model.predict(X=self.post_X)
158-
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
159-
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)
163+
164+
# calculate impact - use appropriate y data format for each model type
165+
if isinstance(self.model, PyMCModel):
166+
# PyMC models work with 2D data
167+
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
168+
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)
169+
elif isinstance(self.model, RegressorMixin):
170+
# SKL models work with 1D data
171+
self.pre_impact = self.model.calculate_impact(
172+
self.pre_y.isel(treated_units=0), self.pre_pred
173+
)
174+
self.post_impact = self.model.calculate_impact(
175+
self.post_y.isel(treated_units=0), self.post_pred
176+
)
177+
160178
self.post_impact_cumulative = self.model.calculate_cumulative_impact(
161179
self.post_impact
162180
)
@@ -202,35 +220,53 @@ def _bayesian_plot(
202220
# pre-intervention period
203221
h_line, h_patch = plot_xY(
204222
self.datapre.index,
205-
self.pre_pred["posterior_predictive"].mu,
223+
self.pre_pred["posterior_predictive"].mu.isel(treated_units=0),
206224
ax=ax[0],
207225
plot_hdi_kwargs={"color": "C0"},
208226
)
209227
handles = [(h_line, h_patch)]
210228
labels = ["Pre-intervention period"]
211229

212-
(h,) = ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
230+
(h,) = ax[0].plot(
231+
self.datapre.index,
232+
self.pre_y.isel(treated_units=0)
233+
if hasattr(self.pre_y, "isel")
234+
else self.pre_y[:, 0],
235+
"k.",
236+
label="Observations",
237+
)
213238
handles.append(h)
214239
labels.append("Observations")
215240

216241
# post intervention period
217242
h_line, h_patch = plot_xY(
218243
self.datapost.index,
219-
self.post_pred["posterior_predictive"].mu,
244+
self.post_pred["posterior_predictive"].mu.isel(treated_units=0),
220245
ax=ax[0],
221246
plot_hdi_kwargs={"color": "C1"},
222247
)
223248
handles.append((h_line, h_patch))
224249
labels.append(counterfactual_label)
225250

226-
ax[0].plot(self.datapost.index, self.post_y, "k.")
251+
ax[0].plot(
252+
self.datapost.index,
253+
self.post_y.isel(treated_units=0)
254+
if hasattr(self.post_y, "isel")
255+
else self.post_y[:, 0],
256+
"k.",
257+
)
227258
# Shaded causal effect
259+
post_pred_mu = (
260+
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
261+
.isel(treated_units=0)
262+
.mean("sample")
263+
) # Add .mean("sample") to get 1D array
228264
h = ax[0].fill_between(
229265
self.datapost.index,
230-
y1=az.extract(
231-
self.post_pred, group="posterior_predictive", var_names="mu"
232-
).mean("sample"),
233-
y2=np.squeeze(self.post_y),
266+
y1=post_pred_mu,
267+
y2=self.post_y.isel(treated_units=0)
268+
if hasattr(self.post_y, "isel")
269+
else self.post_y[:, 0],
234270
color="C0",
235271
alpha=0.25,
236272
)
@@ -239,28 +275,28 @@ def _bayesian_plot(
239275

240276
ax[0].set(
241277
title=f"""
242-
Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)}
243-
(std = {round_num(self.score.r2_std, round_to)})
278+
Pre-intervention Bayesian $R^2$: {round_num(self.score["unit_0_r2"], round_to)}
279+
(std = {round_num(self.score["unit_0_r2_std"], round_to)})
244280
"""
245281
)
246282

247283
# MIDDLE PLOT -----------------------------------------------
248284
plot_xY(
249285
self.datapre.index,
250-
self.pre_impact,
286+
self.pre_impact.isel(treated_units=0),
251287
ax=ax[1],
252288
plot_hdi_kwargs={"color": "C0"},
253289
)
254290
plot_xY(
255291
self.datapost.index,
256-
self.post_impact,
292+
self.post_impact.isel(treated_units=0),
257293
ax=ax[1],
258294
plot_hdi_kwargs={"color": "C1"},
259295
)
260296
ax[1].axhline(y=0, c="k")
261297
ax[1].fill_between(
262298
self.datapost.index,
263-
y1=self.post_impact.mean(["chain", "draw"]),
299+
y1=self.post_impact.mean(["chain", "draw"]).isel(treated_units=0),
264300
color="C0",
265301
alpha=0.25,
266302
label="Causal impact",
@@ -271,7 +307,7 @@ def _bayesian_plot(
271307
ax[2].set(title="Cumulative Causal Impact")
272308
plot_xY(
273309
self.datapost.index,
274-
self.post_impact_cumulative,
310+
self.post_impact_cumulative.isel(treated_units=0),
275311
ax=ax[2],
276312
plot_hdi_kwargs={"color": "C1"},
277313
)
@@ -387,27 +423,45 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
387423
pre_data["prediction"] = (
388424
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
389425
.mean("sample")
426+
.isel(treated_units=0)
390427
.values
391428
)
392429
post_data["prediction"] = (
393430
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
394431
.mean("sample")
432+
.isel(treated_units=0)
395433
.values
396434
)
397-
pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
435+
hdi_pre_pred = get_hdi_to_df(
398436
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
399-
).set_index(pre_data.index)
400-
post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
437+
)
438+
hdi_post_pred = get_hdi_to_df(
401439
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
440+
)
441+
# Select the single unit from the MultiIndex results
442+
pre_data[[pred_lower_col, pred_upper_col]] = hdi_pre_pred.xs(
443+
"unit_0", level="treated_units"
444+
).set_index(pre_data.index)
445+
post_data[[pred_lower_col, pred_upper_col]] = hdi_post_pred.xs(
446+
"unit_0", level="treated_units"
402447
).set_index(post_data.index)
403448

404-
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
405-
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
406-
pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
407-
self.pre_impact, hdi_prob=hdi_prob
449+
pre_data["impact"] = (
450+
self.pre_impact.mean(dim=["chain", "draw"]).isel(treated_units=0).values
451+
)
452+
post_data["impact"] = (
453+
self.post_impact.mean(dim=["chain", "draw"])
454+
.isel(treated_units=0)
455+
.values
456+
)
457+
hdi_pre_impact = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob)
458+
hdi_post_impact = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob)
459+
# Select the single unit from the MultiIndex results
460+
pre_data[[impact_lower_col, impact_upper_col]] = hdi_pre_impact.xs(
461+
"unit_0", level="treated_units"
408462
).set_index(pre_data.index)
409-
post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
410-
self.post_impact, hdi_prob=hdi_prob
463+
post_data[[impact_lower_col, impact_upper_col]] = hdi_post_impact.xs(
464+
"unit_0", level="treated_units"
411465
).set_index(post_data.index)
412466

413467
self.plot_data = pd.concat([pre_data, post_data])

causalpy/experiments/prepostnegd.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,18 @@ def __init__(
122122
},
123123
)
124124
self.y = xr.DataArray(
125-
self.y[:, 0],
126-
dims=["obs_ind"],
127-
coords={"obs_ind": self.data.index},
125+
self.y,
126+
dims=["obs_ind", "treated_units"],
127+
coords={"obs_ind": self.data.index, "treated_units": ["unit_0"]},
128128
)
129129

130130
# fit the model to the observed (pre-intervention) data
131131
if isinstance(self.model, PyMCModel):
132-
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])}
132+
COORDS = {
133+
"coeffs": self.labels,
134+
"obs_ind": np.arange(self.X.shape[0]),
135+
"treated_units": ["unit_0"],
136+
}
133137
self.model.fit(X=self.X, y=self.y, coords=COORDS)
134138
elif isinstance(self.model, RegressorMixin):
135139
raise NotImplementedError("Not implemented for OLS model")
@@ -239,7 +243,7 @@ def _bayesian_plot(
239243
# plot posterior predictive of untreated
240244
h_line, h_patch = plot_xY(
241245
self.pred_xi,
242-
self.pred_untreated["posterior_predictive"].mu,
246+
self.pred_untreated["posterior_predictive"].mu.isel(treated_units=0),
243247
ax=ax[0],
244248
plot_hdi_kwargs={"color": "C0"},
245249
label="Control group",
@@ -250,7 +254,7 @@ def _bayesian_plot(
250254
# plot posterior predictive of treated
251255
h_line, h_patch = plot_xY(
252256
self.pred_xi,
253-
self.pred_treated["posterior_predictive"].mu,
257+
self.pred_treated["posterior_predictive"].mu.isel(treated_units=0),
254258
ax=ax[0],
255259
plot_hdi_kwargs={"color": "C1"},
256260
label="Treatment group",

causalpy/experiments/regression_discontinuity.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,19 @@ def __init__(
131131
},
132132
)
133133
self.y = xr.DataArray(
134-
self.y[:, 0],
135-
dims=["obs_ind"],
136-
coords={"obs_ind": np.arange(self.y.shape[0])},
134+
self.y,
135+
dims=["obs_ind", "treated_units"],
136+
coords={"obs_ind": np.arange(self.y.shape[0]), "treated_units": ["unit_0"]},
137137
)
138138

139139
# fit model
140140
if isinstance(self.model, PyMCModel):
141141
# fit the model to the observed (pre-intervention) data
142-
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])}
142+
COORDS = {
143+
"coeffs": self.labels,
144+
"obs_ind": np.arange(self.X.shape[0]),
145+
"treated_units": ["unit_0"],
146+
}
143147
self.model.fit(X=self.X, y=self.y, coords=COORDS)
144148
elif isinstance(self.model, RegressorMixin):
145149
self.model.fit(X=self.X, y=self.y)
@@ -248,15 +252,15 @@ def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]
248252
# Plot model fit to data
249253
h_line, h_patch = plot_xY(
250254
self.x_pred[self.running_variable_name],
251-
self.pred["posterior_predictive"].mu,
255+
self.pred["posterior_predictive"].mu.isel(treated_units=0),
252256
ax=ax,
253257
plot_hdi_kwargs={"color": "C1"},
254258
)
255259
handles = [(h_line, h_patch)]
256260
labels = ["Posterior mean"]
257261

258262
# create strings to compose title
259-
title_info = f"{round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})"
263+
title_info = f"{round_num(self.score['unit_0_r2'], round_to)} (std = {round_num(self.score['unit_0_r2_std'], round_to)})"
260264
r2 = f"Bayesian $R^2$ on all data = {title_info}"
261265
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
262266
ci = (

0 commit comments

Comments
 (0)