Skip to content

Commit e1b218e

Browse files
author
anna-grim
committed
major refactor: renaming, trainer, doc
1 parent a10b901 commit e1b218e

File tree

17 files changed

+667
-553
lines changed

17 files changed

+667
-553
lines changed

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ dependencies = [
4444
'zarr',
4545
]
4646

47-
aind-exaspim-image-utils = { git = "https://github.com/AllenNeuralDynamics/aind-exaspim-image-utils.git", branch = "main" }
48-
4947
[project.optional-dependencies]
5048
dev = [
5149
'black',
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
"""
2+
Created on Wed July 25 16:00:00 2025
3+
4+
@author: Anna Grim
5+
6+
7+
Code for a custom class for training neural networks to perform classification
8+
tasks within the GraphTrace pipeline.
9+
10+
"""
11+
12+
from datetime import datetime
13+
from sklearn.metrics import precision_score, recall_score, accuracy_score
14+
from torch.optim.lr_scheduler import CosineAnnealingLR
15+
from torch.utils.tensorboard import SummaryWriter
16+
17+
import numpy as np
18+
import os
19+
import torch
20+
import torch.nn as nn
21+
import torch.optim as optim
22+
23+
from deep_neurographs.utils import ml_util, util
24+
25+
26+
class Trainer:
27+
"""
28+
Trainer class for training a model to perform binary classifcation.
29+
30+
Attributes
31+
----------
32+
batch_size : int
33+
Number of samples per batch during training.
34+
best_f1 : float
35+
Best F1 score achieved so far on valiation dataset.
36+
criterion : torch.nn.BCEWithLogitsLoss
37+
Loss function used during training.
38+
log_dir : str
39+
Path to directory that tensorboard and checkpoints are saved to.
40+
max_epochs : int
41+
Maximum number of training epochs.
42+
model : torch.nn.Module
43+
Model that is trained to perform binary classification.
44+
model_name : str
45+
Name of model used for logging and checkpointing.
46+
optimizer : torch.optim.AdamW
47+
Optimizer that is used during training.
48+
scheduler : torch.optim.lr_scheduler.CosineAnnealingLR
49+
Scheduler used to the adjust learning rate.
50+
writer : torch.utils.tensorboard.SummaryWriter
51+
Writer object that writes to a tensorboard.
52+
"""
53+
54+
def __init__(
55+
self,
56+
model,
57+
model_name,
58+
output_dir,
59+
batch_size=32,
60+
lr=1e-3,
61+
max_epochs=200,
62+
):
63+
"""
64+
Instantiates a Trainer object.
65+
66+
Parameters
67+
----------
68+
model : torch.nn.Module
69+
Model that is trained to perform binary classification.
70+
model_name : str
71+
Name of model used for logging and checkpointing.
72+
output_dir : str
73+
Directory that tensorboard and model checkpoints are written to.
74+
batch_size : int, optional
75+
Number of samples per batch during training. Default is 32.
76+
lr : float
77+
Learning rate.
78+
max_epochs : int
79+
Maximum number of training epochs.
80+
"""
81+
# Initializations
82+
exp_name = "session-" + datetime.today().strftime("%Y%m%d_%H%M")
83+
log_dir = os.path.join(output_dir, exp_name)
84+
util.mkdir(log_dir)
85+
86+
# Instance attributes
87+
self.batch_size = batch_size
88+
self.best_f1 = 0
89+
self.log_dir = log_dir
90+
self.max_epochs = max_epochs
91+
self.model_name = model_name
92+
93+
self.criterion = nn.BCEWithLogitsLoss()
94+
self.model = model.to("cuda")
95+
self.optimizer = optim.AdamW(self.model.parameters(), lr=lr)
96+
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=25)
97+
self.writer = SummaryWriter(log_dir=log_dir)
98+
99+
# --- Core Routines ---
100+
def run(self, train_dataloader, val_dataloader):
101+
"""
102+
Run the full training and validation loop.
103+
104+
Parameters
105+
----------
106+
train_dataset : torch.utils.data.Dataset
107+
Dataloader used for training.
108+
val_dataset : torch.utils.data.Dataset
109+
Dataloader used for validation.
110+
111+
Returns
112+
-------
113+
None
114+
"""
115+
exp_name = os.path.basename(os.path.normpath(self.log_dir))
116+
print("\nExperiment:", exp_name)
117+
for epoch in range(self.max_epochs):
118+
# Train-Validate
119+
train_stats = self.train_step(train_dataloader, epoch)
120+
val_stats, new_best = self.validate_step(val_dataloader, epoch)
121+
122+
# Report reuslts
123+
print(f"\nEpoch {epoch}: " + ("New Best!" if new_best else " "))
124+
self.report_stats(train_stats, is_train=True)
125+
self.report_stats(val_stats, is_train=False)
126+
127+
# Step scheduler
128+
self.scheduler.step()
129+
130+
def train_step(self, train_dataloader, epoch):
131+
"""
132+
Perform a single training epoch over the provided DataLoader.
133+
134+
Parameters
135+
----------
136+
train_dataloader : torch.utils.data.DataLoader
137+
DataLoader for the training dataset.
138+
epoch : int
139+
Current training epoch.
140+
141+
Returns
142+
-------
143+
dict
144+
Dictionary of aggregated training metrics.
145+
"""
146+
self.model.train()
147+
loss, y, hat_y = list(), list(), list()
148+
for x_i, y_i in train_dataloader:
149+
# Forward pass
150+
hat_y_i, loss_i = self.forward_pass(x_i, y_i)
151+
152+
# Backward pass
153+
self.optimizer.zero_grad()
154+
loss_i.backward()
155+
self.optimizer.step()
156+
157+
# Store results
158+
y.extend(ml_util.to_cpu(y_i, True).flatten().tolist())
159+
hat_y.extend(ml_util.to_cpu(hat_y_i, True).flatten().tolist())
160+
loss.append(float(ml_util.to_cpu(loss_i)))
161+
162+
# Write stats to tensorboard
163+
stats = self.compute_stats(y, hat_y)
164+
stats["loss"] = np.mean(loss)
165+
self.update_tensorboard(stats, epoch, "train_")
166+
return stats
167+
168+
def validate_step(self, val_dataloader, epoch):
169+
"""
170+
Perform a full validation loop over the given dataloader.
171+
172+
Parameters
173+
----------
174+
val_dataloader : torch.utils.data.DataLoader
175+
DataLoader for the validation dataset.
176+
epoch : int
177+
Current training epoch.
178+
179+
Returns
180+
-------
181+
tuple
182+
stats : dict
183+
Dictionary of aggregated validation metrics.
184+
is_best : bool
185+
True if the current F1 score is the best so far.
186+
"""
187+
loss, y, hat_y = list(), list(), list()
188+
with torch.no_grad():
189+
self.model.eval()
190+
for x_i, y_i in val_dataloader:
191+
# Run model
192+
hat_y_i, loss_i = self.forward_pass(x_i, y_i)
193+
194+
# Store results
195+
y.extend(ml_util.to_cpu(y_i, True).flatten().tolist())
196+
hat_y.extend(ml_util.to_cpu(hat_y_i, True).flatten().tolist())
197+
loss.append(float(ml_util.to_cpu(loss_i)))
198+
199+
# Write stats to tensorboard
200+
stats = self.compute_stats(y, hat_y)
201+
stats["loss"] = np.mean(loss)
202+
self.update_tensorboard(stats, epoch, "val_")
203+
204+
# Check for new best
205+
if stats["f1"] > self.best_f1:
206+
self.save_model(epoch)
207+
self.best_f1 = stats["f1"]
208+
return stats, True
209+
else:
210+
return stats, False
211+
212+
def forward_pass(self, x, y):
213+
"""
214+
Perform a forward pass through the model and compute loss.
215+
216+
Parameters
217+
----------
218+
x : torch.Tensor
219+
Input tensor with shape (B, C, D, H, W).
220+
y : torch.Tensor
221+
Ground truth labels with shape (B, C, D, H, W).
222+
223+
Returns
224+
-------
225+
tuple
226+
hat_y : torch.Tensor
227+
Model predictions.
228+
loss : torch.Tensor
229+
Computed loss value.
230+
"""
231+
x = x.to("cuda", dtype=torch.float32)
232+
y = y.to("cuda", dtype=torch.float32)
233+
hat_y = self.model(x)
234+
loss = self.criterion(hat_y, y)
235+
return hat_y, loss
236+
237+
# --- Helpers
238+
def compute_stats(self, y, hat_y):
239+
"""
240+
Compute F1 score, precision, and recall for each sample in a batch.
241+
242+
Parameters
243+
----------
244+
y : torch.Tensor
245+
Ground truth labels of shape (B, 1, D, H, W) or (B, 1, H, W).
246+
hat_y : torch.Tensor
247+
Model predictions of the same shape as ground truth.
248+
249+
Returns
250+
-------
251+
dict
252+
Dictionary containing lists of per-sample metrics.
253+
"""
254+
# Reformat predictions
255+
hat_y = (np.array(hat_y) > 0).astype(int)
256+
y = np.array(y, dtype=int)
257+
258+
# Compute stats
259+
avg_prec = precision_score(y, hat_y, zero_division=np.nan)
260+
avg_recall = recall_score(y, hat_y, zero_division=np.nan)
261+
avg_f1 = 2 * avg_prec * avg_recall / max((avg_prec + avg_recall), 1)
262+
avg_acc = accuracy_score(y, hat_y)
263+
stats = {
264+
"f1": avg_f1,
265+
"precision": avg_prec,
266+
"recall": avg_recall,
267+
"accuracy": avg_acc
268+
}
269+
return stats
270+
271+
def report_stats(self, stats, is_train=True):
272+
"""
273+
Print a summary of training or validation statistics.
274+
275+
Parameters
276+
----------
277+
stats : dict
278+
Dictionary of metric names to values.
279+
is_train : bool, optional
280+
Indication of whether stats were computed during training.
281+
282+
Returns
283+
-------
284+
None
285+
"""
286+
summary = " Train: " if is_train else " Val: "
287+
for key, value in stats.items():
288+
summary += f"{key}={value:.4f}, "
289+
print(summary)
290+
291+
def save_model(self, epoch):
292+
"""
293+
Save the current model state to a file.
294+
295+
Parameters
296+
----------
297+
epoch : int
298+
Current training epoch.
299+
300+
Returns
301+
-------
302+
None
303+
"""
304+
date = datetime.today().strftime("%Y%m%d")
305+
filename = f"{self.model_name}-{date}-{epoch}-{self.best_f1:.4f}.pth"
306+
path = os.path.join(self.log_dir, filename)
307+
torch.save(self.model.state_dict(), path)
308+
309+
def update_tensorboard(self, stats, epoch, prefix):
310+
"""
311+
Log scalar statistics to TensorBoard.
312+
313+
Parameters
314+
----------
315+
stats : dict
316+
Dictionary of metric names (str) to lists of values.
317+
epoch : int
318+
Current training epoch.
319+
prefix : str
320+
Prefix to prepend to each metric name when logging.
321+
322+
Returns
323+
-------
324+
None
325+
"""
326+
for key, value in stats.items():
327+
self.writer.add_scalar(prefix + key, stats[key], epoch)

src/deep_neurographs/machine_learning/vision_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _get_flattened_size(self, patch_shape):
116116
----------
117117
patch_shape : Tuple[int]
118118
Shape of input image patch.
119-
119+
120120
Returns
121121
-------
122122
int

0 commit comments

Comments
 (0)