You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This fix addresses the issue where Enzyme AD with sparse matrices causes
the primal matrix's sparsity pattern (rowval, colptr) to be corrupted.
## Root Cause
Enzyme.make_zero shares the structural arrays (rowval, colptr) between
the primal and shadow sparse matrices. When broadcast operations like
`dA .-= z * transpose(y)` modify the shadow's sparsity pattern, they
inadvertently corrupt the primal's structure as well.
## Solution
Add sparse-safe helper functions that operate directly on nzval arrays:
- `_safe_add!`: Adds sparse matrices by operating on nonzeros()
- `_safe_zero!`: Zeros sparse matrices by operating on nonzeros()
- `_sparse_outer_sub!`: Accumulates outer product gradients only into
existing non-zero positions using vectorized operations
The key insight is to use vectorized indexing (`z[rows] .* y[col_indices]`)
rather than nested loops with scalar indexing, making the code more
portable (though GPU sparse matrices would need their own extension).
## Changes
- Import SparseArrays accessor functions (nonzeros, rowvals, getcolptr)
- Dispatch on SparseMatrixCSC specifically (not AbstractSparseMatrix)
- Use vectorized operations instead of nested loops in _sparse_outer_sub!
- Add _expand_colptr_to_col_indices helper to build column index vector
- Add documentation explaining the root cause and solution
Fixes#835
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
alg::Const; kwargs...) where {RT, LP <:LinearSolve.LinearProblem}
@@ -25,10 +126,10 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
25
126
dres = func.val(prob.dval, alg.val; kwargs...)
26
127
27
128
if dres.b == res.b
28
-
dres.b .=false
129
+
_safe_zero!(dres.b)
29
130
end
30
131
if dres.A == res.A
31
-
dres.A .=false
132
+
_safe_zero!(dres.A)
32
133
end
33
134
34
135
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
@@ -125,22 +226,23 @@ function EnzymeRules.reverse(
125
226
126
227
if EnzymeRules.width(config) ==1
127
228
if d_A !== prob_d_A
128
-
prob_d_A .+= d_A
129
-
d_A .=0
229
+
# Use sparse-safe addition to preserve sparsity pattern
230
+
_safe_add!(prob_d_A, d_A)
231
+
_safe_zero!(d_A)
130
232
end
131
233
if d_b !== prob_d_b
132
-
prob_d_b.+=d_b
133
-
d_b .=0
234
+
_safe_add!(prob_d_b, d_b)
235
+
_safe_zero!(d_b)
134
236
end
135
237
else
136
238
for (_prob_d_A, _d_A, _prob_d_b, _d_b) inzip(prob_d_A, d_A, prob_d_b, d_b)
137
239
if _d_A !== _prob_d_A
138
-
_prob_d_A.+=_d_A
139
-
_d_A .=0
240
+
_safe_add!(_prob_d_A, _d_A)
241
+
_safe_zero!(_d_A)
140
242
end
141
243
if _d_b !== _prob_d_b
142
-
_prob_d_b.+=_d_b
143
-
_d_b .=0
244
+
_safe_add!(_prob_d_b, _d_b)
245
+
_safe_zero!(_d_b)
144
246
end
145
247
end
146
248
end
@@ -149,7 +251,7 @@ function EnzymeRules.reverse(
149
251
end
150
252
151
253
# y=inv(A) B
152
-
# dA −= z y^T
254
+
# dA −= z y^T
153
255
# dB += z, where z = inv(A^T) dy
154
256
function EnzymeRules.augmented_primal(
155
257
config, func::Const{typeof(LinearSolve.solve!)},
@@ -254,7 +356,8 @@ function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
254
356
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
255
357
end
256
358
257
-
dA .-= z *transpose(y)
359
+
# Use sparse-safe outer product subtraction to preserve sparsity pattern
0 commit comments