6
6
7
7
import torch
8
8
import numpy as np
9
+ from torch .utils .tensorboard import SummaryWriter
9
10
10
11
class RNATrainer :
11
12
def __init__ (self , task , model , rep = "pyg" , wandb_project = "" , exp_name = "default" ,
12
- learning_rate = 0.001 , epochs = 100 , seed = 0 , batch_size = 8 ):
13
+ learning_rate = 0.001 , epochs = 100 , seed = 0 , batch_size = 8 , output = "wandb" , log_dir = "runs/" ):
13
14
self .task = task
14
15
self .representation = rep
15
16
self .model = model
@@ -20,15 +21,20 @@ def __init__(self, task, model, rep="pyg", wandb_project="", exp_name="default",
20
21
self .training_log = []
21
22
self .seed = seed
22
23
self .batch_size = batch_size
24
+ self .output = output
25
+ self .log_dir = log_dir
23
26
24
27
def setup (self ):
25
28
"""Initialize wandb and model training"""
26
- wandb .init (
27
- entity = "mlsb" , # Replace with your team name
28
- project = self .wandb_project ,
29
- name = self .exp_name ,
30
- )
31
-
29
+ if self .output == "tensorboard" :
30
+ self .train_writer = SummaryWriter (log_dir = self .log_dir + self .task .name + "/" + self .exp_name + "/train" )
31
+ self .val_writer = SummaryWriter (log_dir = self .log_dir + self .task .name + "/" + self .exp_name + "/val" )
32
+ else :
33
+ wandb .init (
34
+ entity = "mlsb" , # Replace with your team name
35
+ project = self .wandb_project ,
36
+ name = self .exp_name ,
37
+ )
32
38
# Set seeds for reproducibility
33
39
torch .manual_seed (self .seed ) # CPU random number generator
34
40
if torch .cuda .is_available ():
@@ -72,33 +78,48 @@ def train(self):
72
78
train_metrics = self .model .evaluate (self .task , split = "train" )
73
79
val_metrics = self .model .evaluate (self .task , split = "val" )
74
80
75
- # Log to wandb
81
+ # Log to wandb or Tensorboard
76
82
metrics = {
77
83
"epoch" : epoch ,
78
84
** {f"train_{ k } " : v for k , v in train_metrics .items ()},
79
85
** {f"val_{ k } " : v for k , v in val_metrics .items ()}
80
86
}
87
+ if self .output == "tensorboard" :
88
+ self .train_writer .add_scalar ("Loss" , train_metrics ['loss' ], epoch )
89
+ self .val_writer .add_scalar ("Loss" , val_metrics ['loss' ], epoch )
81
90
try :
82
91
metrics ["train_auc" ] = train_metrics ['auc' ]
83
92
metrics ["val_auc" ] = val_metrics ['auc' ]
93
+ if self .output == "tensorboard" :
94
+ self .train_writer .add_scalar ("AUC" , train_metrics ['auc' ], epoch )
95
+ self .val_writer .add_scalar ("AUC" , val_metrics ['auc' ], epoch )
84
96
except :
85
97
pass
86
98
if self .task .metadata ['multi_label' ]:
87
99
metrics ["train_jaccard" ] = train_metrics ["jaccard" ]
88
100
metrics ["val_jaccard" ] = val_metrics ["jaccard" ]
101
+ if self .output == "tensorboard" :
102
+ self .train_writer .add_scalar ("Jaccard" , train_metrics ['jaccard' ], epoch )
103
+ self .val_writer .add_scalar ("Jaccard" , val_metrics ['jaccard' ], epoch )
89
104
else :
90
105
try :
91
106
metrics ["train_balanced_accuracy" ] = train_metrics ["balanced_accuracy" ]
92
107
metrics ["val_balanced_accuracy" ] = val_metrics ["balanced_accuracy" ]
108
+ if self .output == "tensorboard" :
109
+ self .train_writer .add_scalar ("Balanced_acc" , train_metrics ['balanced_accuracy' ], epoch )
110
+ self .val_writer .add_scalar ("Balanced_acc" , val_metrics ['balanced_accuracy' ], epoch )
93
111
except :
94
112
pass
95
113
try :
96
114
metrics ["train_mcc" ] = train_metrics ["mcc" ]
97
115
metrics ["val_mcc" ] = val_metrics ["mcc" ]
116
+ if self .output == "tensorboard" :
117
+ self .train_writer .add_scalar ("MCC" , train_metrics ['mcc' ], epoch )
118
+ self .val_writer .add_scalar ("MCC" , val_metrics ['mcc' ], epoch )
98
119
except :
99
120
pass
100
-
101
- wandb .log (metrics )
121
+ if self . output == "wandb" :
122
+ wandb .log (metrics )
102
123
self .training_log .append (metrics )
103
124
104
125
# Print progress
@@ -110,7 +131,11 @@ def train(self):
110
131
)
111
132
112
133
self .save_results ()
113
- wandb .finish ()
134
+ if self .output == "tensorboard" :
135
+ self .train_writer .flush ()
136
+ self .val_writer .flush ()
137
+ else :
138
+ wandb .finish ()
114
139
115
140
def save_results (self ):
116
141
"""Save final results and metrics"""
0 commit comments