Skip to content

Commit e146388

Browse files
authored
TruncatedSVDSolver (#66)
* new TruncatedSVDSolver * update lstsq tests, verify() only checks after fit() * rename predict() -> solve() for lstsq classes * update lstsq doc page * update lstsq type annotations, fix model solver checker * small corrections * version 0.5.6 -> 0.5.7, update changelog
1 parent 0777346 commit e146388

File tree

18 files changed

+587
-228
lines changed

18 files changed

+587
-228
lines changed

docs/source/api/lstsq.ipynb

Lines changed: 99 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
" L2DecoupledSolver\n",
2727
" TikhonovSolver\n",
2828
" TikhonovDecoupledSolver\n",
29+
" TruncatedSVDSolver\n",
2930
" TotalLeastSquaresSolver\n",
3031
"```"
3132
]
@@ -37,11 +38,8 @@
3738
":::{admonition} Overview\n",
3839
":class: note\n",
3940
"\n",
40-
"- Solver classes defined in `opinf.lstsq` are used to solve an Operator Inference regression.\n",
41-
"- [`fit()`](SolverTemplate.fit) receives data matrices defining the regression.\n",
42-
"- [`predict()`](SolverTemplate.predict) solves the regression and returns the operator matrix.\n",
43-
"- Solver objects are passed to the constructor of {mod}`opinf.models` classes.\n",
44-
"- Models handle the construction of $\\D$ and $\\Z$ from snapshot data, pass these matrices to the solver's `fit()` method, call the solver's `predict()` method to compute $\\Ohat$, and interpret $\\Ohat$ in the context of the model structure.\n",
41+
"- `opinf.lstsq` classes are solve Operator Inference regression problems.\n",
42+
"- `opinf.models` classes handle the construction of the regression matrices from snapshot data, pass these matrices to the solver's [`fit()`](SolverTemplate.fit) method, call the solver's [`solve()`](SolverTemplate.solve) method, and interpret the solution in the context of the model structure.\n",
4543
":::"
4644
]
4745
},
@@ -72,7 +70,7 @@
7270
"cell_type": "markdown",
7371
"metadata": {},
7472
"source": [
75-
"## Least-squares Operator Inference Problems"
73+
"## Operator Inference Regression Problems"
7674
]
7775
},
7876
{
@@ -289,7 +287,7 @@
289287
"outputs": [],
290288
"source": [
291289
"# Solve the least-squares problem for the operator matrix.\n",
292-
"Ohat = solver.predict()\n",
290+
"Ohat = solver.solve()\n",
293291
"print(f\"{Ohat.shape=}\")"
294292
]
295293
},
@@ -348,12 +346,12 @@
348346
"\n",
349347
"The following classes solve Tikhonov-regularized least-squares Operator Inference regressions for different choices of the regularization term $\\mathcal{R}(\\Ohat)$.\n",
350348
"\n",
351-
"| Solver class | Description | Regularization $\\mathcal{R}(\\Ohat)$ |\n",
352-
"| :------------------------------- | :----------------------------------------------- | :------------: |\n",
353-
"| {class}`L2Solver` | One scalar regularizer for all $\\ohat_i$ | $\\lambda^{2}||\\Ohat\\trp||_F^2$ |\n",
354-
"| {class}`L2DecoupledSolver` | Different scalar regularizers for each $\\ohat_i$ | $\\sum_{i=1}^{r}\\lambda_i^2||\\ohat_i||_2^2$ |\n",
355-
"| {class}`TikhonovSolver` | One matrix regularizer for all $\\ohat_i$ | $||\\bfGamma\\Ohat\\trp||_F^2$ |\n",
356-
"| {class}`TikhonovDecoupledSolver` | Different matrix regularizers for each $\\ohat_i$ | $\\sum_{i=1}^{r}||\\bfGamma_i\\ohat_i||_2^2$ |"
349+
"| Solver class | Description | Regularization $\\mathcal{R}(\\Ohat)$ |\n",
350+
"| :------------------------------- | :----------------------------------------------- | :--------------------------------------------: |\n",
351+
"| {class}`L2Solver` | One scalar regularizer for all $\\ohat_i$ | $\\lambda^{2}\\|\\|\\Ohat\\trp\\|\\|_F^2$ |\n",
352+
"| {class}`L2DecoupledSolver` | Different scalar regularizers for each $\\ohat_i$ | $\\sum_{i=1}^{r}\\lambda_i^2\\|\\|\\ohat_i\\|\\|_2^2$ |\n",
353+
"| {class}`TikhonovSolver` | One matrix regularizer for all $\\ohat_i$ | $\\|\\|\\bfGamma\\Ohat\\trp\\|\\|_F^2$ |\n",
354+
"| {class}`TikhonovDecoupledSolver` | Different matrix regularizers for each $\\ohat_i$ | $\\sum_{i=1}^{r}\\|\\|\\bfGamma_i\\ohat_i\\|\\|_2^2$ |"
357355
]
358356
},
359357
{
@@ -373,7 +371,7 @@
373371
"metadata": {},
374372
"outputs": [],
375373
"source": [
376-
"Ohat_L2 = l2solver.predict()\n",
374+
"Ohat_L2 = l2solver.solve()\n",
377375
"l2solver.residual(Ohat_L2)"
378376
]
379377
},
@@ -396,7 +394,7 @@
396394
"metadata": {},
397395
"outputs": [],
398396
"source": [
399-
"Ohat_L2d = l2dsolver.predict()\n",
397+
"Ohat_L2d = l2dsolver.solve()\n",
400398
"l2dsolver.residual(Ohat_L2d)"
401399
]
402400
},
@@ -420,7 +418,7 @@
420418
"metadata": {},
421419
"outputs": [],
422420
"source": [
423-
"Ohat_tik = tiksolver.predict()\n",
421+
"Ohat_tik = tiksolver.solve()\n",
424422
"tiksolver.residual(Ohat_tik)"
425423
]
426424
},
@@ -443,7 +441,7 @@
443441
"metadata": {},
444442
"outputs": [],
445443
"source": [
446-
"Ohat_tik = tiksolver.predict()\n",
444+
"Ohat_tik = tiksolver.solve()\n",
447445
"tiksolver.residual(Ohat_tik)"
448446
]
449447
},
@@ -460,6 +458,83 @@
460458
"print(tikdsolver)"
461459
]
462460
},
461+
{
462+
"cell_type": "markdown",
463+
"metadata": {},
464+
"source": [
465+
"## Truncated SVD"
466+
]
467+
},
468+
{
469+
"cell_type": "markdown",
470+
"metadata": {},
471+
"source": [
472+
"The {class}`TruncatedSVDSolver` class approximates the solution to the ordinary least-squares problem {eq}`eq:lstsq:plain` by solving the related problem\n",
473+
"\n",
474+
"$$\n",
475+
"\\begin{aligned}\n",
476+
" \\argmin_{\\Ohat}\\|\\tilde{\\D}\\Ohat\\trp - \\Z\\trp\\|_{F}^{2}\n",
477+
"\\end{aligned}\n",
478+
"$$\n",
479+
"\n",
480+
"where $\\tilde{\\D}$ is the best rank-$d'$ approximation of $\\D$ for some given $d' < \\min(k,d)$, i.e.,\n",
481+
"\n",
482+
"$$\n",
483+
"\\begin{aligned}\n",
484+
" \\tilde{D}\n",
485+
" = \\argmin_{\\D' \\in \\RR^{k \\times d}}\n",
486+
" \\|\\D' - \\D\\|_{F}\n",
487+
" \\quad\\textrm{such that}\\quad\n",
488+
" \\operatorname{rank}(\\D') = d'.\n",
489+
"\\end{aligned}\n",
490+
"$$\n",
491+
"\n",
492+
"This approach is [related to Tikhonov regularization](https://math.stackexchange.com/questions/1084677/tikhonov-regularization-vs-truncated-svd) and is based on the [truncated singular value decomposition](https://en.wikipedia.org/wiki/Singular_value_decomposition#Truncated_SVD) of the data matrix $\\D$.\n",
493+
"The optimization residual is guaranteed to be higher than when using the full SVD as in {class}`PlainSolver`, but the condition number of the truncated SVD system is lower than that of the original system.\n",
494+
"Truncation can play a similar role to regularization, but the hyperparameter here (the number of columns to use) is an integer, whereas the regularization hyperparameter $\\lambda$ for {class}`L2Solver` may be any positive number."
495+
]
496+
},
497+
{
498+
"cell_type": "code",
499+
"execution_count": null,
500+
"metadata": {},
501+
"outputs": [],
502+
"source": [
503+
"tsvdsolver = opinf.lstsq.TruncatedSVDSolver(-2)\n",
504+
"tsvdsolver.fit(D, Z)\n",
505+
"print(tsvdsolver)"
506+
]
507+
},
508+
{
509+
"cell_type": "code",
510+
"execution_count": null,
511+
"metadata": {},
512+
"outputs": [],
513+
"source": [
514+
"Ohat = tsvdsolver.solve()\n",
515+
"tsvdsolver.residual(Ohat)"
516+
]
517+
},
518+
{
519+
"cell_type": "code",
520+
"execution_count": null,
521+
"metadata": {},
522+
"outputs": [],
523+
"source": [
524+
"# Change the number of columns used without recomputing the SVD.\n",
525+
"tsvdsolver.num_svdmodes = 8\n",
526+
"print(tsvdsolver)"
527+
]
528+
},
529+
{
530+
"cell_type": "code",
531+
"execution_count": null,
532+
"metadata": {},
533+
"outputs": [],
534+
"source": [
535+
"tsvdsolver.residual(tsvdsolver.solve())"
536+
]
537+
},
463538
{
464539
"cell_type": "markdown",
465540
"metadata": {},
@@ -510,7 +585,7 @@
510585
"metadata": {},
511586
"outputs": [],
512587
"source": [
513-
"Ohat_total = totalsolver.predict()\n",
588+
"Ohat_total = totalsolver.solve()\n",
514589
"totalsolver.residual(Ohat_total)"
515590
]
516591
},
@@ -547,7 +622,7 @@
547622
" # Process / store hyperparameters here.\n",
548623
"\n",
549624
" # Required methods --------------------------------------------------------\n",
550-
" def predict(self):\n",
625+
" def solve(self):\n",
551626
" \"\"\"Solve the regression and return the operator matrix.\"\"\"\n",
552627
" raise NotImplementedError\n",
553628
"\n",
@@ -615,7 +690,7 @@
615690
" super().__init__()\n",
616691
" self.options = dict(maxiter=maxiter, atol=atol)\n",
617692
"\n",
618-
" def predict(self):\n",
693+
" def solve(self):\n",
619694
" \"\"\"Solve the regression and return the operator matrix.\"\"\"\n",
620695
" # Allocate space for the operator matrix entries.\n",
621696
" Ohat = np.empty((self.r, self.d))\n",
@@ -643,8 +718,8 @@
643718
"metadata": {},
644719
"outputs": [],
645720
"source": [
646-
"solver = NNSolver()\n",
647-
"solver.verify()"
721+
"solver = NNSolver().fit(D, Z)\n",
722+
"print(solver)"
648723
]
649724
},
650725
{
@@ -653,8 +728,7 @@
653728
"metadata": {},
654729
"outputs": [],
655730
"source": [
656-
"solver.fit(D, Z)\n",
657-
"print(solver)"
731+
"solver.verify()"
658732
]
659733
},
660734
{
@@ -663,7 +737,7 @@
663737
"metadata": {},
664738
"outputs": [],
665739
"source": [
666-
"Ohat_nn = solver.predict()\n",
740+
"Ohat_nn = solver.solve()\n",
667741
"print(f\"{Ohat_nn.shape=}\")\n",
668742
"\n",
669743
"# Check that the entries of the operator matrix are nonnegative.\n",

docs/source/api/missing.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,5 @@ lstsq.ipynb
134134
L2DecoupledSolver
135135
TikhonovSolver
136136
TikhonovDecoupledSolver
137+
TruncatedSVDSolver
137138
TotalLeastSquaresSolver

docs/source/api/pre.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
":::\n",
7575
"\n",
7676
"You can [download the data here](https://github.com/Willcox-Research-Group/rom-operator-inference-Python3/raw/data/pre_example.npy) to repeat the experiments.\n",
77+
"The full dataset is available [here](https://doi.org/10.7302/nj7w-j319).\n",
7778
"::::"
7879
]
7980
},

docs/source/opinf/changelog.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@
55
New versions may introduce substantial new features or API adjustments.
66
:::
77

8+
## Version 0.5.7
9+
10+
Updates to `opinf.lstsq`:
11+
12+
- New `TruncatedSVDSolver` class.
13+
- `predict()` has been renamed `solve()` for `opinf.lstsq` solver classes to not clash with `predict()` from `opinf.roms` / `opinf.models` classes.
14+
- `solve()` always returns a two-dimensional array, even if $r = 1$.
15+
16+
Various small improvements to tests and documentation.
17+
818
## Version 0.5.6
919

1020
Added public templates to `opinf.operators`:
@@ -14,9 +24,7 @@ Added public templates to `opinf.operators`:
1424
- `ParametricOperatorTemplate` for general parametric operators.
1525
- `ParametricOpInfOperator` for parametric operators that can be learned through Operator Inference.
1626

17-
Also added `opinf.ddt.InterpolationDerivativeEstimator` and made various improvements to the API documentation.
18-
19-
Also made various updates for compatibility with [NumPy 2.0.0](https://numpy.org/doc/stable/release/2.0.0-notes.html).
27+
Also added a new `opinf.ddt.InterpolationDerivativeEstimator` class and made various small changes for compatibility with [NumPy 2.0.0](https://numpy.org/doc/stable/release/2.0.0-notes.html).
2028

2129
## Version 0.5.5
2230

src/opinf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
https://github.com/Willcox-Research-Group/rom-operator-inference-Python3
88
"""
99

10-
__version__ = "0.5.6"
10+
__version__ = "0.5.7"
1111

1212
from . import (
1313
basis,

src/opinf/lstsq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
r"""Solvers for Operator Inference least-squares problems."""
33

44
from ._base import *
5+
from ._tsvd import *
56
from ._tikhonov import *
67
from ._total import *

0 commit comments

Comments
 (0)