Skip to content

Commit a632cbb

Browse files
committed
add tensorboard option and update RNA_CM_exp
1 parent 6c128e4 commit a632cbb

File tree

2 files changed

+107
-24
lines changed

2 files changed

+107
-24
lines changed

base/RNA_CM_exp.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,89 @@
88

99
from rnaglib.learning.task_models import PygModel
1010
from rnaglib.tasks import get_task, RNA_CM
11-
from rnaglib.transforms import GraphRepresentation
11+
from rnaglib.transforms import GraphRepresentation, RNAFMTransform
12+
from rnaglib.dataset_transforms import CDHitComputer, ClusterSplitter, StructureDistanceComputer, RandomSplitter
13+
from rnaglib.encoders import ListEncoder
14+
from rnaglib.config.graph_keys import GRAPH_KEYS, TOOL
1215

1316
script_dir = os.path.dirname(os.path.realpath(__file__))
1417
if __name__ == "__main__":
1518
sys.path.append(os.path.join(script_dir, '..'))
1619

1720
from exp import RNATrainer
1821

19-
# Setup task
20-
ta = get_task(root="roots/RNA_CM", task_id="rna_cm")
22+
# Hyperparameters (to tune)
23+
nb_layers = 3
24+
hidden_dim = 128
25+
learning_rate = 0.001
26+
batch_size = 8
27+
epochs = 40
28+
split = "default"
29+
rna_fm = False
30+
representation = "2.5D"
31+
layer_type = "rgcn"
32+
output = "tensorboard"
33+
34+
# Experiment name
35+
exp_name="RNA_CM_"+str(nb_layers)+"layers_lr"+str(learning_rate)+"_"+str(epochs)+"epochs_hiddendim"+str(hidden_dim)+"_"+representation+"_layer_type_"+layer_type
36+
if rna_fm:
37+
exp_name += "rna_fm"
38+
if split != "default":
39+
exp_name += split
2140

22-
ta.dataset.add_representation(GraphRepresentation(framework="pyg"))
23-
ta.get_split_loaders(batch_size=8)
2441

2542
model_args = {
26-
"num_node_features": ta.metadata["num_node_features"],
27-
"num_classes": ta.metadata["num_classes"],
28-
"graph_level": False,
29-
"num_layers": 3,
43+
"graph_level": False,
44+
"num_layers": nb_layers,
45+
"hidden_channels": hidden_dim,
46+
"layer_type": layer_type,
3047
}
3148

32-
model = PygModel(**model_args)
33-
trainer = RNATrainer(ta, model, epochs=40)
49+
if rna_fm:
50+
model_args["num_node_features"]=644
3451

35-
trainer.train()
52+
#model_CM = PygModel(**model_args)
53+
#trainer_CM = RNATrainer(ta, model_CM, rep, exp_name=exp_name, learning_rate=learning_rate, epochs=epochs)
54+
#trainer_CM.train()
3655

3756
if __name__ == "__main__":
38-
pass
57+
for seed in [0,1,2]:
58+
ta = get_task(root="roots/RNA_CM", task_id="rna_cm")
59+
if split=="struc":
60+
distance = "USalign"
61+
else:
62+
distance = "cd_hit"
63+
64+
if distance not in ta.dataset.distances:
65+
if split == 'struc':
66+
ta.dataset = StructureDistanceComputer()(ta.dataset)
67+
if split == 'seq':
68+
ta.dataset = CDHitComputer()(ta.dataset)
69+
if split == 'rand':
70+
ta.splitter = RandomSplitter()
71+
elif split=='struc' or split=='seq':
72+
ta.splitter = ClusterSplitter(distance_name=distance)
73+
74+
if rna_fm:
75+
rnafm = RNAFMTransform()
76+
[rnafm(rna) for rna in ta.dataset]
77+
ta.dataset.features_computer.add_feature(feature_names=["rnafm"], custom_encoders={"rnafm": ListEncoder(640)})
78+
79+
if representation=="2D":
80+
edge_map = GRAPH_KEYS["2D_edge_map"][TOOL]
81+
elif representation=="simplified_2.5D":
82+
edge_map = GRAPH_KEYS["simplified_edge_map"][TOOL]
83+
else:
84+
edge_map = GRAPH_KEYS["edge_map"][TOOL]
85+
86+
representation_args = {
87+
"framework": "pyg",
88+
"edge_map": edge_map,
89+
}
90+
91+
rep = GraphRepresentation(**representation_args)
92+
ta.dataset.add_representation(rep)
93+
ta.get_split_loaders(batch_size=batch_size, recompute=True)
94+
model = PygModel.from_task(ta, **model_args)
95+
trainer = RNATrainer(ta, model, rep, exp_name=exp_name+"_seed"+str(seed), learning_rate=learning_rate, epochs=epochs, seed=seed, batch_size=batch_size, output=output)
96+
trainer.train()

exp.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
import torch
88
import numpy as np
9+
from torch.utils.tensorboard import SummaryWriter
910

1011
class RNATrainer:
1112
def __init__(self, task, model, rep="pyg", wandb_project="", exp_name="default",
12-
learning_rate=0.001, epochs=100, seed=0, batch_size=8):
13+
learning_rate=0.001, epochs=100, seed=0, batch_size=8, output="wandb", log_dir="runs/"):
1314
self.task = task
1415
self.representation = rep
1516
self.model = model
@@ -20,15 +21,20 @@ def __init__(self, task, model, rep="pyg", wandb_project="", exp_name="default",
2021
self.training_log = []
2122
self.seed = seed
2223
self.batch_size = batch_size
24+
self.output = output
25+
self.log_dir = log_dir
2326

2427
def setup(self):
2528
"""Initialize wandb and model training"""
26-
wandb.init(
27-
entity="mlsb", # Replace with your team name
28-
project=self.wandb_project,
29-
name=self.exp_name,
30-
)
31-
29+
if self.output == "tensorboard":
30+
self.train_writer = SummaryWriter(log_dir=self.log_dir+self.task.name+"/"+self.exp_name+"/train")
31+
self.val_writer = SummaryWriter(log_dir=self.log_dir+self.task.name+"/"+self.exp_name+"/val")
32+
else:
33+
wandb.init(
34+
entity="mlsb", # Replace with your team name
35+
project=self.wandb_project,
36+
name=self.exp_name,
37+
)
3238
# Set seeds for reproducibility
3339
torch.manual_seed(self.seed) # CPU random number generator
3440
if torch.cuda.is_available():
@@ -72,33 +78,48 @@ def train(self):
7278
train_metrics = self.model.evaluate(self.task, split="train")
7379
val_metrics = self.model.evaluate(self.task, split="val")
7480

75-
# Log to wandb
81+
# Log to wandb or Tensorboard
7682
metrics = {
7783
"epoch": epoch,
7884
**{f"train_{k}": v for k, v in train_metrics.items()},
7985
**{f"val_{k}": v for k, v in val_metrics.items()}
8086
}
87+
if self.output == "tensorboard":
88+
self.train_writer.add_scalar("Loss", train_metrics['loss'], epoch)
89+
self.val_writer.add_scalar("Loss", val_metrics['loss'], epoch)
8190
try:
8291
metrics["train_auc"] = train_metrics['auc']
8392
metrics["val_auc"] = val_metrics['auc']
93+
if self.output == "tensorboard":
94+
self.train_writer.add_scalar("AUC", train_metrics['auc'], epoch)
95+
self.val_writer.add_scalar("AUC", val_metrics['auc'], epoch)
8496
except:
8597
pass
8698
if self.task.metadata['multi_label']:
8799
metrics["train_jaccard"] = train_metrics["jaccard"]
88100
metrics["val_jaccard"] = val_metrics["jaccard"]
101+
if self.output == "tensorboard":
102+
self.train_writer.add_scalar("Jaccard", train_metrics['jaccard'], epoch)
103+
self.val_writer.add_scalar("Jaccard", val_metrics['jaccard'], epoch)
89104
else:
90105
try:
91106
metrics["train_balanced_accuracy"] = train_metrics["balanced_accuracy"]
92107
metrics["val_balanced_accuracy"] = val_metrics["balanced_accuracy"]
108+
if self.output == "tensorboard":
109+
self.train_writer.add_scalar("Balanced_acc", train_metrics['balanced_accuracy'], epoch)
110+
self.val_writer.add_scalar("Balanced_acc", val_metrics['balanced_accuracy'], epoch)
93111
except:
94112
pass
95113
try:
96114
metrics["train_mcc"] = train_metrics["mcc"]
97115
metrics["val_mcc"] = val_metrics["mcc"]
116+
if self.output == "tensorboard":
117+
self.train_writer.add_scalar("MCC", train_metrics['mcc'], epoch)
118+
self.val_writer.add_scalar("MCC", val_metrics['mcc'], epoch)
98119
except:
99120
pass
100-
101-
wandb.log(metrics)
121+
if self.output == "wandb":
122+
wandb.log(metrics)
102123
self.training_log.append(metrics)
103124

104125
# Print progress
@@ -110,7 +131,11 @@ def train(self):
110131
)
111132

112133
self.save_results()
113-
wandb.finish()
134+
if self.output == "tensorboard":
135+
self.train_writer.flush()
136+
self.val_writer.flush()
137+
else:
138+
wandb.finish()
114139

115140
def save_results(self):
116141
"""Save final results and metrics"""

0 commit comments

Comments
 (0)