Skip to content
This repository was archived by the owner on Mar 2, 2025. It is now read-only.

Commit f886ce2

Browse files
Merge pull request #41 from basalt-org/matmul
2 parents ea9e4a4 + 20e7488 commit f886ce2

File tree

15 files changed

+389
-1282
lines changed

15 files changed

+389
-1282
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ basalt.📦
55

66
examples/data/mnist_test.csv
77
examples/data/mnist_train.csv
8+
examples/data/mnist_train_small.csv
89

910
output_model.onnx
1011
Makefile

basalt/autograd/graph.mojo

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,7 @@ struct Graph:
5757
self.symbol_count += 1
5858
return scalar_id
5959

60-
fn constant(
61-
inout self, shape: TensorShape, data: List[SIMD[dtype, 1]]
62-
) -> Symbol:
60+
fn constant(inout self, shape: TensorShape, data: List[SIMD[dtype, 1]]) -> Symbol:
6361
var cst = Param(data)
6462
var constant_id = Symbol(self.symbol_count, dtype, shape, trainable=False)
6563
self.params.put(constant_id, cst)
@@ -111,9 +109,7 @@ struct Graph:
111109
self.result_trainable(operand_1),
112110
)
113111

114-
self.nodes.append(
115-
Node(op, res, operand_1, operand_2, operand_3, attributes)
116-
)
112+
self.nodes.append(Node(op, res, operand_1, operand_2, operand_3, attributes))
117113
self.symbol_count += 1
118114
return res
119115

basalt/autograd/ops/basics.mojo

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ struct DIV:
177177
/ (t2.load[nelts](index2) ** 2)
178178
* ug.load[nelts](i),
179179
)
180+
180181
vectorize[vec_div_bw_broadcast, 1](ug_shape.num_elements())
181182

182183
else:
@@ -257,9 +258,7 @@ struct EXP:
257258

258259
@parameter
259260
fn vec_exp_bw[nelts: Int](i: Int):
260-
res_grad.store[nelts](
261-
i, exp(t1.load[nelts](i)) * ug.load[nelts](i)
262-
)
261+
res_grad.store[nelts](i, exp(t1.load[nelts](i)) * ug.load[nelts](i))
263262

264263
vectorize[vec_exp_bw, nelts](ug_shape.num_elements())
265264
return res_grad ^

0 commit comments

Comments
 (0)