Skip to content

Commit 0f11457

Browse files
committed
2 parents dbef5b7 + bb93144 commit 0f11457

File tree

47 files changed

+46294
-40
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+46294
-40
lines changed

plotting_scripts/make_table_splitting_publication.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pandas as pd
33
import seaborn as sns
44
import matplotlib.pyplot as plt
5-
import matplotlib.lines as mlines # Import mlines
5+
import matplotlib.lines as mlines # Import mlines
66
import matplotlib.patches as mpatches
77

88
plt.rcParams["text.usetex"] = True
@@ -114,7 +114,7 @@
114114
# Add a vertical dotted line between rna_site and rna_site_redundant
115115
ax = g.ax
116116
line_pos = list(task_names.keys()).index("rna_site") + 0.5
117-
ax.axvline(x=line_pos, linestyle="--", color="dimgray", linewidth=2)
117+
ax.axvline(x=line_pos, ymax=0.75, linestyle="--", color="dimgray", linewidth=2)
118118

119119
# Create handles and labels manually
120120
handles = []
@@ -132,7 +132,7 @@
132132
labels.append(distance)
133133
plt.legend(handles, labels, loc="upper center", ncol=3, title=r"Splitting strategy:", handletextpad=-0.3)
134134

135-
plt.savefig("plotting_scripts/splitting_publication.pdf", format="pdf")
136135
plt.subplots_adjust(bottom=0.15) # Adjust the values as needed
136+
plt.savefig("plotting_scripts/splitting_publication.pdf", format="pdf")
137137
plt.show()
138138
plt.clf()
Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
import json
22
import pandas as pd
3-
import matplotlib.pyplot as plt
43
import seaborn as sns
54

5+
import matplotlib.pyplot as plt
6+
import matplotlib.lines as mlines # Import mlines
7+
import matplotlib.patches as mpatches
8+
9+
plt.rcParams["text.usetex"] = True
10+
plt.rc("font", size=16) # fontsize of the tick labels
11+
plt.rc("ytick", labelsize=13) # fontsize of the tick labels
12+
plt.rc("xtick", labelsize=13) # fontsize of the tick labels
13+
plt.rc("grid", color="grey", alpha=0.2)
14+
615
# Config
716
thresholds = [0.4, 0.6, 0.8, 1]
817
SEEDS = [0, 1, 2]
@@ -22,67 +31,97 @@
2231
rows.append(
2332
{
2433
"score": score,
25-
"metric": METRIC,
2634
"threshold": f"{threshold:.1f}",
2735
"seed": seed,
2836
"distance": distance,
2937
}
3038
)
3139

3240
# Add random performance over three seeds
33-
for seed in SEEDS:
34-
path = f"results/thresholds_rna_site_redundant_rand_0.4_{seed}.json"
35-
with open(path) as result:
36-
result = json.load(result)
37-
score = result[METRIC]
38-
rows.append(
39-
{
40-
"score": score,
41-
"metric": METRIC,
42-
"threshold": "random",
43-
"seed": seed, # simulate different seeds
44-
"distance": "rand", # used for filtering
45-
}
46-
)
41+
for threshold in thresholds:
42+
for seed in SEEDS:
43+
path = f"results/thresholds_rna_site_redundant_rand_0.4_{seed}.json"
44+
with open(path) as result:
45+
result = json.load(result)
46+
score = result[METRIC]
47+
rows.append(
48+
{
49+
"score": score,
50+
"threshold": f"{threshold:.1f}",
51+
"seed": seed, # simulate different seeds
52+
"distance": "rand", # used for filtering
53+
}
54+
)
4755

4856
# Build DataFrame
4957
df = pd.DataFrame(rows)
5058
df.to_csv("thresholds_with_random.csv")
5159

52-
# Ensure x-axis order: random first, then thresholds
53-
task_order = ["random"] + [f"{t:.1f}" for t in thresholds]
54-
5560
# Create a column to use for legend hue, exclude 'rand'
56-
df["distance"] = df["distance"].apply(lambda d: d if d in ["struc", "seq"] else "random")
61+
# df["distance"] = df["distance"].apply(lambda d: d if d in ["struc", "seq"] else "random")
5762

5863
# Custom color palette (you can tweak these)
5964
palette = {
6065
"struc": sns.color_palette()[0],
6166
"seq": sns.color_palette()[1],
6267
"random": sns.color_palette()[2],
6368
}
69+
70+
dist_names = {
71+
"struc": r"Structure",
72+
"seq": r"Sequence",
73+
"rand": r"Random"
74+
}
75+
df["distance"] = df["distance"].replace(dist_names)
76+
77+
palette_dict = sns.color_palette("muted")
78+
palette_dict = {
79+
r"Structure": palette_dict[0],
80+
r"Sequence": palette_dict[9],
81+
r"Random": palette_dict[3],
82+
}
6483
print(df)
6584

66-
# Plot
67-
g = sns.catplot(
85+
# df_no_rand = df[df["distance"] != r"Random"]
86+
ax = sns.lineplot(
6887
data=df,
69-
x="threshold",
70-
y="score",
71-
hue="distance",
72-
kind="bar",
73-
height=4,
74-
aspect=1,
75-
order=task_order,
76-
palette=palette,
88+
x="threshold", # X-axis is the threshold
89+
y="score", # Y-axis is the score
90+
hue="distance", # Differentiate lines by distance
91+
errorbar='sd', # Use standard deviation for the error band
7792
legend=False,
93+
palette=palette_dict,
94+
linewidth=2,
7895
)
7996

80-
# Format axes
81-
g.set_axis_labels("Threshold", "Test Score")
82-
g.set(ylim=(0.5, 0.85))
83-
g.despine(left=True)
97+
# Set axis labels and title
98+
plt.xlabel("Threshold")
99+
plt.ylabel("Test Score")
100+
plt.tight_layout()
101+
102+
103+
# Create handles and labels manually
104+
handles = []
105+
labels = []
106+
for i, distance in enumerate(dist_names.values()):
107+
# Create a dummy rectangle for each distance, using the color from the plot
108+
color = palette_dict[distance] # Get color for this distance
109+
# color = sns.color_palette()[i] # Get color for this distance
110+
# handle = mpatches.Circle((0, 0), radius=0.5, color=color)
111+
# Using mlines.Line2D with marker 'o'
112+
handle = mlines.Line2D([], [], color=color, marker='o', linestyle='None', markersize=10, )
113+
# label=distance_name) # Use Line2D with marker 'o' and set markersize
114+
# handle = plt.Rectangle((0, 0), 1, 1, color=color) # Create a rectangle with that color
115+
handles.append(handle)
116+
labels.append(distance)
117+
plt.legend(handles, labels, loc="upper center", ncol=3, title=r"Splitting strategy:", handletextpad=-0.3)
84118

85-
# Save + show
86-
plt.savefig("thresholds.pdf", format="pdf")
119+
# # Format axes
120+
# g.set_axis_labels("Threshold", "Test Score")
121+
ax.set(ylim=(0.5, 1))
122+
sns.despine()
123+
#
124+
# # Save + show
125+
plt.savefig("plotting_scripts/thresholds.pdf", format="pdf")
87126
plt.show()
88-
plt.clf()
127+
plt.clf()

0 commit comments

Comments
 (0)