1
1
import json
2
2
import pandas as pd
3
- import matplotlib .pyplot as plt
4
3
import seaborn as sns
5
4
5
+ import matplotlib .pyplot as plt
6
+ import matplotlib .lines as mlines # Import mlines
7
+ import matplotlib .patches as mpatches
8
+
9
+ plt .rcParams ["text.usetex" ] = True
10
+ plt .rc ("font" , size = 16 ) # fontsize of the tick labels
11
+ plt .rc ("ytick" , labelsize = 13 ) # fontsize of the tick labels
12
+ plt .rc ("xtick" , labelsize = 13 ) # fontsize of the tick labels
13
+ plt .rc ("grid" , color = "grey" , alpha = 0.2 )
14
+
6
15
# Config
7
16
thresholds = [0.4 , 0.6 , 0.8 , 1 ]
8
17
SEEDS = [0 , 1 , 2 ]
22
31
rows .append (
23
32
{
24
33
"score" : score ,
25
- "metric" : METRIC ,
26
34
"threshold" : f"{ threshold :.1f} " ,
27
35
"seed" : seed ,
28
36
"distance" : distance ,
29
37
}
30
38
)
31
39
32
40
# Add random performance over three seeds
33
- for seed in SEEDS :
34
- path = f"results/thresholds_rna_site_redundant_rand_0.4_ { seed } .json"
35
- with open ( path ) as result :
36
- result = json . load ( result )
37
- score = result [ METRIC ]
38
- rows . append (
39
- {
40
- "score" : score ,
41
- "metric " : METRIC ,
42
- "threshold" : "random " ,
43
- "seed" : seed , # simulate different seeds
44
- "distance" : "rand" , # used for filtering
45
- }
46
- )
41
+ for threshold in thresholds :
42
+ for seed in SEEDS :
43
+ path = f"results/thresholds_rna_site_redundant_rand_0.4_ { seed } .json"
44
+ with open ( path ) as result :
45
+ result = json . load ( result )
46
+ score = result [ METRIC ]
47
+ rows . append (
48
+ {
49
+ "score " : score ,
50
+ "threshold" : f" { threshold :.1f } " ,
51
+ "seed" : seed , # simulate different seeds
52
+ "distance" : "rand" , # used for filtering
53
+ }
54
+ )
47
55
48
56
# Build DataFrame
49
57
df = pd .DataFrame (rows )
50
58
df .to_csv ("thresholds_with_random.csv" )
51
59
52
- # Ensure x-axis order: random first, then thresholds
53
- task_order = ["random" ] + [f"{ t :.1f} " for t in thresholds ]
54
-
55
60
# Create a column to use for legend hue, exclude 'rand'
56
- df ["distance" ] = df ["distance" ].apply (lambda d : d if d in ["struc" , "seq" ] else "random" )
61
+ # df["distance"] = df["distance"].apply(lambda d: d if d in ["struc", "seq"] else "random")
57
62
58
63
# Custom color palette (you can tweak these)
59
64
palette = {
60
65
"struc" : sns .color_palette ()[0 ],
61
66
"seq" : sns .color_palette ()[1 ],
62
67
"random" : sns .color_palette ()[2 ],
63
68
}
69
+
70
+ dist_names = {
71
+ "struc" : r"Structure" ,
72
+ "seq" : r"Sequence" ,
73
+ "rand" : r"Random"
74
+ }
75
+ df ["distance" ] = df ["distance" ].replace (dist_names )
76
+
77
+ palette_dict = sns .color_palette ("muted" )
78
+ palette_dict = {
79
+ r"Structure" : palette_dict [0 ],
80
+ r"Sequence" : palette_dict [9 ],
81
+ r"Random" : palette_dict [3 ],
82
+ }
64
83
print (df )
65
84
66
- # Plot
67
- g = sns .catplot (
85
+ # df_no_rand = df[df["distance"] != r"Random"]
86
+ ax = sns .lineplot (
68
87
data = df ,
69
- x = "threshold" ,
70
- y = "score" ,
71
- hue = "distance" ,
72
- kind = "bar" ,
73
- height = 4 ,
74
- aspect = 1 ,
75
- order = task_order ,
76
- palette = palette ,
88
+ x = "threshold" , # X-axis is the threshold
89
+ y = "score" , # Y-axis is the score
90
+ hue = "distance" , # Differentiate lines by distance
91
+ errorbar = 'sd' , # Use standard deviation for the error band
77
92
legend = False ,
93
+ palette = palette_dict ,
94
+ linewidth = 2 ,
78
95
)
79
96
80
- # Format axes
81
- g .set_axis_labels ("Threshold" , "Test Score" )
82
- g .set (ylim = (0.5 , 0.85 ))
83
- g .despine (left = True )
97
+ # Set axis labels and title
98
+ plt .xlabel ("Threshold" )
99
+ plt .ylabel ("Test Score" )
100
+ plt .tight_layout ()
101
+
102
+
103
+ # Create handles and labels manually
104
+ handles = []
105
+ labels = []
106
+ for i , distance in enumerate (dist_names .values ()):
107
+ # Create a dummy rectangle for each distance, using the color from the plot
108
+ color = palette_dict [distance ] # Get color for this distance
109
+ # color = sns.color_palette()[i] # Get color for this distance
110
+ # handle = mpatches.Circle((0, 0), radius=0.5, color=color)
111
+ # Using mlines.Line2D with marker 'o'
112
+ handle = mlines .Line2D ([], [], color = color , marker = 'o' , linestyle = 'None' , markersize = 10 , )
113
+ # label=distance_name) # Use Line2D with marker 'o' and set markersize
114
+ # handle = plt.Rectangle((0, 0), 1, 1, color=color) # Create a rectangle with that color
115
+ handles .append (handle )
116
+ labels .append (distance )
117
+ plt .legend (handles , labels , loc = "upper center" , ncol = 3 , title = r"Splitting strategy:" , handletextpad = - 0.3 )
84
118
85
- # Save + show
86
- plt .savefig ("thresholds.pdf" , format = "pdf" )
119
+ # # Format axes
120
+ # g.set_axis_labels("Threshold", "Test Score")
121
+ ax .set (ylim = (0.5 , 1 ))
122
+ sns .despine ()
123
+ #
124
+ # # Save + show
125
+ plt .savefig ("plotting_scripts/thresholds.pdf" , format = "pdf" )
87
126
plt .show ()
88
- plt .clf ()
127
+ plt .clf ()
0 commit comments