@@ -120,9 +120,9 @@ def __init__(
120
120
},
121
121
)
122
122
self .pre_y = xr .DataArray (
123
- self .pre_y [:, 0 ],
124
- dims = ["obs_ind" ],
125
- coords = {"obs_ind" : self .datapre .index },
123
+ self .pre_y , # Keep 2D shape
124
+ dims = ["obs_ind" , "treated_units" ],
125
+ coords = {"obs_ind" : self .datapre .index , "treated_units" : [ "unit_0" ] },
126
126
)
127
127
self .post_X = xr .DataArray (
128
128
self .post_X ,
@@ -133,17 +133,22 @@ def __init__(
133
133
},
134
134
)
135
135
self .post_y = xr .DataArray (
136
- self .post_y [:, 0 ],
137
- dims = ["obs_ind" ],
138
- coords = {"obs_ind" : self .datapost .index },
136
+ self .post_y , # Keep 2D shape
137
+ dims = ["obs_ind" , "treated_units" ],
138
+ coords = {"obs_ind" : self .datapost .index , "treated_units" : [ "unit_0" ] },
139
139
)
140
140
141
141
# fit the model to the observed (pre-intervention) data
142
142
if isinstance (self .model , PyMCModel ):
143
- COORDS = {"coeffs" : self .labels , "obs_ind" : np .arange (self .pre_X .shape [0 ])}
143
+ COORDS = {
144
+ "coeffs" : self .labels ,
145
+ "obs_ind" : np .arange (self .pre_X .shape [0 ]),
146
+ "treated_units" : ["unit_0" ],
147
+ }
144
148
self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
145
149
elif isinstance (self .model , RegressorMixin ):
146
- self .model .fit (X = self .pre_X , y = self .pre_y )
150
+ # For OLS models, use 1D y data
151
+ self .model .fit (X = self .pre_X , y = self .pre_y .isel (treated_units = 0 ))
147
152
else :
148
153
raise ValueError ("Model type not recognized" )
149
154
@@ -155,8 +160,21 @@ def __init__(
155
160
156
161
# calculate the counterfactual
157
162
self .post_pred = self .model .predict (X = self .post_X )
158
- self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
159
- self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
163
+
164
+ # calculate impact - use appropriate y data format for each model type
165
+ if isinstance (self .model , PyMCModel ):
166
+ # PyMC models work with 2D data
167
+ self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
168
+ self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
169
+ elif isinstance (self .model , RegressorMixin ):
170
+ # SKL models work with 1D data
171
+ self .pre_impact = self .model .calculate_impact (
172
+ self .pre_y .isel (treated_units = 0 ), self .pre_pred
173
+ )
174
+ self .post_impact = self .model .calculate_impact (
175
+ self .post_y .isel (treated_units = 0 ), self .post_pred
176
+ )
177
+
160
178
self .post_impact_cumulative = self .model .calculate_cumulative_impact (
161
179
self .post_impact
162
180
)
@@ -202,35 +220,53 @@ def _bayesian_plot(
202
220
# pre-intervention period
203
221
h_line , h_patch = plot_xY (
204
222
self .datapre .index ,
205
- self .pre_pred ["posterior_predictive" ].mu ,
223
+ self .pre_pred ["posterior_predictive" ].mu . isel ( treated_units = 0 ) ,
206
224
ax = ax [0 ],
207
225
plot_hdi_kwargs = {"color" : "C0" },
208
226
)
209
227
handles = [(h_line , h_patch )]
210
228
labels = ["Pre-intervention period" ]
211
229
212
- (h ,) = ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
230
+ (h ,) = ax [0 ].plot (
231
+ self .datapre .index ,
232
+ self .pre_y .isel (treated_units = 0 )
233
+ if hasattr (self .pre_y , "isel" )
234
+ else self .pre_y [:, 0 ],
235
+ "k." ,
236
+ label = "Observations" ,
237
+ )
213
238
handles .append (h )
214
239
labels .append ("Observations" )
215
240
216
241
# post intervention period
217
242
h_line , h_patch = plot_xY (
218
243
self .datapost .index ,
219
- self .post_pred ["posterior_predictive" ].mu ,
244
+ self .post_pred ["posterior_predictive" ].mu . isel ( treated_units = 0 ) ,
220
245
ax = ax [0 ],
221
246
plot_hdi_kwargs = {"color" : "C1" },
222
247
)
223
248
handles .append ((h_line , h_patch ))
224
249
labels .append (counterfactual_label )
225
250
226
- ax [0 ].plot (self .datapost .index , self .post_y , "k." )
251
+ ax [0 ].plot (
252
+ self .datapost .index ,
253
+ self .post_y .isel (treated_units = 0 )
254
+ if hasattr (self .post_y , "isel" )
255
+ else self .post_y [:, 0 ],
256
+ "k." ,
257
+ )
227
258
# Shaded causal effect
259
+ post_pred_mu = (
260
+ az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
261
+ .isel (treated_units = 0 )
262
+ .mean ("sample" )
263
+ ) # Add .mean("sample") to get 1D array
228
264
h = ax [0 ].fill_between (
229
265
self .datapost .index ,
230
- y1 = az . extract (
231
- self .post_pred , group = "posterior_predictive" , var_names = "mu"
232
- ). mean ( "sample" ),
233
- y2 = np . squeeze ( self .post_y ) ,
266
+ y1 = post_pred_mu ,
267
+ y2 = self .post_y . isel ( treated_units = 0 )
268
+ if hasattr ( self . post_y , "isel" )
269
+ else self .post_y [:, 0 ] ,
234
270
color = "C0" ,
235
271
alpha = 0.25 ,
236
272
)
@@ -239,28 +275,28 @@ def _bayesian_plot(
239
275
240
276
ax [0 ].set (
241
277
title = f"""
242
- Pre-intervention Bayesian $R^2$: { round_num (self .score . r2 , round_to )}
243
- (std = { round_num (self .score . r2_std , round_to )} )
278
+ Pre-intervention Bayesian $R^2$: { round_num (self .score [ "unit_0_r2" ] , round_to )}
279
+ (std = { round_num (self .score [ "unit_0_r2_std" ] , round_to )} )
244
280
"""
245
281
)
246
282
247
283
# MIDDLE PLOT -----------------------------------------------
248
284
plot_xY (
249
285
self .datapre .index ,
250
- self .pre_impact ,
286
+ self .pre_impact . isel ( treated_units = 0 ) ,
251
287
ax = ax [1 ],
252
288
plot_hdi_kwargs = {"color" : "C0" },
253
289
)
254
290
plot_xY (
255
291
self .datapost .index ,
256
- self .post_impact ,
292
+ self .post_impact . isel ( treated_units = 0 ) ,
257
293
ax = ax [1 ],
258
294
plot_hdi_kwargs = {"color" : "C1" },
259
295
)
260
296
ax [1 ].axhline (y = 0 , c = "k" )
261
297
ax [1 ].fill_between (
262
298
self .datapost .index ,
263
- y1 = self .post_impact .mean (["chain" , "draw" ]),
299
+ y1 = self .post_impact .mean (["chain" , "draw" ]). isel ( treated_units = 0 ) ,
264
300
color = "C0" ,
265
301
alpha = 0.25 ,
266
302
label = "Causal impact" ,
@@ -271,7 +307,7 @@ def _bayesian_plot(
271
307
ax [2 ].set (title = "Cumulative Causal Impact" )
272
308
plot_xY (
273
309
self .datapost .index ,
274
- self .post_impact_cumulative ,
310
+ self .post_impact_cumulative . isel ( treated_units = 0 ) ,
275
311
ax = ax [2 ],
276
312
plot_hdi_kwargs = {"color" : "C1" },
277
313
)
@@ -387,27 +423,45 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
387
423
pre_data ["prediction" ] = (
388
424
az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
389
425
.mean ("sample" )
426
+ .isel (treated_units = 0 )
390
427
.values
391
428
)
392
429
post_data ["prediction" ] = (
393
430
az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
394
431
.mean ("sample" )
432
+ .isel (treated_units = 0 )
395
433
.values
396
434
)
397
- pre_data [[ pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
435
+ hdi_pre_pred = get_hdi_to_df (
398
436
self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
399
- ). set_index ( pre_data . index )
400
- post_data [[ pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
437
+ )
438
+ hdi_post_pred = get_hdi_to_df (
401
439
self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
440
+ )
441
+ # Select the single unit from the MultiIndex results
442
+ pre_data [[pred_lower_col , pred_upper_col ]] = hdi_pre_pred .xs (
443
+ "unit_0" , level = "treated_units"
444
+ ).set_index (pre_data .index )
445
+ post_data [[pred_lower_col , pred_upper_col ]] = hdi_post_pred .xs (
446
+ "unit_0" , level = "treated_units"
402
447
).set_index (post_data .index )
403
448
404
- pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
405
- post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
406
- pre_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
407
- self .pre_impact , hdi_prob = hdi_prob
449
+ pre_data ["impact" ] = (
450
+ self .pre_impact .mean (dim = ["chain" , "draw" ]).isel (treated_units = 0 ).values
451
+ )
452
+ post_data ["impact" ] = (
453
+ self .post_impact .mean (dim = ["chain" , "draw" ])
454
+ .isel (treated_units = 0 )
455
+ .values
456
+ )
457
+ hdi_pre_impact = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob )
458
+ hdi_post_impact = get_hdi_to_df (self .post_impact , hdi_prob = hdi_prob )
459
+ # Select the single unit from the MultiIndex results
460
+ pre_data [[impact_lower_col , impact_upper_col ]] = hdi_pre_impact .xs (
461
+ "unit_0" , level = "treated_units"
408
462
).set_index (pre_data .index )
409
- post_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
410
- self . post_impact , hdi_prob = hdi_prob
463
+ post_data [[impact_lower_col , impact_upper_col ]] = hdi_post_impact . xs (
464
+ "unit_0" , level = "treated_units"
411
465
).set_index (post_data .index )
412
466
413
467
self .plot_data = pd .concat ([pre_data , post_data ])
0 commit comments