Skip to content

Commit b36ef94

Browse files
committed
Updating splotting scripts
1 parent 435b686 commit b36ef94

File tree

6 files changed

+146
-29
lines changed

6 files changed

+146
-29
lines changed

prepare_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def fix_duplicate_runs(df):
4444
df_config = prepare_config(config_general)
4545
group_keys = df_config.columns
4646
df_test = df_config.join(df, on=group_keys, how="inner")
47-
df_test.join(df_query_time_all_datalakes, on="jd_method").with_columns(
47+
df_test = df_test.join(df_query_time_all_datalakes, on="jd_method").with_columns(
4848
total_runtime=pl.col("time_run") + pl.col("time_query")
4949
)
5050
df_test.write_parquet("results/results_general.parquet")

results_pivot.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# %%
2+
import polars as pl
3+
4+
from src.utils.constants import LABEL_MAPPING
5+
6+
# %%
7+
df_aggregation = pl.read_parquet("results/results_aggregation.parquet")
8+
df_general = pl.read_parquet("results/results_general.parquet")
9+
df_retrieval = pl.read_parquet("results/results_retrieval.parquet")
10+
df_master = pl.read_parquet("results/master_list.parquet")
11+
12+
# %%
13+
variables = ["chosen_model", "jd_method", "estimator", "target_dl", "base_table"]
14+
15+
# %%
16+
df_general.pivot(
17+
on="estimator",
18+
index="chosen_model",
19+
values="prediction_metric",
20+
aggregate_function="median",
21+
)
22+
# %%
23+
for var_1 in variables:
24+
df_list = []
25+
for var_2 in variables:
26+
if var_1 == var_2:
27+
continue
28+
_this_df = df_general.pivot(
29+
on=var_2,
30+
index=var_1,
31+
values="prediction_metric",
32+
aggregate_function="median",
33+
)
34+
_index = _this_df.get_column(var_1).replace(LABEL_MAPPING[var_1])
35+
_this_df.drop_in_place(var_1)
36+
_this_df = _this_df.rename(lambda c : LABEL_MAPPING[var_2][c])
37+
_col_order = [var_1] + _this_df.columns
38+
_this_df = _this_df.with_columns(_index.alias(var_1)).select(_col_order)
39+
df_list.append(_this_df)
40+
41+
df_aligned = pl.concat(df_list, how="align")
42+
df_aligned.write_csv(f"results/results_pivot_{var_1}.csv")
43+
44+
# %%
45+
# %%

scripts/plotting/plot_comparison_large.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
# %%
66
# %cd ~/bench
7-
# %load_ext autoreload
8-
# %autoreload 2
9-
#%%
7+
# %%
108
import matplotlib.pyplot as plt
119
import polars as pl
1210

@@ -16,6 +14,7 @@
1614
plot_case = "dep"
1715
savefig = False
1816

17+
1918
# %%
2019
def read_and_format(file_path):
2120
return (
@@ -35,14 +34,24 @@ def read_and_format(file_path):
3534
)
3635

3736

38-
#%%
37+
# %%
3938
_results_general = read_and_format("results/results_general.parquet")
4039
_results_aggr = read_and_format("results/results_aggregation.parquet")
41-
42-
_results_aggr = _results_aggr.filter(pl.col("estimator") != "nojoin")
4340
_results_retrieval = read_and_format("results/results_retrieval.parquet")
44-
# _results_aggr = _results_aggr.filter(pl.col("jd_method") == "exact_matching")
45-
# _results_general = _results_general.filter(pl.col("jd_method") == "exact_matching")
41+
42+
# _results_aggr = _results_aggr.with_columns(time_run = pl.col("time_run")*10 + pl.col("time_query"))
43+
# _results_general = _results_general.with_columns(time_run = pl.col("time_run")*10 + pl.col("time_query"))
44+
# _results_retrieval = _results_retrieval.with_columns(time_run = pl.col("time_run")*10 + pl.col("time_query"))
45+
46+
_results_aggr = _results_aggr.filter(pl.col("estimator") != "nojoin").with_columns(
47+
time_run=pl.col("time_run") * 10 + pl.col("time_query")
48+
)
49+
_results_general = _results_general.filter(
50+
pl.col("estimator") != "nojoin"
51+
).with_columns(time_run=pl.col("time_run") * 10 + pl.col("time_query"))
52+
_results_retrieval = _results_retrieval.filter(
53+
pl.col("estimator") != "nojoin"
54+
).with_columns(time_run=pl.col("time_run") * 10 + pl.col("time_query"))
4655

4756

4857
# %%

scripts/plotting/plot_pareto_topk.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""This script is used to prepare the Pareto plots that compare the performance
2+
by value of top-k, considering both the run time and the peak RAM.
3+
"""
14
# %%
25
# %cd ~/bench
36
# %%
@@ -7,6 +10,9 @@
710
from src.utils import constants
811
from src.utils.plotting import pareto_frontier_plot
912

13+
plt.style.use("seaborn-v0_8-talk")
14+
plt.rc("font", family="sans-serif")
15+
1016
# %%
1117
df = pl.read_csv("results/results_topk.csv")
1218
df = df.group_by(constants.GROUPING_KEYS + ["top_k"]).agg(
@@ -47,7 +53,7 @@ def prepare_sem_df(df, variable):
4753
xerr = df_sem["sem_time_run"].to_numpy()
4854
yerr = df_sem["sem_pred"].to_numpy()
4955

50-
fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(6, 4), layout="constrained")
56+
fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(5, 3.5), layout="constrained")
5157
(h, l), _ = pareto_frontier_plot(
5258
df_pareto.to_pandas(),
5359
x_var="time_run",
@@ -78,14 +84,12 @@ def prepare_sem_df(df, variable):
7884
ecolor=c,
7985
)
8086

81-
ax.legend(h, l, title="Value of k", loc="upper right", bbox_to_anchor=(1.30, 1))
82-
87+
ax.legend(h, l, title="Value of k", loc="upper right", bbox_to_anchor=(1.35, 1.05), frameon=False)
8388
_x, _y = df_pareto.filter(top_k=30).select("time_run", "prediction_metric")
8489

8590
x_text = _x.item()
8691
y_text = _y.item()
8792

88-
# Annotate the point (36, 0.52)
8993
ax.annotate(
9094
"k used in experiments", # Annotation text
9195
xy=(x_text, y_text), # Point to annotate
@@ -104,7 +108,7 @@ def prepare_sem_df(df, variable):
104108
xerr = df_sem["sem_peak_fit"].to_numpy()
105109
yerr = df_sem["sem_pred"].to_numpy()
106110

107-
fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(6, 4), layout="constrained")
111+
fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(5, 3.5), layout="constrained")
108112
(h, l), _ = pareto_frontier_plot(
109113
df_pareto.to_pandas(),
110114
x_var="peak_fit",
@@ -135,7 +139,7 @@ def prepare_sem_df(df, variable):
135139
ecolor=c,
136140
)
137141

138-
ax.legend(h, l, title="Value of k", loc="upper right", bbox_to_anchor=(1.30, 1))
142+
ax.legend(h, l, title="Value of k", loc="upper right", bbox_to_anchor=(1.35, 1.05), frameon=False)
139143
fig.savefig("images/pareto_topk_ram.png", bbox_inches="tight")
140144
fig.savefig("images/pareto_topk_ram.pdf", bbox_inches="tight")
141145

stats/compile_stats.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# %%
2+
import polars as pl
3+
import pandas as pd
4+
5+
# %%
6+
df_starmie = (
7+
pl.read_csv("stats_retrieval_starmie.csv")
8+
.drop("time_save")
9+
.with_columns(index_name=pl.lit("starmie"))
10+
)
11+
df_others = pl.read_csv("stats_retrieval_others.csv").drop("time_save", "n_candidates")
12+
# %%
13+
val_starmie = (
14+
df_starmie.group_by("data_lake_version", "base_table", "index_name")
15+
.agg(pl.mean("time_create", "time_load", "time_query"))
16+
.with_columns(total_query=pl.col("time_load") + pl.col("time_query"))["total_query"]
17+
.mean()
18+
) / 6
19+
20+
# %% Include only the data lakes that starmie works on
21+
_d = (
22+
df_others.filter(~pl.col("data_lake_version").is_in(["wordnet_vldb_50", "open_data_us"])).with_columns(
23+
total_query=pl.when(pl.col("index_name") == "exact_matching")
24+
.then(pl.sum_horizontal("time_create", "time_load", "time_query"))
25+
.otherwise(pl.sum_horizontal("time_load", "time_query"))
26+
)
27+
.group_by("index_name")
28+
.agg(pl.mean("total_query"))
29+
.with_columns(total_query=pl.col("total_query") / 6)
30+
)
31+
32+
r_dict = dict(_d.rows())
33+
r_dict["starmie"] = val_starmie
34+
35+
pl.from_dict({"jd_method": r_dict.keys(), "time_query": r_dict.values()}).write_csv(
36+
"avg_query_time_for_pareto_plot_retrieval.csv"
37+
)
38+
# %% Now all data lakes and no starmie
39+
_d = (
40+
df_others.with_columns(
41+
total_query=pl.when(pl.col("index_name") == "exact_matching")
42+
.then(pl.sum_horizontal("time_create", "time_load", "time_query"))
43+
.otherwise(pl.sum_horizontal("time_load", "time_query"))
44+
)
45+
.group_by("index_name")
46+
.agg(pl.mean("total_query"))
47+
.with_columns(total_query=pl.col("total_query") / 6)
48+
)
49+
50+
r_dict = dict(_d.rows())
51+
52+
pl.from_dict({"jd_method": r_dict.keys(), "time_query": r_dict.values()}).write_csv(
53+
"avg_query_time_for_pareto_plot_all_datalakes.csv"
54+
)
55+
56+
# %%
57+
df_others_max_ram=df_others.filter(~pl.col("data_lake_version").is_in(["wordnet_vldb_50", "open_data_us"])).with_columns(
58+
max_ram=pl.max_horizontal("peak_create", "peak_query")
59+
)
60+
# %%
61+
import seaborn as sns
62+
63+
sns.displot(data=df_others_max_ram.to_pandas(), x="max_ram", col="index_name", binwidth=200)

summary_results.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""This script is used to build the ablation table to compare the performance of each
2+
configuration against the reference.
3+
"""
4+
15
# %%
26
# %cd ~/bench
37
# %%
@@ -64,23 +68,15 @@
6468
df_reference.write_csv("results/results_reference.csv")
6569

6670
# %%
67-
query_times_retrieval = pl.read_csv(
68-
"stats/avg_query_time_for_pareto_plot_retrieval.csv"
69-
)
70-
query_times_all_datalakes = pl.read_csv(
71-
"stats/avg_query_time_for_pareto_plot_all_datalakes.csv"
71+
df_retrieval = df_retrieval.filter(pl.col("estimator") != "nojoin").with_columns(
72+
time_run=pl.col("time_run") * 10 + pl.col("time_query")
7273
)
73-
74-
# %%
75-
df_retrieval = df_retrieval.join(query_times_retrieval, on="jd_method").with_columns(
76-
time_run=pl.col("time_run")*10 + pl.col("time_query")
74+
df_general = df_general.filter(pl.col("estimator") != "nojoin").with_columns(
75+
time_run=pl.col("time_run") * 10 + pl.col("time_query")
7776
)
78-
df_general = df_general.join(query_times_retrieval, on="jd_method").with_columns(
79-
time_run=pl.col("time_run")*10 + pl.col("time_query")
77+
df_aggregation = df_aggregation.filter(pl.col("estimator") != "nojoin").with_columns(
78+
time_run=pl.col("time_run") * 10 + pl.col("time_query")
8079
)
81-
df_aggregation = df_aggregation.join(
82-
query_times_retrieval, on="jd_method"
83-
).with_columns(time_run=pl.col("time_run")*10 + pl.col("time_query"))
8480

8581

8682
# %%

0 commit comments

Comments
 (0)