Skip to content

Commit 294c2ef

Browse files
committed
partially support plotnine v0.12.1
1 parent aaa5160 commit 294c2ef

File tree

1 file changed

+58
-46
lines changed

1 file changed

+58
-46
lines changed

patchworklib/patchworklib.py

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def _reset_ggplot_legend(bricks):
179179
tmp_artist.remove()
180180
else:
181181
bricks._case.artists.remove(bricks._ggplot_legend)
182-
182+
183183
anchored_box = AnchoredOffsetbox(
184184
loc=bricks._ggplot_legend_loc,
185185
child=bricks._ggplot_legend_box,
@@ -189,9 +189,10 @@ def _reset_ggplot_legend(bricks):
189189
bbox_transform = bricks._case.transAxes,
190190
borderpad=0.)
191191
anchored_box.set_zorder(90.1)
192+
anchored_box.set_in_layout(True)
192193
try:
193194
bricks._case.add_artist(anchored_box)
194-
except:
195+
except Exception as e:
195196
pass
196197
bricks._ggplot_legend = anchored_box
197198
bricks.case
@@ -239,7 +240,7 @@ def draw_labels(bricks, gori, gcp, figsize):
239240
pad_x = 4
240241
else:
241242
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
242-
pad_x = 4
243+
pad_x = 14 + (get_property('axis_text_x', 'size') - 11) * 0.5 + (get_property('axis_title_x', 'size') - 11) * 0.5
243244
else:
244245
pad_x = margin.get_as('t', 'pt')
245246

@@ -249,7 +250,7 @@ def draw_labels(bricks, gori, gcp, figsize):
249250
pad_y = 4
250251
else:
251252
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
252-
pad_y = 8
253+
pad_y = 12 + (get_property('axis_text_y', 'size') - 11) * 0.5 + (get_property('axis_title_y', 'size') - 11) * 0.5
253254
else:
254255
pad_y = margin.get_as('r', 'pt')
255256

@@ -276,34 +277,27 @@ def draw_labels(bricks, gori, gcp, figsize):
276277

277278
else:
278279
xlabel = bricks.set_xlabel(labels.x, labelpad=pad_x, va="top")
279-
ylabel = bricks.set_ylabel(labels.y, labelpad=pad_y)
280-
280+
ylabel = bricks.set_ylabel(labels.y, labelpad=pad_y)
281+
281282
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
282283
gori.theme._targets['axis_title_x'] = xlabel
283284
gori.theme._targets['axis_title_y'] = ylabel
284285
if 'axis_title_x' in gori.theme.themeables:
285286
gori.theme.themeables['axis_title_x'].apply_figure(gori.figure, gori.theme._targets)
286287
for ax in gori.axs:
287-
gori.theme.themeables['axis_title_x'].apply_ax(ax)
288+
gori.theme.themeables['axis_title_x'].apply_ax(ax)
288289

289290
if 'axis_title_y' in gori.theme.themeables:
290291
gori.theme.themeables['axis_title_y'].apply_figure(gori.figure, gori.theme._targets)
291292
for ax in gori.axs:
292-
gori.theme.themeables['axis_title_y'].apply_ax(ax)
293-
294-
if bricks._type == "Bricks":
295-
xlabel = bricks.case.set_xlabel(labels.x, labelpad=pad_x, va=va)
296-
x,y = xlabel.get_position()
297-
xlabel.set_position([(px1+px2) / 2, y])
298-
299-
ylabel = bricks.case.set_ylabel(labels.y, labelpad=pad_y)
300-
x,y = ylabel.get_position()
301-
ylabel.set_position([x, (py1+py2) / 2])
302-
303-
else:
304-
bricks.set_xlabel(labels.x, labelpad=pad_x, va=va)
305-
bricks.set_ylabel(labels.y, labelpad=pad_y)
306-
293+
gori.theme.themeables['axis_title_y'].apply_ax(ax)
294+
295+
for key in gori.theme.themeables:
296+
if "legend" in key:
297+
gori.theme.themeables[key].apply_figure(gori.figure, gori.theme._targets)
298+
for ax in gori.axs:
299+
gori.theme.themeables[key].apply_ax(ax)
300+
307301
else:
308302
gori.figure._themeable['axis_title_x'] = xlabel
309303
gori.figure._themeable['axis_title_y'] = ylabel
@@ -316,11 +310,12 @@ def draw_labels(bricks, gori, gcp, figsize):
316310
gori.theme.themeables['axis_title_y'].apply_figure(gori.figure)
317311
for ax in gori.axs:
318312
gori.theme.themeables['axis_title_y'].apply(ax)
313+
314+
return labels.x, labels.y
319315

320316
def draw_legend(bricks, gori, gcp, figsize):
321317
get_property = gcp.theme.themeables.property
322318
legend_box = gcp.guides.build(gcp)
323-
324319
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
325320
wratio = 1
326321
hratio = 1
@@ -358,7 +353,7 @@ def draw_legend(bricks, gori, gcp, figsize):
358353
else:
359354
loc = 1
360355
x, y = 0, 0
361-
356+
362357
if legend_box is None:
363358
pass
364359
else:
@@ -372,6 +367,7 @@ def draw_legend(bricks, gori, gcp, figsize):
372367
borderpad=0.)
373368

374369
anchored_box.set_zorder(90.1)
370+
anchored_box.set_in_layout(True)
375371
bricks.case.add_artist(anchored_box)
376372
bricks._ggplot_legend = anchored_box
377373
bricks._ggplot_legend_box = legend_box
@@ -475,11 +471,12 @@ def draw_title(bricks, gori, gcp, figsize):
475471
figure_subplot_wspace_ori = matplotlib.rcParams["figure.subplot.wspace"]
476472
figure_subplot_hspace_ori = matplotlib.rcParams["figure.subplot.hspace"]
477473
figsize_ori = gcp.theme.themeables['figure_size'].properties["value"]
474+
if figsize is None:
475+
figsize = gcp.theme.themeables['figure_size'].properties["value"]
478476
matplotlib.rcParams["figure.subplot.wspace"] = figure_subplot_wspace_ori / figsize[0]
479477
matplotlib.rcParams["figure.subplot.hspace"] = figure_subplot_hspace_ori / figsize[1]
480478
fig, gcp = gcp.draw(return_ggplot=True)
481-
if figsize is None:
482-
figsize = gcp.theme.themeables['figure_size'].properties["value"]
479+
483480
else:
484481
fig, gcp = gcp.draw(return_ggplot=True)
485482
_themeable = fig._themeable
@@ -562,7 +559,6 @@ def draw_title(bricks, gori, gcp, figsize):
562559
new = themeable.from_class_name
563560
ggplot.theme.themeables["figure_size"] = new("figure_size",(1,1))
564561
ggplot.theme.apply()
565-
#ggplot.figure.set_layout_engine(PlotnineLayoutEngine(ggplot))
566562

567563
elif StrictVersion(plotnine_version) >= StrictVersion("0.9"):
568564
ggplot._resize_panels()
@@ -590,7 +586,7 @@ def draw_title(bricks, gori, gcp, figsize):
590586
ax.change_aspectratio((figsize[0], figsize[1]))
591587

592588
if StrictVersion(plotnine_version) >= StrictVersion("0.9"):
593-
draw_labels(ax, ggplot, gcp, figsize)
589+
xl, yl = draw_labels(ax, ggplot, gcp, figsize)
594590
draw_legend(ax, ggplot, gcp, figsize)
595591
draw_title(ax, ggplot, gcp, figsize)
596592

@@ -606,12 +602,12 @@ def draw_title(bricks, gori, gcp, figsize):
606602
del gcp
607603
for key in tmp_axes_keys:
608604
axtmp = _axes_dict[key]
609-
axtmp.set_position(position_dict[key])
605+
axtmp.set_position(position_dict[key])
610606

611607
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
612-
matplotlib.rcParams["figure.subplot.wspace"] = figure_subplot_wspace_ori
613-
matplotlib.rcParams["figure.subplot.hspace"] = figure_subplot_hspace_ori
614-
return ax
608+
ax.set_xlabel(xl)
609+
ax.set_ylabel(yl)
610+
return_obj = ax
615611

616612
else:
617613
width, height = figsize
@@ -628,10 +624,11 @@ def draw_title(bricks, gori, gcp, figsize):
628624
bricks = expand(bricks, width, height)
629625

630626
if StrictVersion(plotnine_version) >= StrictVersion("0.9"):
631-
draw_labels(bricks, ggplot, gcp, figsize)
627+
xl, yl = draw_labels(bricks, ggplot, gcp, figsize)
632628
draw_legend(bricks, ggplot, gcp, figsize)
633629
draw_title(bricks, ggplot, gcp, figsize)
634-
630+
pass
631+
635632
elif StrictVersion("0.8") <= StrictVersion(plotnine_version) < StrictVersion("0.9"):
636633
draw_labels(bricks, ggplot, gcp, figsize)
637634
draw_legend(bricks, ggplot, gcp, figsize)
@@ -644,16 +641,22 @@ def draw_title(bricks, gori, gcp, figsize):
644641
del gcp
645642
for key in tmp_axes_keys:
646643
ax = _axes_dict[key]
647-
ax.set_position(position_dict[key])
648-
644+
ax.set_position(position_dict[key])
645+
649646
x0, x1, y0, y1 = bricks.get_outer_corner()
650647
bricks._originalsize = (abs(x1-x0), abs(y0-y1))
651648
bricks.set_originalpositions()
652-
653649
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
654-
matplotlib.rcParams["figure.subplot.wspace"] = figure_subplot_wspace_ori
655-
matplotlib.rcParams["figure.subplot.hspace"] = figure_subplot_hspace_ori
656-
return bricks
650+
bricks.case.set_xlabel(xl)
651+
bricks.case.set_ylabel(yl)
652+
return_obj = bricks
653+
654+
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
655+
matplotlib.rcParams["figure.subplot.wspace"] = figure_subplot_wspace_ori
656+
matplotlib.rcParams["figure.subplot.hspace"] = figure_subplot_hspace_ori
657+
return_obj.savefig(_ggplot=True)
658+
659+
return return_obj
657660

658661
def overwrite_axisgrid():
659662
"""
@@ -2366,7 +2369,7 @@ def get_middle_corner(self, labels=None):
23662369
for key in self.bricks_dict:
23672370
if key in labels:
23682371
ax = self.bricks_dict[key]
2369-
px0, px1, py0, py1 = ax.get_middle_corner()
2372+
px0, px1, py0, py1 = ax.get_middle_corner()
23702373
x0_list.append(px0)
23712374
x1_list.append(px1)
23722375
y0_list.append(py0)
@@ -2420,7 +2423,7 @@ def get_outer_corner(self):
24202423

24212424
return min(x0_list), max(x1_list), min(y0_list), max(y1_list)
24222425

2423-
def savefig(self, fname=None, transparent=None, quick=True, **kwargs):
2426+
def savefig(self, fname=None, transparent=None, quick=True, _ggplot=False, **kwargs):
24242427
"""
24252428
24262429
Save figure.
@@ -2443,6 +2446,7 @@ def savefig(self, fname=None, transparent=None, quick=True, **kwargs):
24432446
_axes_dict[":".join(case_label.split(":")[1:])].case
24442447

24452448
global param
2449+
global _basefigure
24462450
global _removed_axes
24472451
if quick == False:
24482452
self.case
@@ -2483,7 +2487,11 @@ def savefig(self, fname=None, transparent=None, quick=True, **kwargs):
24832487
kwargs.setdefault('bbox_inches', 'tight')
24842488
kwargs.setdefault('dpi', param['dpi'])
24852489
fig.savefig(fname, transparent=transparent, **kwargs)
2486-
2490+
else:
2491+
if _ggplot == True:
2492+
bytefig = io.BytesIO()
2493+
_basefigure.savefig(bytefig, format="pdf")
2494+
24872495
return fig
24882496

24892497
def __or__(self, other):
@@ -2859,7 +2867,7 @@ def get_outer_corner(self, labes=None):
28592867
self._outer_flag = True
28602868
return self._outer_corner
28612869

2862-
def savefig(self, fname=None, transparent=None, quick=True, **kwargs):
2870+
def savefig(self, fname=None, transparent=None, quick=True, _ggplot=False, **kwargs):
28632871
"""
28642872
28652873
Save figure.
@@ -2906,11 +2914,15 @@ def savefig(self, fname=None, transparent=None, quick=True, **kwargs):
29062914
else:
29072915
ax.remove()
29082916
_removed_axes[ax.get_label()] = ax
2909-
2910-
if fname is not None:
2917+
if fname is not None:
29112918
kwargs.setdefault('bbox_inches', 'tight')
29122919
kwargs.setdefault('dpi', param['dpi'])
29132920
fig.savefig(fname, transparent=transparent, **kwargs)
2921+
else:
2922+
if _ggplot == True:
2923+
bytefig = io.BytesIO()
2924+
_basefigure.savefig(bytefig, format="pdf")
2925+
29142926
return fig
29152927

29162928
def change_plotsize(self, new_size):

0 commit comments

Comments
 (0)