Skip to content

Commit 223cc46

Browse files
Fix Enzyme sparse matrix sparsity pattern corruption (issue #835)
The issue: Enzyme.make_zero shares structural arrays (rowval, colptr) between primal and shadow sparse matrices. Broadcast operations like `dA .-= z * y'` can change the sparsity pattern, corrupting both shadow AND primal matrices. The fix: Add sparse-safe helper functions that operate directly on the nonzeros array to preserve the sparsity pattern: - _safe_add!: Add arrays preserving sparsity pattern - _safe_zero!: Zero arrays preserving sparsity pattern - _sparse_outer_sub!: Compute outer product subtraction preserving sparsity pattern (uses non-allocating CSC loop for CPU sparse matrices) Also added SparseArrays as a dependency for the LinearSolveEnzymeExt extension. Note: This PR addresses the sparsity pattern corruption issue. The dense matrix Enzyme tests are failing on main (pre-existing issue). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 5bd131d commit 223cc46

File tree

2 files changed

+93
-13
lines changed

2 files changed

+93
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ LinearSolveCUDAExt = "CUDA"
6262
LinearSolveCUDSSExt = "CUDSS"
6363
LinearSolveCUSOLVERRFExt = ["CUSOLVERRF", "SparseArrays"]
6464
LinearSolveCliqueTreesExt = ["CliqueTrees", "SparseArrays"]
65-
LinearSolveEnzymeExt = "EnzymeCore"
65+
LinearSolveEnzymeExt = ["EnzymeCore", "SparseArrays"]
6666
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
6767
LinearSolveFastLapackInterfaceExt = "FastLapackInterface"
6868
LinearSolveForwardDiffExt = "ForwardDiff"

ext/LinearSolveEnzymeExt.jl

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,87 @@ using LinearSolve: LinearSolve, SciMLLinearSolveAlgorithm, init, solve!, LinearP
66
using LinearSolve.LinearAlgebra
77
using EnzymeCore
88
using EnzymeCore: EnzymeRules
9+
using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC
910

1011
@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:LinearSolve.SciMLLinearSolveAlgorithm}) = true
1112

13+
# Helper functions for sparse-safe gradient accumulation
14+
# These avoid broadcast operations that can change sparsity patterns
15+
#
16+
# Key insight: Enzyme.make_zero shares structural arrays (rowval, colptr) between
17+
# primal and shadow sparse matrices. Broadcast operations like `dA .-= z * y'` can
18+
# change the sparsity pattern, corrupting both shadow AND primal. We must operate
19+
# directly on nzval to preserve the sparsity pattern.
20+
21+
using SparseArrays: nonzeros, rowvals, getcolptr
22+
23+
"""
24+
_safe_add!(dst, src)
25+
26+
Add `src` to `dst` in a way that preserves the sparsity pattern of sparse matrices.
27+
For sparse matrices with matching sparsity patterns (as with Enzyme shadows),
28+
this operates directly on the nonzeros array.
29+
"""
30+
function _safe_add!(dst::SparseMatrixCSC, src::SparseMatrixCSC)
31+
nonzeros(dst) .+= nonzeros(src)
32+
return dst
33+
end
34+
35+
function _safe_add!(dst::AbstractArray, src::AbstractArray)
36+
dst .+= src
37+
return dst
38+
end
39+
40+
"""
41+
_safe_zero!(A)
42+
43+
Zero out `A` in a way that preserves the sparsity pattern of sparse matrices.
44+
For sparse matrices, this operates directly on the nonzeros array.
45+
"""
46+
function _safe_zero!(A::SparseMatrixCSC)
47+
fill!(nonzeros(A), zero(eltype(A)))
48+
return A
49+
end
50+
51+
function _safe_zero!(A::AbstractArray)
52+
fill!(A, zero(eltype(A)))
53+
return A
54+
end
55+
56+
"""
57+
_sparse_outer_sub!(dA, z, y)
58+
59+
Compute `dA .-= z * transpose(y)` in a sparsity-preserving manner.
60+
61+
For sparse matrices, only accumulates gradients into existing non-zero positions.
62+
This is mathematically correct for sparse matrix AD: gradients are only meaningful
63+
at positions where the matrix can be modified.
64+
65+
Note: SparseMatrixCSC is a CPU-only type. GPU sparse matrices (CuSparseMatrixCSC, etc.)
66+
have their own types and would need handling in their respective extensions.
67+
"""
68+
function _sparse_outer_sub!(dA::SparseMatrixCSC, z::AbstractVector, y::AbstractVector)
69+
rows = rowvals(dA)
70+
vals = nonzeros(dA)
71+
colptr = getcolptr(dA)
72+
73+
# Non-allocating loop over CSC structure
74+
# This is efficient and cache-friendly (column-major order)
75+
@inbounds for col in 1:size(dA, 2)
76+
y_col = y[col]
77+
for idx in colptr[col]:(colptr[col + 1] - 1)
78+
vals[idx] -= z[rows[idx]] * y_col
79+
end
80+
end
81+
82+
return dA
83+
end
84+
85+
function _sparse_outer_sub!(dA::AbstractArray, z::AbstractVector, y::AbstractVector)
86+
dA .-= z * transpose(y)
87+
return dA
88+
end
89+
1290
function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
1391
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
1492
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@@ -25,10 +103,10 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
25103
dres = func.val(prob.dval, alg.val; kwargs...)
26104

27105
if dres.b == res.b
28-
dres.b .= false
106+
_safe_zero!(dres.b)
29107
end
30108
if dres.A == res.A
31-
dres.A .= false
109+
_safe_zero!(dres.A)
32110
end
33111

34112
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
@@ -125,22 +203,23 @@ function EnzymeRules.reverse(
125203

126204
if EnzymeRules.width(config) == 1
127205
if d_A !== prob_d_A
128-
prob_d_A .+= d_A
129-
d_A .= 0
206+
# Use sparse-safe addition to preserve sparsity pattern
207+
_safe_add!(prob_d_A, d_A)
208+
_safe_zero!(d_A)
130209
end
131210
if d_b !== prob_d_b
132-
prob_d_b .+= d_b
133-
d_b .= 0
211+
_safe_add!(prob_d_b, d_b)
212+
_safe_zero!(d_b)
134213
end
135214
else
136215
for (_prob_d_A, _d_A, _prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)
137216
if _d_A !== _prob_d_A
138-
_prob_d_A .+= _d_A
139-
_d_A .= 0
217+
_safe_add!(_prob_d_A, _d_A)
218+
_safe_zero!(_d_A)
140219
end
141220
if _d_b !== _prob_d_b
142-
_prob_d_b .+= _d_b
143-
_d_b .= 0
221+
_safe_add!(_prob_d_b, _d_b)
222+
_safe_zero!(_d_b)
144223
end
145224
end
146225
end
@@ -149,7 +228,7 @@ function EnzymeRules.reverse(
149228
end
150229

151230
# y=inv(A) B
152-
# dA −= z y^T
231+
# dA −= z y^T
153232
# dB += z, where z = inv(A^T) dy
154233
function EnzymeRules.augmented_primal(
155234
config, func::Const{typeof(LinearSolve.solve!)},
@@ -254,7 +333,8 @@ function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
254333
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
255334
end
256335

257-
dA .-= z * transpose(y)
336+
# Use sparse-safe outer product subtraction to preserve sparsity pattern
337+
_sparse_outer_sub!(dA, z, y)
258338
db .+= z
259339
dy .= eltype(dy)(0)
260340
end

0 commit comments

Comments
 (0)