Skip to content

Commit 3ed7d22

Browse files
authored
Fix race condition during the tracing phase of automatic differentiation (vjp, value_and_grad) (#338)
* Fix race condition during the tracing phase of automatic differentiation (vjp, value_and_grad) While the execution of the graph is thread-safe and lazy, the construction of that graph is not.
1 parent 741bc7e commit 3ed7d22

2 files changed

Lines changed: 14 additions & 3 deletions

File tree

Source/MLX/Transforms+Internal.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ private func valueAndGradient(
1515

1616
var r0 = mlx_vector_array_new()
1717
var r1 = mlx_vector_array_new()
18-
mlx_closure_value_and_grad_apply(&r0, &r1, valueAndGrad, input_vector)
18+
19+
_ = evalLock.withLock {
20+
mlx_closure_value_and_grad_apply(&r0, &r1, valueAndGrad, input_vector)
21+
}
1922

2023
defer { mlx_vector_array_free(r0) }
2124
defer { mlx_vector_array_free(r1) }

Source/MLX/Transforms.swift

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ public func jvp(
2727
var r1 = mlx_vector_array_new()
2828

2929
let closure = new_mlx_closure(f)
30-
mlx_jvp(&r0, &r1, closure, primals_mlx, tangents_mlx)
30+
31+
_ = evalLock.withLock {
32+
mlx_jvp(&r0, &r1, closure, primals_mlx, tangents_mlx)
33+
}
34+
3135
mlx_closure_free(closure)
3236

3337
defer { mlx_vector_array_free(r0) }
@@ -60,7 +64,11 @@ public func vjp(
6064
var r1 = mlx_vector_array_new()
6165

6266
let closure = new_mlx_closure(f)
63-
mlx_vjp(&r0, &r1, closure, primals_mlx, cotangents_mlx)
67+
68+
_ = evalLock.withLock {
69+
mlx_vjp(&r0, &r1, closure, primals_mlx, cotangents_mlx)
70+
}
71+
6472
mlx_closure_free(closure)
6573

6674
defer { mlx_vector_array_free(r0) }

0 commit comments

Comments
 (0)