Skip to content
This repository was archived by the owner on Mar 2, 2025. It is now read-only.

Commit a3b7de7

Browse files
Merge pull request #40 from basalt-org/perf
Add a simple perf metrics module for model
2 parents f886ce2 + c0de971 commit a3b7de7

File tree

4 files changed

+266
-1
lines changed

4 files changed

+266
-1
lines changed

basalt/nn/model.mojo

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
from collections.optional import Optional
22

3+
from sys import env_get_int
4+
35
from basalt import Graph, Symbol, Tensor, TensorShape
46
from basalt.autograd.ops import forward_op, backward_op
57
from basalt.utils.collection import Collection
68
from basalt.utils.tensorutils import fill
79
from .initializers import initialize_tensor
10+
from basalt.utils.perf_utils import PerfMetrics
11+
12+
13+
# When runing mojo -D DEBUG=1 -I . file, a crash happens at some point at runtime because of an error in linking it seems (because of using -I .)
14+
# For now it seems one has to change this variable manually to be able to run model with performance metrics.
15+
alias DEBUG = env_get_int["DEBUG", 0]()
816

917

1018
fn dv_contains(dv: List[Symbol], symbol: Symbol) -> Bool:
@@ -61,8 +69,15 @@ struct Model[
6169
n_inference_nodes: Optional[Int] = calc_n_inference_nodes(g), # TODO: remove this
6270
]():
6371
var parameters: Parameters[g]
72+
var perf_metrics: PerfMetrics
6473

6574
fn __init__(inout self, inference_only: Bool = False):
75+
@parameter
76+
if DEBUG == 1:
77+
self.perf_metrics = PerfMetrics(g)
78+
else:
79+
self.perf_metrics = PerfMetrics()
80+
6681
self.parameters = Parameters[g]()
6782
self.allocate_tensor_memory()
6883
self.allocate_grad_memory()
@@ -122,6 +137,11 @@ struct Model[
122137
alias out = g.nodes[i].output
123138
alias attrs = g.nodes[i].attributes
124139

140+
# Save start time for performance metrics
141+
@parameter
142+
if DEBUG == 1:
143+
self.perf_metrics.start_forward_pass()
144+
125145
@parameter
126146
if op.num_operands == 1:
127147
# Unary operator
@@ -147,6 +167,11 @@ struct Model[
147167
self.parameters.params[t3],
148168
)
149169

170+
# Save end time for performance metrics
171+
@parameter
172+
if DEBUG == 1:
173+
self.perf_metrics.end_forward_pass(i)
174+
150175
unroll[fw_unroll, num_nodes]()
151176

152177
fn backward(inout self):
@@ -166,6 +191,11 @@ struct Model[
166191
alias t1 = g.nodes[reverse_i].input_1
167192
alias attrs = g.nodes[reverse_i].attributes
168193

194+
# Save start time for performance metrics
195+
@parameter
196+
if DEBUG == 1:
197+
self.perf_metrics.start_backward_pass()
198+
169199
@parameter
170200
if op.num_operands == 1:
171201
# Unary operator
@@ -234,6 +264,11 @@ struct Model[
234264
self.parameters.grads[t3], # grad to be updated: input_3
235265
)
236266

267+
# Save end time for performance metrics
268+
@parameter
269+
if DEBUG == 1:
270+
self.perf_metrics.end_backward_pass(i)
271+
237272
unroll[bw_unroll, g.nodes.size]()
238273

239274
fn allocate_tensor_memory(inout self):
@@ -282,3 +317,7 @@ struct Model[
282317
var out = g.nodes[i].output
283318
if out.trainable:
284319
self.parameters.grads.append(Tensor[dtype](out.shape), out)
320+
321+
fn print_perf_metrics(self, time_format: String = "ns", print_shape: Bool = False):
322+
self.perf_metrics.print_forward_perf_metrics(time_format, print_shape)
323+
self.perf_metrics.print_backward_perf_metrics(time_format, print_shape)

basalt/utils/dataloader.mojo

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,14 @@ struct DataLoader:
9797
# self._data_shape[0] = end - self._current_index
9898
# self._label_shape[0] = end - self._current_index
9999

100+
var temp_current_index = self._current_index
100101
self._current_index += self.batch_size
101102
self._num_batches -= 1
102103

103104
return Batch[dtype](
104105
self.data,
105106
self.labels,
106-
self._current_index,
107+
temp_current_index,
107108
self._data_batch_shape,
108109
self._label_batch_shape,
109110
)

basalt/utils/perf_utils.mojo

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
from time import now
2+
from math import min
3+
from memory import memset, memcpy
4+
5+
from basalt.autograd.node import Node
6+
7+
8+
fn fit_string[num: Int](s: String) -> String:
9+
var data = DTypePointer[DType.int8]().alloc(num + 1)
10+
11+
# Copy the the string up to the length of the buffer
12+
# Fill the rest with spaces & Terminate with zero byte
13+
memcpy(data, s._as_ptr(), min(num, len(s)))
14+
if num - min(num, len(s)) > 0:
15+
memset(data + min(num, len(s)), ord(" "), num - min(num, len(s)))
16+
data[num] = 0
17+
18+
return String(data, num + 1)
19+
20+
21+
fn truncate_decimals[num: Int](s: String) -> String:
22+
var truncated: String
23+
try:
24+
var p1 = s.split(delimiter=".")
25+
truncated = p1[0]
26+
if len(p1) > 1:
27+
var p2 = p1[1].split(delimiter="e")
28+
truncated += "." + fit_string[num](p2[0])
29+
if len(p2) > 1:
30+
truncated += "e" + p2[1]
31+
32+
except e:
33+
print("[WARNING] could not truncate decimals: ", e)
34+
truncated = s
35+
return truncated
36+
37+
38+
@value
39+
struct PerfMetricsValues(CollectionElement):
40+
var node: Node
41+
var time: Float64
42+
43+
fn __init__(inout self, node: Node, time: Float64):
44+
self.node = node
45+
self.time = time
46+
47+
48+
@value
49+
struct PerfMetrics:
50+
# values are in "ns"
51+
# using perf_metrics can reduce the speed of each epoch of the model a little bit
52+
var forward_perf_metrics: List[PerfMetricsValues]
53+
var backward_perf_metrics: List[PerfMetricsValues]
54+
var epochs_forward: Int
55+
var epochs_backward: Int
56+
var start: Int
57+
58+
fn __init__(inout self):
59+
self.forward_perf_metrics = List[PerfMetricsValues]()
60+
self.backward_perf_metrics = List[PerfMetricsValues]()
61+
self.epochs_forward = 0
62+
self.epochs_backward = 0
63+
self.start = 0
64+
65+
fn __init__(inout self, graph: Graph):
66+
self.forward_perf_metrics = List[PerfMetricsValues]()
67+
self.backward_perf_metrics = List[PerfMetricsValues]()
68+
69+
for i in range(graph.nodes.size):
70+
self.forward_perf_metrics.append(PerfMetricsValues(graph.nodes[i], 0.0))
71+
self.backward_perf_metrics.append(PerfMetricsValues(graph.nodes[i], 0.0))
72+
73+
self.epochs_forward = 0
74+
self.epochs_backward = 0
75+
self.start = 0
76+
77+
fn start_forward_pass(inout self):
78+
self.start = now()
79+
80+
fn end_forward_pass(inout self, pos: Int):
81+
# Change this to use references when list has them available
82+
var old_value = self.forward_perf_metrics[pos]
83+
self.forward_perf_metrics[pos] = PerfMetricsValues(
84+
old_value.node, old_value.time + (now() - self.start)
85+
)
86+
self.epochs_forward += 1
87+
88+
fn start_backward_pass(inout self):
89+
self.start = now()
90+
91+
fn end_backward_pass(inout self, pos: Int):
92+
var old_value = self.backward_perf_metrics[pos]
93+
self.backward_perf_metrics[pos] = PerfMetricsValues(
94+
old_value.node, old_value.time + (now() - self.start)
95+
)
96+
self.epochs_backward += 1
97+
98+
fn print_perf_metrics[
99+
type_part: String
100+
](self, time_format: String = "ns", print_shape: Bool = False):
101+
# Calculates the average time for each node operation.
102+
103+
if type_part == "Forward" and len(self.forward_perf_metrics) == 0:
104+
return
105+
if type_part == "Backward" and len(self.backward_perf_metrics) == 0:
106+
return
107+
108+
if type_part == "Forward":
109+
print("\n\nForward pass performance metrics:")
110+
else:
111+
print("\n\nBackward pass performance metrics:")
112+
113+
var total_time: SIMD[DType.float64, 1] = 0
114+
115+
var size: Int = 0
116+
117+
@parameter
118+
if type_part == "Forward":
119+
size = len(self.forward_perf_metrics)
120+
elif type_part == "Backward":
121+
size = len(self.backward_perf_metrics)
122+
for i in range(size):
123+
124+
@parameter
125+
if type_part == "Forward":
126+
total_time += self.forward_perf_metrics[i].time / self.epochs_forward
127+
elif type_part == "Backward":
128+
total_time += self.backward_perf_metrics[i].time / self.epochs_backward
129+
130+
# 1. Header
131+
var header = fit_string[5]("Node") + "| " + fit_string[15](
132+
"Operator"
133+
) + "| " + fit_string[20]("Time [" + time_format + "]") + "| " + fit_string[20](
134+
"Percentage [%]"
135+
)
136+
if print_shape:
137+
header += "| " + fit_string[70]("Shape\t <out> = OP( <in1>, <in2>, <in3> )")
138+
print(header)
139+
140+
# 2. Seperator
141+
var sep = DTypePointer[DType.int8]().alloc(len(header) + 1)
142+
memset(sep, ord("-"), len(header))
143+
sep[len(header)] = 0
144+
var seperator = String(sep, len(header) + 1)
145+
print(seperator)
146+
147+
# 3. Perf Data
148+
for i in range(len(self.forward_perf_metrics)):
149+
var value: PerfMetricsValues
150+
151+
@parameter
152+
if type_part == "Forward":
153+
value = self.forward_perf_metrics[i]
154+
else:
155+
value = self.backward_perf_metrics[i]
156+
157+
var time = value.time
158+
159+
@parameter
160+
if type_part == "Forward":
161+
time = time / self.epochs_forward
162+
else:
163+
time = time / self.epochs_backward
164+
165+
var time_converted = time
166+
if time_format == "ms":
167+
time_converted = time / 1e6
168+
elif time_format == "s":
169+
time_converted = time / 1e9
170+
171+
var print_value = fit_string[5](str(i)) + "| " + fit_string[15](
172+
value.node.operator
173+
) + "| " + fit_string[20](
174+
truncate_decimals[4](time_converted)
175+
) + "| " + fit_string[
176+
20
177+
](
178+
truncate_decimals[3]((time / total_time) * 100) + " %"
179+
) + "| "
180+
181+
if print_shape:
182+
var shape_str: String = ""
183+
shape_str += fit_string[15]("<" + str(value.node.output.shape) + ">")
184+
shape_str += fit_string[7](" = OP(")
185+
shape_str += fit_string[15]("<" + str(value.node.input_1.shape) + ">")
186+
if value.node.input_2:
187+
shape_str += ", " + fit_string[15](
188+
"<" + str(value.node.input_2.value().shape) + ">"
189+
)
190+
if value.node.input_3:
191+
shape_str += ", " + fit_string[15](
192+
"<" + str(value.node.input_3.value().shape) + ">"
193+
)
194+
shape_str += ")"
195+
196+
print_value += shape_str
197+
198+
print(print_value)
199+
200+
var total_time_converted = total_time
201+
if time_format == "ms":
202+
total_time_converted = total_time / 1e6
203+
elif time_format == "s":
204+
total_time_converted = total_time / 1e9
205+
print(
206+
"\nTotal average "
207+
+ type_part
208+
+ " time: "
209+
+ str(total_time_converted)
210+
+ " "
211+
+ time_format
212+
)
213+
214+
fn print_forward_perf_metrics(
215+
self, time_format: String = "ns", print_shape: Bool = False
216+
):
217+
self.print_perf_metrics["Forward"](time_format, print_shape)
218+
219+
fn print_backward_perf_metrics(
220+
self, time_format: String = "ns", print_shape: Bool = False
221+
):
222+
self.print_perf_metrics["Backward"](time_format, print_shape)

examples/mnist.mojo

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,6 @@ fn main():
122122
print("Epoch time: ", (now() - epoch_start) / 1e9, "seconds")
123123

124124
print("Training finished: ", (now() - start) / 1e9, "seconds")
125+
126+
127+
model.print_perf_metrics("ms", True)

0 commit comments

Comments
 (0)