Skip to content

Commit 9d7e168

Browse files
franklinicclaude
andcommitted
fix: address PR review comments for pruning strategies
- PruningMask: use PointwiseMultiply vectorization for SIMD acceleration - StructuredPruningStrategy: implement Filter and Channel pruning types - LotteryTicketPruningStrategy: O(n⁴)→O(n²) optimization, add validation - PruningStrategyTests: fix potential integer overflow - ModelCompressionBenchmarks: fix weight-matrix sizing losing tail elements 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 3bd6d72 commit 9d7e168

File tree

5 files changed

+180
-62
lines changed

5 files changed

+180
-62
lines changed

AiDotNetBenchmarkTests/BenchmarkTests/ModelCompressionBenchmarks.cs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,19 @@ public void Setup()
6161
}
6262

6363
// Initialize weight matrix (for structured pruning)
64+
// Use ceiling division to ensure all weights fit in the matrix
6465
int rows = (int)Math.Sqrt(WeightCount);
65-
int cols = WeightCount / rows;
66+
int cols = (WeightCount + rows - 1) / rows; // Ceiling division: ensures rows * cols >= WeightCount
6667
_weightMatrix = new Matrix<double>(rows, cols);
67-
int idx = 0;
68-
for (int i = 0; i < rows && idx < WeightCount; i++)
69-
for (int j = 0; j < cols && idx < WeightCount; j++)
70-
_weightMatrix[i, j] = _weights[idx++];
68+
for (int i = 0; i < rows; i++)
69+
{
70+
for (int j = 0; j < cols; j++)
71+
{
72+
int idx = i * cols + j;
73+
// Fill with weight data if available, otherwise zero (padding for alignment)
74+
_weightMatrix[i, j] = idx < WeightCount ? _weights[idx] : 0.0;
75+
}
76+
}
7177

7278
// Initialize compression algorithms
7379
_deepCompression = new DeepCompression<double>(

src/Pruning/LotteryTicketPruningStrategy.cs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ public class LotteryTicketPruningStrategy<T> : IPruningStrategy<T>
100100
/// </remarks>
101101
public LotteryTicketPruningStrategy(int iterativeRounds = 5)
102102
{
103+
if (iterativeRounds <= 0)
104+
throw new ArgumentOutOfRangeException(nameof(iterativeRounds), "iterativeRounds must be greater than 0.");
105+
103106
_numOps = MathHelper.GetNumericOperations<T>();
104107
_initialWeights = new Dictionary<string, Matrix<T>>();
105108
_iterativeRounds = iterativeRounds;
@@ -134,10 +137,10 @@ public void StoreInitialWeights(string layerName, Matrix<T> weights)
134137
/// <exception cref="InvalidOperationException">Thrown when no weights stored for the layer</exception>
135138
public Matrix<T> GetInitialWeights(string layerName)
136139
{
137-
if (!_initialWeights.ContainsKey(layerName))
140+
if (!_initialWeights.TryGetValue(layerName, out var weights))
138141
throw new InvalidOperationException($"No initial weights stored for layer {layerName}");
139142

140-
return _initialWeights[layerName].Clone();
143+
return weights.Clone();
141144
}
142145

143146
/// <summary>
@@ -210,6 +213,9 @@ public Tensor<T> ComputeImportanceScores(Tensor<T> weights, Tensor<T>? gradients
210213
/// <returns>Binary mask (1 = keep, 0 = prune)</returns>
211214
public IPruningMask<T> CreateMask(Vector<T> importanceScores, double targetSparsity)
212215
{
216+
if (targetSparsity < 0.0 || targetSparsity > 1.0)
217+
throw new ArgumentException("targetSparsity must be between 0 and 1 (inclusive).", nameof(targetSparsity));
218+
213219
double prunePerRound = 1.0 - Math.Pow(1.0 - targetSparsity, 1.0 / _iterativeRounds);
214220
var currentMask = new PruningMask<T>(1, importanceScores.Length);
215221

@@ -230,9 +236,10 @@ public IPruningMask<T> CreateMask(Vector<T> importanceScores, double targetSpars
230236

231237
flatScores.Sort((a, b) => a.score.CompareTo(b.score));
232238

239+
// Reuse maskedScores instead of recomputing currentMask.Apply(importanceScores) in loop
233240
var keepIndices = new bool[importanceScores.Length];
234241
for (int i = 0; i < importanceScores.Length; i++)
235-
keepIndices[i] = !_numOps.Equals(currentMask.Apply(importanceScores)[i], _numOps.Zero);
242+
keepIndices[i] = !_numOps.Equals(maskedScores[i], _numOps.Zero);
236243

237244
for (int i = 0; i < numToPrune && i < flatScores.Count; i++)
238245
{
@@ -267,6 +274,9 @@ public IPruningMask<T> CreateMask(Vector<T> importanceScores, double targetSpars
267274
/// </remarks>
268275
public IPruningMask<T> CreateMask(Matrix<T> importanceScores, double targetSparsity)
269276
{
277+
if (targetSparsity < 0.0 || targetSparsity > 1.0)
278+
throw new ArgumentException("targetSparsity must be between 0 and 1 (inclusive).", nameof(targetSparsity));
279+
270280
// Iterative magnitude pruning to target sparsity
271281
// Each round prunes (1 - (1 - targetSparsity)^(1/rounds)) of remaining weights
272282
double prunePerRound = 1.0 - Math.Pow(1.0 - targetSparsity, 1.0 / _iterativeRounds);
@@ -298,11 +308,12 @@ public IPruningMask<T> CreateMask(Matrix<T> importanceScores, double targetSpars
298308

299309
flatScores.Sort((a, b) => a.score.CompareTo(b.score));
300310

311+
// Reuse maskedScores instead of recomputing currentMask.Apply(importanceScores) in loop (O(n⁴) → O(n²))
301312
var keepIndices = new bool[importanceScores.Rows, importanceScores.Columns];
302313

303314
for (int i = 0; i < importanceScores.Rows; i++)
304315
for (int j = 0; j < importanceScores.Columns; j++)
305-
keepIndices[i, j] = !_numOps.Equals(currentMask.Apply(importanceScores)[i, j], _numOps.Zero);
316+
keepIndices[i, j] = !_numOps.Equals(maskedScores[i, j], _numOps.Zero);
306317

307318
for (int i = 0; i < numToPrune && i < flatScores.Count; i++)
308319
{
@@ -324,6 +335,9 @@ public IPruningMask<T> CreateMask(Matrix<T> importanceScores, double targetSpars
324335
/// <returns>Binary mask (1 = keep, 0 = prune)</returns>
325336
public IPruningMask<T> CreateMask(Tensor<T> importanceScores, double targetSparsity)
326337
{
338+
if (targetSparsity < 0.0 || targetSparsity > 1.0)
339+
throw new ArgumentException("targetSparsity must be between 0 and 1 (inclusive).", nameof(targetSparsity));
340+
327341
double prunePerRound = 1.0 - Math.Pow(1.0 - targetSparsity, 1.0 / _iterativeRounds);
328342
var flatScoresInit = importanceScores.ToVector();
329343
// Initialize mask to all true
@@ -616,14 +630,16 @@ public void ResetToInitialWeights(string layerName, Matrix<T> weights, IPruningM
616630
if (initial.Rows != weights.Rows || initial.Columns != weights.Columns)
617631
throw new ArgumentException("Weight dimensions don't match initial weights");
618632

633+
// Compute masked initial weights once (O(n²) instead of O(n⁴))
634+
var maskedInitial = mask.Apply(initial);
635+
619636
// Reset non-pruned weights to their initialization
620637
for (int i = 0; i < weights.Rows; i++)
621638
{
622639
for (int j = 0; j < weights.Columns; j++)
623640
{
624641
// Keep initial value where mask is 1, zero otherwise
625-
var maskValue = mask.Apply(initial)[i, j];
626-
weights[i, j] = maskValue;
642+
weights[i, j] = maskedInitial[i, j];
627643
}
628644
}
629645
}

src/Pruning/PruningMask.cs

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using AiDotNet.Extensions;
12
using AiDotNet.Interfaces;
23

34
namespace AiDotNet.Pruning;
@@ -158,17 +159,8 @@ public Matrix<T> Apply(Matrix<T> weights)
158159
if (weights.Rows != _mask.Rows || weights.Columns != _mask.Columns)
159160
throw new ArgumentException("Weight matrix shape must match mask shape");
160161

161-
var result = new Matrix<T>(weights.Rows, weights.Columns);
162-
163-
for (int i = 0; i < weights.Rows; i++)
164-
{
165-
for (int j = 0; j < weights.Columns; j++)
166-
{
167-
result[i, j] = _numOps.Multiply(weights[i, j], _mask[i, j]);
168-
}
169-
}
170-
171-
return result;
162+
// Use vectorized PointwiseMultiply for SIMD acceleration
163+
return weights.PointwiseMultiply(_mask);
172164
}
173165

174166
/// <summary>
@@ -195,29 +187,69 @@ public Tensor<T> Apply(Tensor<T> weights)
195187
// For 4D tensors (convolutional layers: [filters, channels, height, width])
196188
if (weights.Rank == 4)
197189
{
198-
// Clone the tensor manually since Clone() might not exist
199-
var flatData = weights.ToVector();
200-
var result = Tensor<T>.FromVector(flatData, (int[])weights.Shape.Clone());
201190
int filters = weights.Shape[0];
202191
int channels = weights.Shape[1];
192+
int height = weights.Shape[2];
193+
int width = weights.Shape[3];
194+
int spatialSize = height * width;
203195

204-
// Apply mask element-wise for now (unstructured pruning)
205-
// For structured pruning, this would need to be modified
206-
for (int f = 0; f < filters; f++)
196+
// Check if mask dimensions match filter/channel for structured pruning
197+
if (_mask.Rows == filters && _mask.Columns == channels)
207198
{
208-
for (int c = 0; c < channels; c++)
199+
// Apply filter/channel-level pruning using vectorized operations
200+
// Broadcast mask [filters, channels] to [filters, channels, height, width]
201+
int totalElements = filters * channels * spatialSize;
202+
var maskData = new T[totalElements];
203+
204+
// Fill mask data by broadcasting each mask value across spatial dimensions
205+
int idx = 0;
206+
for (int f = 0; f < filters; f++)
209207
{
210-
for (int h = 0; h < weights.Shape[2]; h++)
208+
for (int c = 0; c < channels; c++)
211209
{
212-
for (int w = 0; w < weights.Shape[3]; w++)
210+
T maskValue = _mask[f, c];
211+
// Fill spatial elements with the same mask value
212+
for (int s = 0; s < spatialSize; s++)
213213
{
214-
result[f, c, h, w] = weights[f, c, h, w];
214+
maskData[idx++] = maskValue;
215215
}
216216
}
217217
}
218+
219+
var maskTensor = Tensor<T>.FromVector(new Vector<T>(maskData), new int[] { filters, channels, height, width });
220+
221+
// Use vectorized PointwiseMultiply for SIMD acceleration
222+
return weights.PointwiseMultiply(maskTensor);
218223
}
224+
else
225+
{
226+
// Unstructured pruning: apply mask element-by-element
227+
var flatWeights = weights.ToVector();
228+
int totalElements = flatWeights.Length;
229+
230+
if (_mask.Rows * _mask.Columns != totalElements)
231+
{
232+
throw new ArgumentException(
233+
$"Mask shape [{_mask.Rows}, {_mask.Columns}] does not match 4D tensor total elements ({totalElements}) " +
234+
$"or filter/channel dimensions [{filters}, {channels}]");
235+
}
219236

220-
return result;
237+
// Convert mask to flat vector and use vectorized PointwiseMultiply
238+
var flatMask = new T[totalElements];
239+
int idx = 0;
240+
for (int i = 0; i < _mask.Rows; i++)
241+
{
242+
for (int j = 0; j < _mask.Columns; j++)
243+
{
244+
flatMask[idx++] = _mask[i, j];
245+
}
246+
}
247+
248+
var flatMaskVector = new Vector<T>(flatMask);
249+
var flatResult = flatWeights.PointwiseMultiply(flatMaskVector);
250+
251+
return Tensor<T>.FromVector(flatResult, (int[])weights.Shape.Clone());
252+
}
221253
}
222254

223255
throw new NotSupportedException($"Tensor rank {weights.Rank} not supported for pruning");
@@ -237,21 +269,12 @@ public Tensor<T> Apply(Tensor<T> weights)
237269
public Vector<T> Apply(Vector<T> weights)
238270
{
239271
// For 1D vectors, the mask is stored as a single-row matrix
240-
int length = Math.Min(weights.Length, _mask.Columns);
241-
var result = new T[weights.Length];
242-
243-
for (int i = 0; i < length; i++)
244-
{
245-
result[i] = _numOps.Multiply(weights[i], _mask[0, i]);
246-
}
247-
248-
// Keep remaining elements as-is if vector is longer than mask
249-
for (int i = length; i < weights.Length; i++)
250-
{
251-
result[i] = weights[i];
252-
}
272+
if (weights.Length != _mask.Columns)
273+
throw new ArgumentException($"Weight vector length ({weights.Length}) must match mask columns ({_mask.Columns})");
253274

254-
return new Vector<T>(result);
275+
// Extract mask row as vector and use vectorized PointwiseMultiply
276+
var maskVector = _mask.GetRow(0);
277+
return weights.PointwiseMultiply(maskVector);
255278
}
256279

257280
/// <summary>

src/Pruning/StructuredPruningStrategy.cs

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,47 @@ public Matrix<T> ComputeImportanceScores(Matrix<T> weights, Matrix<T>? gradients
194194
}
195195
break;
196196

197-
default:
198-
throw new NotImplementedException($"Pruning type {_pruningType} not yet implemented");
197+
case StructurePruningType.Filter:
198+
// For 2D matrix, Filter pruning treats rows as filters
199+
// Score for each filter (row) = L2 norm of its weights
200+
for (int row = 0; row < weights.Rows; row++)
201+
{
202+
double rowNorm = 0;
203+
for (int col = 0; col < weights.Columns; col++)
204+
{
205+
double val = Convert.ToDouble(weights[row, col]);
206+
rowNorm += val * val;
207+
}
208+
rowNorm = Math.Sqrt(rowNorm);
209+
210+
// Assign same score to all weights in row
211+
for (int col = 0; col < weights.Columns; col++)
212+
{
213+
scores[row, col] = _numOps.FromDouble(rowNorm);
214+
}
215+
}
216+
break;
217+
218+
case StructurePruningType.Channel:
219+
// For 2D matrix, Channel pruning treats columns as channels (same as Neuron)
220+
// Score for each channel (column) = L2 norm of its weights
221+
for (int col = 0; col < weights.Columns; col++)
222+
{
223+
double columnNorm = 0;
224+
for (int row = 0; row < weights.Rows; row++)
225+
{
226+
double val = Convert.ToDouble(weights[row, col]);
227+
columnNorm += val * val;
228+
}
229+
columnNorm = Math.Sqrt(columnNorm);
230+
231+
// Assign same score to all weights in column
232+
for (int row = 0; row < weights.Rows; row++)
233+
{
234+
scores[row, col] = _numOps.FromDouble(columnNorm);
235+
}
236+
}
237+
break;
199238
}
200239

201240
return scores;
@@ -230,29 +269,30 @@ public IPruningMask<T> CreateMask(Matrix<T> importanceScores, double targetSpars
230269
switch (_pruningType)
231270
{
232271
case StructurePruningType.Neuron:
233-
// Prune entire columns (neurons)
234-
int totalNeurons = importanceScores.Columns;
235-
int neuronsToPrune = (int)(totalNeurons * targetSparsity);
272+
case StructurePruningType.Channel:
273+
// Prune entire columns (neurons/channels)
274+
int totalColumns = importanceScores.Columns;
275+
int columnsToPrune = (int)(totalColumns * targetSparsity);
236276

237-
// Get score for each neuron (all rows in column have same score)
238-
var neuronScores = new List<(int col, double score)>();
277+
// Get score for each column (all rows in column have same score)
278+
var columnScores = new List<(int col, double score)>();
239279
for (int col = 0; col < importanceScores.Columns; col++)
240280
{
241281
double score = Convert.ToDouble(importanceScores[0, col]);
242-
neuronScores.Add((col, score));
282+
columnScores.Add((col, score));
243283
}
244284

245-
// Sort by score (ascending)
246-
neuronScores.Sort((a, b) => a.score.CompareTo(b.score));
285+
// Sort by score (ascending - lowest scores get pruned first)
286+
columnScores.Sort((a, b) => a.score.CompareTo(b.score));
247287

248288
// Mark columns to keep
249289
var keepColumns = new bool[importanceScores.Columns];
250290
for (int i = 0; i < importanceScores.Columns; i++)
251291
keepColumns[i] = true;
252292

253-
for (int i = 0; i < neuronsToPrune; i++)
293+
for (int i = 0; i < columnsToPrune; i++)
254294
{
255-
keepColumns[neuronScores[i].col] = false;
295+
keepColumns[columnScores[i].col] = false;
256296
}
257297

258298
// Create mask
@@ -265,8 +305,41 @@ public IPruningMask<T> CreateMask(Matrix<T> importanceScores, double targetSpars
265305
}
266306
break;
267307

268-
default:
269-
throw new NotImplementedException($"Pruning type {_pruningType} not yet implemented");
308+
case StructurePruningType.Filter:
309+
// Prune entire rows (filters)
310+
int totalRows = importanceScores.Rows;
311+
int rowsToPrune = (int)(totalRows * targetSparsity);
312+
313+
// Get score for each row (all columns in row have same score)
314+
var rowScores = new List<(int row, double score)>();
315+
for (int row = 0; row < importanceScores.Rows; row++)
316+
{
317+
double score = Convert.ToDouble(importanceScores[row, 0]);
318+
rowScores.Add((row, score));
319+
}
320+
321+
// Sort by score (ascending - lowest scores get pruned first)
322+
rowScores.Sort((a, b) => a.score.CompareTo(b.score));
323+
324+
// Mark rows to keep
325+
var keepRows = new bool[importanceScores.Rows];
326+
for (int i = 0; i < importanceScores.Rows; i++)
327+
keepRows[i] = true;
328+
329+
for (int i = 0; i < rowsToPrune; i++)
330+
{
331+
keepRows[rowScores[i].row] = false;
332+
}
333+
334+
// Create mask
335+
for (int row = 0; row < importanceScores.Rows; row++)
336+
{
337+
for (int col = 0; col < importanceScores.Columns; col++)
338+
{
339+
keepIndices[row, col] = keepRows[row];
340+
}
341+
}
342+
break;
270343
}
271344

272345
var mask = new PruningMask<T>(importanceScores.Rows, importanceScores.Columns);

tests/AiDotNet.Tests/Pruning/PruningStrategyTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ public void LotteryTicket_IterativePruning_AchievesTargetSparsity()
343343
var weights = new Matrix<double>(5, 5);
344344
for (int i = 0; i < 5; i++)
345345
for (int j = 0; j < 5; j++)
346-
weights[i, j] = (i + 1) * (j + 1) * 0.1; // Varying magnitudes
346+
weights[i, j] = ((double)(i + 1)) * (j + 1) * 0.1; // Varying magnitudes
347347

348348
var strategy = new LotteryTicketPruningStrategy<double>(iterativeRounds: 5);
349349

0 commit comments

Comments
 (0)