Skip to content

Commit 4cf8489

Browse files
committed
feat: implement adversarial robustness and safety module
1 parent 61ce76e commit 4cf8489

25 files changed

+1933
-124
lines changed

src/DecompositionMethods/MatrixDecomposition/LuDecomposition.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,10 @@ public override Vector<T> Solve(Vector<T> b)
186186
if (i > j)
187187
L[i, j] = A[i, j];
188188
else if (i == j)
189+
{
189190
L[i, j] = NumOps.One;
191+
U[i, j] = A[i, j]; // Also copy diagonal to U
192+
}
190193
else
191194
U[i, j] = A[i, j];
192195
}
@@ -275,7 +278,10 @@ public override Vector<T> Solve(Vector<T> b)
275278
if (i > j)
276279
L[i, j] = A[i, j];
277280
else if (i == j)
281+
{
278282
L[i, j] = NumOps.One;
283+
U[i, j] = A[i, j]; // Also copy diagonal to U
284+
}
279285
else
280286
U[i, j] = A[i, j];
281287
}

src/FitDetectors/LearningCurveFitDetector.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,13 @@ protected override T CalculateConfidenceLevel(ModelEvaluationData<T, TInput, TOu
148148
var validationVariance = CalculateVariance(evaluationData.ValidationSet.PredictionStats.LearningCurve);
149149

150150
var totalVariance = NumOps.Add(trainingVariance, validationVariance);
151-
return NumOps.Subtract(NumOps.One, NumOps.Divide(totalVariance, NumOps.FromDouble(2)));
151+
var result = NumOps.Subtract(NumOps.One, NumOps.Divide(totalVariance, NumOps.FromDouble(2)));
152+
153+
// Clamp confidence to [0, 1]
154+
if (NumOps.LessThan(result, NumOps.Zero)) result = NumOps.Zero;
155+
if (NumOps.GreaterThan(result, NumOps.One)) result = NumOps.One;
156+
157+
return result;
152158
}
153159

154160
/// <summary>

src/FitDetectors/ResidualAnalysisFitDetector.cs

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,48 @@ protected override FitType DetermineFitType(ModelEvaluationData<T, TInput, TOutp
200200
/// </remarks>
201201
protected override T CalculateConfidenceLevel(ModelEvaluationData<T, TInput, TOutput> evaluationData)
202202
{
203-
var trainingConfidence = NumOps.Subtract(NumOps.One, NumOps.Divide(evaluationData.TrainingSet.ErrorStats.PopulationStandardError, evaluationData.TrainingSet.ErrorStats.MeanBiasError));
204-
var validationConfidence = NumOps.Subtract(NumOps.One, NumOps.Divide(evaluationData.ValidationSet.ErrorStats.PopulationStandardError, evaluationData.ValidationSet.ErrorStats.MeanBiasError));
205-
var testConfidence = NumOps.Subtract(NumOps.One, NumOps.Divide(evaluationData.TestSet.ErrorStats.PopulationStandardError, evaluationData.TestSet.ErrorStats.MeanBiasError));
203+
// Helper function to safely calculate confidence, avoiding division by zero
204+
T SafeConfidence(T populationStdError, T meanBiasError)
205+
{
206+
// Avoid division by zero - if mean bias error is close to zero, use R-squared only
207+
if (NumOps.LessThanOrEquals(NumOps.Abs(meanBiasError), NumOps.FromDouble(1e-10)))
208+
{
209+
return NumOps.FromDouble(0.5); // Neutral confidence when bias is near zero
210+
}
211+
var ratio = NumOps.Divide(populationStdError, meanBiasError);
212+
var conf = NumOps.Subtract(NumOps.One, NumOps.Abs(ratio));
213+
// Clamp to [0, 1]
214+
if (NumOps.LessThan(conf, NumOps.Zero)) conf = NumOps.Zero;
215+
if (NumOps.GreaterThan(conf, NumOps.One)) conf = NumOps.One;
216+
return conf;
217+
}
218+
219+
var trainingConfidence = SafeConfidence(evaluationData.TrainingSet.ErrorStats.PopulationStandardError, evaluationData.TrainingSet.ErrorStats.MeanBiasError);
220+
var validationConfidence = SafeConfidence(evaluationData.ValidationSet.ErrorStats.PopulationStandardError, evaluationData.ValidationSet.ErrorStats.MeanBiasError);
221+
var testConfidence = SafeConfidence(evaluationData.TestSet.ErrorStats.PopulationStandardError, evaluationData.TestSet.ErrorStats.MeanBiasError);
206222

207223
var averageConfidence = NumOps.Divide(NumOps.Add(NumOps.Add(trainingConfidence, validationConfidence), testConfidence), NumOps.FromDouble(3));
208224

209-
// Adjust confidence based on R-squared values
210-
var r2Adjustment = NumOps.Divide(NumOps.Add(NumOps.Add(evaluationData.TrainingSet.PredictionStats.R2, evaluationData.ValidationSet.PredictionStats.R2), evaluationData.TestSet.PredictionStats.R2), NumOps.FromDouble(3));
211-
212-
return NumOps.Multiply(averageConfidence, r2Adjustment);
225+
// Adjust confidence based on R-squared values (clamp R2 values to [0, 1])
226+
var r2Training = evaluationData.TrainingSet.PredictionStats.R2;
227+
if (NumOps.LessThan(r2Training, NumOps.Zero)) r2Training = NumOps.Zero;
228+
if (NumOps.GreaterThan(r2Training, NumOps.One)) r2Training = NumOps.One;
229+
230+
var r2Validation = evaluationData.ValidationSet.PredictionStats.R2;
231+
if (NumOps.LessThan(r2Validation, NumOps.Zero)) r2Validation = NumOps.Zero;
232+
if (NumOps.GreaterThan(r2Validation, NumOps.One)) r2Validation = NumOps.One;
233+
234+
var r2Test = evaluationData.TestSet.PredictionStats.R2;
235+
if (NumOps.LessThan(r2Test, NumOps.Zero)) r2Test = NumOps.Zero;
236+
if (NumOps.GreaterThan(r2Test, NumOps.One)) r2Test = NumOps.One;
237+
var r2Adjustment = NumOps.Divide(NumOps.Add(NumOps.Add(r2Training, r2Validation), r2Test), NumOps.FromDouble(3));
238+
239+
var result = NumOps.Multiply(averageConfidence, r2Adjustment);
240+
241+
// Final clamp to ensure [0, 1]
242+
if (NumOps.LessThan(result, NumOps.Zero)) result = NumOps.Zero;
243+
if (NumOps.GreaterThan(result, NumOps.One)) result = NumOps.One;
244+
245+
return result;
213246
}
214247
}

src/Helpers/StatisticsHelper.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4882,7 +4882,11 @@ public static T CalculatePrecisionRecallAUC(Vector<T> actual, Vector<T> predicte
48824882
T totalNegatives = _numOps.Subtract(_numOps.FromDouble(actual.Length), totalPositives);
48834883

48844884
if (_numOps.Equals(totalPositives, _numOps.Zero) || _numOps.Equals(totalNegatives, _numOps.Zero))
4885-
throw new ArgumentException("Both positive and negative samples are required to calculate AUC.");
4885+
{
4886+
// Return 0 for regression data or data without both classes
4887+
// AUC is a classification metric and is not meaningful for regression
4888+
return _numOps.Zero;
4889+
}
48864890

48874891
T truePositives = _numOps.Zero;
48884892
T falsePositives = _numOps.Zero;

src/Interpolation/BicubicInterpolation.cs

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,40 @@ public BicubicInterpolation(Vector<T> x, Vector<T> y, Matrix<T> z, IMatrixDecomp
105105
/// <returns>The interpolated z-value at the specified (x,y) coordinates.</returns>
106106
public T Interpolate(T x, T y)
107107
{
108-
int i = FindInterval(_x, x);
109-
int j = FindInterval(_y, y);
108+
// Check if x and y exactly match grid points - return exact value
109+
// Use binary search for O(log n) instead of O(n×m)
110+
int exactXIndex = BinarySearchExact(_x, x);
111+
if (exactXIndex >= 0)
112+
{
113+
int exactYIndex = BinarySearchExact(_y, y);
114+
if (exactYIndex >= 0)
115+
{
116+
return _z[exactXIndex, exactYIndex];
117+
}
118+
}
119+
120+
int iOriginal = FindInterval(_x, x);
121+
int jOriginal = FindInterval(_y, y);
122+
123+
// Use original intervals for dx/dy normalization (maintains correct interpolation position)
124+
T dx = _numOps.Divide(_numOps.Subtract(x, _x[iOriginal]), _numOps.Subtract(_x[iOriginal + 1], _x[iOriginal]));
125+
T dy = _numOps.Divide(_numOps.Subtract(y, _y[jOriginal]), _numOps.Subtract(_y[jOriginal + 1], _y[jOriginal]));
126+
127+
// Clamp indices for 4×4 neighborhood extraction (ensures valid array access)
128+
int i = Math.Max(1, Math.Min(iOriginal, _x.Length - 3));
129+
int j = Math.Max(1, Math.Min(jOriginal, _y.Length - 3));
130+
131+
// Adjust dx/dy to account for neighborhood shift when clamping occurred
132+
if (i != iOriginal)
133+
{
134+
// Recalculate dx relative to the clamped cell
135+
dx = _numOps.Divide(_numOps.Subtract(x, _x[i]), _numOps.Subtract(_x[i + 1], _x[i]));
136+
}
137+
if (j != jOriginal)
138+
{
139+
// Recalculate dy relative to the clamped cell
140+
dy = _numOps.Divide(_numOps.Subtract(y, _y[j]), _numOps.Subtract(_y[j + 1], _y[j]));
141+
}
110142

111143
T[,] p = new T[4, 4];
112144
for (int m = -1; m <= 2; m++)
@@ -117,9 +149,6 @@ public T Interpolate(T x, T y)
117149
}
118150
}
119151

120-
T dx = _numOps.Divide(_numOps.Subtract(x, _x[i]), _numOps.Subtract(_x[i + 1], _x[i]));
121-
T dy = _numOps.Divide(_numOps.Subtract(y, _y[j]), _numOps.Subtract(_y[j + 1], _y[j]));
122-
123152
return InterpolateBicubicPatch(p, dx, dy);
124153
}
125154

@@ -239,4 +268,37 @@ private int FindInterval(Vector<T> values, T point)
239268

240269
return values.Length - 2;
241270
}
271+
272+
/// <summary>
273+
/// Binary search for an exact match in a sorted array.
274+
/// </summary>
275+
/// <param name="values">The sorted array to search.</param>
276+
/// <param name="target">The target value to find.</param>
277+
/// <returns>The index of the exact match, or -1 if not found.</returns>
278+
private int BinarySearchExact(Vector<T> values, T target)
279+
{
280+
int left = 0;
281+
int right = values.Length - 1;
282+
283+
while (left <= right)
284+
{
285+
int mid = left + (right - left) / 2;
286+
287+
if (_numOps.Equals(values[mid], target))
288+
{
289+
return mid;
290+
}
291+
292+
if (_numOps.LessThan(values[mid], target))
293+
{
294+
left = mid + 1;
295+
}
296+
else
297+
{
298+
right = mid - 1;
299+
}
300+
}
301+
302+
return -1; // Not found
303+
}
242304
}

src/Interpolation/CatmullRomSplineInterpolation.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ public CatmullRomSplineInterpolation(Vector<T> x, Vector<T> y, double tension =
104104
/// <returns>The interpolated y-value at the specified x-coordinate.</returns>
105105
public T Interpolate(T x)
106106
{
107+
// Check if x exactly matches a known point
108+
for (int k = 0; k < _x.Length; k++)
109+
{
110+
if (_numOps.Equals(x, _x[k]))
111+
{
112+
return _y[k];
113+
}
114+
}
115+
107116
int i = FindInterval(x);
108117

109118
// Get the four points needed for the spline calculation

src/Interpolation/HermiteInterpolation.cs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,19 @@ public T Interpolate(T x)
108108
// Calculate the normalized position within the interval (0 to 1)
109109
T t = _numOps.Divide(_numOps.Subtract(x, x0), h);
110110

111-
// Calculate the Hermite basis functions
112-
// These control how the values and slopes at the endpoints influence the curve
113-
T h00 = _numOps.Multiply(_numOps.Subtract(_numOps.FromDouble(2), _numOps.Multiply(_numOps.FromDouble(3), t)), _numOps.Add(_numOps.One, _numOps.Multiply(_numOps.FromDouble(-1), t)));
114-
T h10 = _numOps.Multiply(h, _numOps.Multiply(t, _numOps.Add(_numOps.One, _numOps.Multiply(_numOps.FromDouble(-1), t))));
115-
T h01 = _numOps.Multiply(_numOps.Multiply(_numOps.FromDouble(3), t), _numOps.Subtract(_numOps.FromDouble(2), t));
116-
T h11 = _numOps.Multiply(h, _numOps.Multiply(t, _numOps.Subtract(t, _numOps.One)));
111+
// Calculate t² and t³ for the Hermite basis functions
112+
T t2 = _numOps.Multiply(t, t);
113+
T t3 = _numOps.Multiply(t2, t);
114+
115+
// Calculate the Hermite basis functions (correct formulas):
116+
// h00(t) = 2t³ - 3t² + 1
117+
// h10(t) = t³ - 2t² + t (scaled by h in the final calculation)
118+
// h01(t) = -2t³ + 3t²
119+
// h11(t) = t³ - t² (scaled by h in the final calculation)
120+
T h00 = _numOps.Add(_numOps.Add(_numOps.Multiply(_numOps.FromDouble(2), t3), _numOps.Multiply(_numOps.FromDouble(-3), t2)), _numOps.One);
121+
T h10 = _numOps.Multiply(h, _numOps.Add(_numOps.Add(t3, _numOps.Multiply(_numOps.FromDouble(-2), t2)), t));
122+
T h01 = _numOps.Add(_numOps.Multiply(_numOps.FromDouble(-2), t3), _numOps.Multiply(_numOps.FromDouble(3), t2));
123+
T h11 = _numOps.Multiply(h, _numOps.Add(t3, _numOps.Multiply(_numOps.FromDouble(-1), t2)));
117124

118125
// Combine the basis functions with the values and slopes to get the interpolated value
119126
return _numOps.Add(

src/Interpolation/MonotoneCubicInterpolation.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,15 @@ public T Interpolate(T x)
104104
T t2 = _numOps.Multiply(t, t);
105105
T t3 = _numOps.Multiply(t2, t);
106106

107-
T h00 = _numOps.Add(_numOps.Multiply(_numOps.FromDouble(2), t3), _numOps.Subtract(_numOps.Multiply(_numOps.FromDouble(-3), t2), _numOps.FromDouble(1)));
108-
T h10 = _numOps.Add(t3, _numOps.Subtract(_numOps.Multiply(_numOps.FromDouble(-2), t2), t));
107+
// Hermite basis functions (correct formulas):
108+
// h00(t) = 2t³ - 3t² + 1
109+
// h10(t) = t³ - 2t² + t
110+
// h01(t) = -2t³ + 3t²
111+
// h11(t) = t³ - t²
112+
T h00 = _numOps.Add(_numOps.Add(_numOps.Multiply(_numOps.FromDouble(2), t3), _numOps.Multiply(_numOps.FromDouble(-3), t2)), _numOps.One);
113+
T h10 = _numOps.Add(_numOps.Add(t3, _numOps.Multiply(_numOps.FromDouble(-2), t2)), t);
109114
T h01 = _numOps.Add(_numOps.Multiply(_numOps.FromDouble(-2), t3), _numOps.Multiply(_numOps.FromDouble(3), t2));
110-
T h11 = _numOps.Subtract(t3, t2);
115+
T h11 = _numOps.Add(t3, _numOps.Multiply(_numOps.FromDouble(-1), t2));
111116

112117
return _numOps.Add(
113118
_numOps.Add(

src/Interpolation/NaturalSplineInterpolation.cs

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ public NaturalSplineInterpolation(Vector<T> x, Vector<T> y, int degree = 3, IMat
9292
_degree = degree;
9393
_decomposition = decomposition;
9494
_numOps = MathHelper.GetNumericOperations<T>();
95-
_coefficients = new Vector<T>[_degree];
95+
_coefficients = new Vector<T>[_degree + 1];
9696

97-
for (int i = 0; i < _degree; i++)
97+
for (int i = 0; i <= _degree; i++)
9898
{
9999
_coefficients[i] = new Vector<T>(x.Length - 1);
100100
}
@@ -125,7 +125,7 @@ public T Interpolate(T x)
125125
T dx = _numOps.Subtract(x, _x[i]);
126126
T result = _y[i];
127127

128-
for (int j = 1; j < _degree; j++)
128+
for (int j = 1; j <= _degree; j++)
129129
{
130130
result = _numOps.Add(result, _numOps.Multiply(_coefficients[j][i], Power(dx, j)));
131131
}
@@ -150,46 +150,69 @@ public T Interpolate(T x)
150150
private void CalculateCoefficients()
151151
{
152152
int n = _x.Length;
153-
Matrix<T> A = new Matrix<T>(n, n);
154-
Vector<T> b = new Vector<T>(n);
155153

156-
// Set up the system of equations
154+
// Calculate interval widths h[i] = x[i+1] - x[i]
155+
Vector<T> h = new Vector<T>(n - 1);
157156
for (int i = 0; i < n - 1; i++)
158157
{
159-
T h = _numOps.Subtract(_x[i + 1], _x[i]);
160-
A[i, i] = h;
161-
if (i < n - 2)
162-
A[i, i + 1] = _numOps.Multiply(_numOps.FromDouble(2), _numOps.Add(h, _numOps.Subtract(_x[i + 2], _x[i + 1])));
163-
if (i > 0)
164-
A[i, i - 1] = h;
165-
166-
if (i < n - 2)
167-
{
168-
T dy1 = _numOps.Divide(_numOps.Subtract(_y[i + 1], _y[i]), h);
169-
T dy2 = _numOps.Divide(_numOps.Subtract(_y[i + 2], _y[i + 1]), _numOps.Subtract(_x[i + 2], _x[i + 1]));
170-
b[i] = _numOps.Multiply(_numOps.FromDouble(6), _numOps.Subtract(dy2, dy1));
171-
}
158+
h[i] = _numOps.Subtract(_x[i + 1], _x[i]);
172159
}
173160

174-
// Apply natural spline conditions
161+
// Set up the tridiagonal system for natural cubic spline
162+
// We solve for second derivatives M[i] at each point
163+
Matrix<T> A = new Matrix<T>(n, n);
164+
Vector<T> b = new Vector<T>(n);
165+
166+
// Natural spline boundary conditions: M[0] = 0, M[n-1] = 0
167+
// Ensure boundary rows are completely set (all zeros except diagonal = 1)
168+
for (int j = 0; j < n; j++)
169+
{
170+
A[0, j] = _numOps.Zero;
171+
A[n - 1, j] = _numOps.Zero;
172+
}
175173
A[0, 0] = _numOps.One;
176-
A[n - 1, n - 1] = _numOps.One;
177174
b[0] = _numOps.Zero;
175+
A[n - 1, n - 1] = _numOps.One;
178176
b[n - 1] = _numOps.Zero;
179177

180-
// Solve the system
178+
// Interior equations (tridiagonal system)
179+
// h[i-1]*M[i-1] + 2*(h[i-1]+h[i])*M[i] + h[i]*M[i+1] = 6*((y[i+1]-y[i])/h[i] - (y[i]-y[i-1])/h[i-1])
180+
for (int i = 1; i < n - 1; i++)
181+
{
182+
A[i, i - 1] = h[i - 1];
183+
A[i, i] = _numOps.Multiply(_numOps.FromDouble(2), _numOps.Add(h[i - 1], h[i]));
184+
A[i, i + 1] = h[i];
185+
186+
T slope1 = _numOps.Divide(_numOps.Subtract(_y[i], _y[i - 1]), h[i - 1]);
187+
T slope2 = _numOps.Divide(_numOps.Subtract(_y[i + 1], _y[i]), h[i]);
188+
b[i] = _numOps.Multiply(_numOps.FromDouble(6), _numOps.Subtract(slope2, slope1));
189+
}
190+
191+
// Solve the system for second derivatives M
181192
var decomposition = _decomposition ?? new LuDecomposition<T>(A);
182-
Vector<T> m = MatrixSolutionHelper.SolveLinearSystem(b, decomposition);
193+
Vector<T> M = MatrixSolutionHelper.SolveLinearSystem(b, decomposition);
183194

184-
// Calculate the coefficients
195+
// Calculate the spline coefficients for each segment
196+
// S(x) = a + b*(x-x[i]) + c*(x-x[i])^2 + d*(x-x[i])^3
185197
for (int i = 0; i < n - 1; i++)
186198
{
187-
T h = _numOps.Subtract(_x[i + 1], _x[i]);
199+
// a[i] = y[i]
188200
_coefficients[0][i] = _y[i];
189-
_coefficients[1][i] = _numOps.Divide(_numOps.Subtract(_y[i + 1], _y[i]), h);
190-
_coefficients[1][i] = _numOps.Subtract(_coefficients[1][i], _numOps.Multiply(_numOps.Divide(h, _numOps.FromDouble(6)), _numOps.Add(_numOps.Multiply(_numOps.FromDouble(2), m[i]), m[i + 1])));
191-
_coefficients[2][i] = _numOps.Divide(m[i], _numOps.FromDouble(2));
192-
_coefficients[3][i] = _numOps.Divide(_numOps.Subtract(m[i + 1], m[i]), _numOps.Multiply(_numOps.FromDouble(6), h));
201+
202+
// b[i] = (y[i+1] - y[i])/h[i] - h[i]*(2*M[i] + M[i+1])/6
203+
T term1 = _numOps.Divide(_numOps.Subtract(_y[i + 1], _y[i]), h[i]);
204+
T term2 = _numOps.Divide(
205+
_numOps.Multiply(h[i], _numOps.Add(_numOps.Multiply(_numOps.FromDouble(2), M[i]), M[i + 1])),
206+
_numOps.FromDouble(6));
207+
_coefficients[1][i] = _numOps.Subtract(term1, term2);
208+
209+
// c[i] = M[i]/2
210+
_coefficients[2][i] = _numOps.Divide(M[i], _numOps.FromDouble(2));
211+
212+
// d[i] = (M[i+1] - M[i])/(6*h[i])
213+
_coefficients[3][i] = _numOps.Divide(
214+
_numOps.Subtract(M[i + 1], M[i]),
215+
_numOps.Multiply(_numOps.FromDouble(6), h[i]));
193216
}
194217
}
195218

src/LearningRateSchedulers/StepLRScheduler.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ public StepLRScheduler(
7575
/// <inheritdoc/>
7676
protected override double ComputeLearningRate(int step)
7777
{
78-
// Decay happens AFTER completing stepSize steps
79-
// With stepSize=3: steps 1,2,3 have no decay; step 4+ has first decay
80-
// Formula: floor((step - 1) / stepSize) for step > 0
81-
int decayCount = step > 0 ? (step - 1) / _stepSize : 0;
78+
// Decay happens AT stepSize steps (first decay at step = stepSize)
79+
// Formula: floor(step / stepSize)
80+
// This ensures steps 0..(stepSize-1) have decayCount=0, steps stepSize..(2*stepSize-1) have decayCount=1, etc.
81+
int decayCount = step / _stepSize;
8282
return _baseLearningRate * Math.Pow(_gamma, decayCount);
8383
}
8484

0 commit comments

Comments
 (0)