You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The issue: Enzyme.make_zero shares structural arrays (rowval, colptr)
between primal and shadow sparse matrices. Broadcast operations like
`dA .-= z * y'` can change the sparsity pattern, corrupting both shadow
AND primal matrices.
The fix: Add sparse-safe helper functions that operate directly on
the nonzeros array to preserve the sparsity pattern:
- _safe_add!: Add arrays preserving sparsity pattern
- _safe_zero!: Zero arrays preserving sparsity pattern
- _sparse_outer_sub!: Compute outer product subtraction preserving
sparsity pattern (uses non-allocating CSC loop for CPU sparse matrices)
Also added SparseArrays as a dependency for the LinearSolveEnzymeExt
extension.
Note: This PR addresses the sparsity pattern corruption issue. The
dense matrix Enzyme tests are failing on main (pre-existing issue).
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
alg::Const; kwargs...) where {RT, LP <:LinearSolve.LinearProblem}
@@ -25,10 +103,10 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
25
103
dres = func.val(prob.dval, alg.val; kwargs...)
26
104
27
105
if dres.b == res.b
28
-
dres.b .=false
106
+
_safe_zero!(dres.b)
29
107
end
30
108
if dres.A == res.A
31
-
dres.A .=false
109
+
_safe_zero!(dres.A)
32
110
end
33
111
34
112
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
@@ -125,22 +203,23 @@ function EnzymeRules.reverse(
125
203
126
204
if EnzymeRules.width(config) ==1
127
205
if d_A !== prob_d_A
128
-
prob_d_A .+= d_A
129
-
d_A .=0
206
+
# Use sparse-safe addition to preserve sparsity pattern
207
+
_safe_add!(prob_d_A, d_A)
208
+
_safe_zero!(d_A)
130
209
end
131
210
if d_b !== prob_d_b
132
-
prob_d_b.+=d_b
133
-
d_b .=0
211
+
_safe_add!(prob_d_b, d_b)
212
+
_safe_zero!(d_b)
134
213
end
135
214
else
136
215
for (_prob_d_A, _d_A, _prob_d_b, _d_b) inzip(prob_d_A, d_A, prob_d_b, d_b)
137
216
if _d_A !== _prob_d_A
138
-
_prob_d_A.+=_d_A
139
-
_d_A .=0
217
+
_safe_add!(_prob_d_A, _d_A)
218
+
_safe_zero!(_d_A)
140
219
end
141
220
if _d_b !== _prob_d_b
142
-
_prob_d_b.+=_d_b
143
-
_d_b .=0
221
+
_safe_add!(_prob_d_b, _d_b)
222
+
_safe_zero!(_d_b)
144
223
end
145
224
end
146
225
end
@@ -149,7 +228,7 @@ function EnzymeRules.reverse(
149
228
end
150
229
151
230
# y=inv(A) B
152
-
# dA −= z y^T
231
+
# dA −= z y^T
153
232
# dB += z, where z = inv(A^T) dy
154
233
function EnzymeRules.augmented_primal(
155
234
config, func::Const{typeof(LinearSolve.solve!)},
@@ -254,7 +333,8 @@ function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
254
333
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
255
334
end
256
335
257
-
dA .-= z *transpose(y)
336
+
# Use sparse-safe outer product subtraction to preserve sparsity pattern
0 commit comments