|
26 | 26 | " L2DecoupledSolver\n",
|
27 | 27 | " TikhonovSolver\n",
|
28 | 28 | " TikhonovDecoupledSolver\n",
|
| 29 | + " TruncatedSVDSolver\n", |
29 | 30 | " TotalLeastSquaresSolver\n",
|
30 | 31 | "```"
|
31 | 32 | ]
|
|
37 | 38 | ":::{admonition} Overview\n",
|
38 | 39 | ":class: note\n",
|
39 | 40 | "\n",
|
40 |
| - "- Solver classes defined in `opinf.lstsq` are used to solve an Operator Inference regression.\n", |
41 |
| - "- [`fit()`](SolverTemplate.fit) receives data matrices defining the regression.\n", |
42 |
| - "- [`predict()`](SolverTemplate.predict) solves the regression and returns the operator matrix.\n", |
43 |
| - "- Solver objects are passed to the constructor of {mod}`opinf.models` classes.\n", |
44 |
| - "- Models handle the construction of $\\D$ and $\\Z$ from snapshot data, pass these matrices to the solver's `fit()` method, call the solver's `predict()` method to compute $\\Ohat$, and interpret $\\Ohat$ in the context of the model structure.\n", |
| 41 | + "- `opinf.lstsq` classes are solve Operator Inference regression problems.\n", |
| 42 | + "- `opinf.models` classes handle the construction of the regression matrices from snapshot data, pass these matrices to the solver's [`fit()`](SolverTemplate.fit) method, call the solver's [`solve()`](SolverTemplate.solve) method, and interpret the solution in the context of the model structure.\n", |
45 | 43 | ":::"
|
46 | 44 | ]
|
47 | 45 | },
|
|
72 | 70 | "cell_type": "markdown",
|
73 | 71 | "metadata": {},
|
74 | 72 | "source": [
|
75 |
| - "## Least-squares Operator Inference Problems" |
| 73 | + "## Operator Inference Regression Problems" |
76 | 74 | ]
|
77 | 75 | },
|
78 | 76 | {
|
|
289 | 287 | "outputs": [],
|
290 | 288 | "source": [
|
291 | 289 | "# Solve the least-squares problem for the operator matrix.\n",
|
292 |
| - "Ohat = solver.predict()\n", |
| 290 | + "Ohat = solver.solve()\n", |
293 | 291 | "print(f\"{Ohat.shape=}\")"
|
294 | 292 | ]
|
295 | 293 | },
|
|
348 | 346 | "\n",
|
349 | 347 | "The following classes solve Tikhonov-regularized least-squares Operator Inference regressions for different choices of the regularization term $\\mathcal{R}(\\Ohat)$.\n",
|
350 | 348 | "\n",
|
351 |
| - "| Solver class | Description | Regularization $\\mathcal{R}(\\Ohat)$ |\n", |
352 |
| - "| :------------------------------- | :----------------------------------------------- | :------------: |\n", |
353 |
| - "| {class}`L2Solver` | One scalar regularizer for all $\\ohat_i$ | $\\lambda^{2}||\\Ohat\\trp||_F^2$ |\n", |
354 |
| - "| {class}`L2DecoupledSolver` | Different scalar regularizers for each $\\ohat_i$ | $\\sum_{i=1}^{r}\\lambda_i^2||\\ohat_i||_2^2$ |\n", |
355 |
| - "| {class}`TikhonovSolver` | One matrix regularizer for all $\\ohat_i$ | $||\\bfGamma\\Ohat\\trp||_F^2$ |\n", |
356 |
| - "| {class}`TikhonovDecoupledSolver` | Different matrix regularizers for each $\\ohat_i$ | $\\sum_{i=1}^{r}||\\bfGamma_i\\ohat_i||_2^2$ |" |
| 349 | + "| Solver class | Description | Regularization $\\mathcal{R}(\\Ohat)$ |\n", |
| 350 | + "| :------------------------------- | :----------------------------------------------- | :--------------------------------------------: |\n", |
| 351 | + "| {class}`L2Solver` | One scalar regularizer for all $\\ohat_i$ | $\\lambda^{2}\\|\\|\\Ohat\\trp\\|\\|_F^2$ |\n", |
| 352 | + "| {class}`L2DecoupledSolver` | Different scalar regularizers for each $\\ohat_i$ | $\\sum_{i=1}^{r}\\lambda_i^2\\|\\|\\ohat_i\\|\\|_2^2$ |\n", |
| 353 | + "| {class}`TikhonovSolver` | One matrix regularizer for all $\\ohat_i$ | $\\|\\|\\bfGamma\\Ohat\\trp\\|\\|_F^2$ |\n", |
| 354 | + "| {class}`TikhonovDecoupledSolver` | Different matrix regularizers for each $\\ohat_i$ | $\\sum_{i=1}^{r}\\|\\|\\bfGamma_i\\ohat_i\\|\\|_2^2$ |" |
357 | 355 | ]
|
358 | 356 | },
|
359 | 357 | {
|
|
373 | 371 | "metadata": {},
|
374 | 372 | "outputs": [],
|
375 | 373 | "source": [
|
376 |
| - "Ohat_L2 = l2solver.predict()\n", |
| 374 | + "Ohat_L2 = l2solver.solve()\n", |
377 | 375 | "l2solver.residual(Ohat_L2)"
|
378 | 376 | ]
|
379 | 377 | },
|
|
396 | 394 | "metadata": {},
|
397 | 395 | "outputs": [],
|
398 | 396 | "source": [
|
399 |
| - "Ohat_L2d = l2dsolver.predict()\n", |
| 397 | + "Ohat_L2d = l2dsolver.solve()\n", |
400 | 398 | "l2dsolver.residual(Ohat_L2d)"
|
401 | 399 | ]
|
402 | 400 | },
|
|
420 | 418 | "metadata": {},
|
421 | 419 | "outputs": [],
|
422 | 420 | "source": [
|
423 |
| - "Ohat_tik = tiksolver.predict()\n", |
| 421 | + "Ohat_tik = tiksolver.solve()\n", |
424 | 422 | "tiksolver.residual(Ohat_tik)"
|
425 | 423 | ]
|
426 | 424 | },
|
|
443 | 441 | "metadata": {},
|
444 | 442 | "outputs": [],
|
445 | 443 | "source": [
|
446 |
| - "Ohat_tik = tiksolver.predict()\n", |
| 444 | + "Ohat_tik = tiksolver.solve()\n", |
447 | 445 | "tiksolver.residual(Ohat_tik)"
|
448 | 446 | ]
|
449 | 447 | },
|
|
460 | 458 | "print(tikdsolver)"
|
461 | 459 | ]
|
462 | 460 | },
|
| 461 | + { |
| 462 | + "cell_type": "markdown", |
| 463 | + "metadata": {}, |
| 464 | + "source": [ |
| 465 | + "## Truncated SVD" |
| 466 | + ] |
| 467 | + }, |
| 468 | + { |
| 469 | + "cell_type": "markdown", |
| 470 | + "metadata": {}, |
| 471 | + "source": [ |
| 472 | + "The {class}`TruncatedSVDSolver` class approximates the solution to the ordinary least-squares problem {eq}`eq:lstsq:plain` by solving the related problem\n", |
| 473 | + "\n", |
| 474 | + "$$\n", |
| 475 | + "\\begin{aligned}\n", |
| 476 | + " \\argmin_{\\Ohat}\\|\\tilde{\\D}\\Ohat\\trp - \\Z\\trp\\|_{F}^{2}\n", |
| 477 | + "\\end{aligned}\n", |
| 478 | + "$$\n", |
| 479 | + "\n", |
| 480 | + "where $\\tilde{\\D}$ is the best rank-$d'$ approximation of $\\D$ for some given $d' < \\min(k,d)$, i.e.,\n", |
| 481 | + "\n", |
| 482 | + "$$\n", |
| 483 | + "\\begin{aligned}\n", |
| 484 | + " \\tilde{D}\n", |
| 485 | + " = \\argmin_{\\D' \\in \\RR^{k \\times d}}\n", |
| 486 | + " \\|\\D' - \\D\\|_{F}\n", |
| 487 | + " \\quad\\textrm{such that}\\quad\n", |
| 488 | + " \\operatorname{rank}(\\D') = d'.\n", |
| 489 | + "\\end{aligned}\n", |
| 490 | + "$$\n", |
| 491 | + "\n", |
| 492 | + "This approach is [related to Tikhonov regularization](https://math.stackexchange.com/questions/1084677/tikhonov-regularization-vs-truncated-svd) and is based on the [truncated singular value decomposition](https://en.wikipedia.org/wiki/Singular_value_decomposition#Truncated_SVD) of the data matrix $\\D$.\n", |
| 493 | + "The optimization residual is guaranteed to be higher than when using the full SVD as in {class}`PlainSolver`, but the condition number of the truncated SVD system is lower than that of the original system.\n", |
| 494 | + "Truncation can play a similar role to regularization, but the hyperparameter here (the number of columns to use) is an integer, whereas the regularization hyperparameter $\\lambda$ for {class}`L2Solver` may be any positive number." |
| 495 | + ] |
| 496 | + }, |
| 497 | + { |
| 498 | + "cell_type": "code", |
| 499 | + "execution_count": null, |
| 500 | + "metadata": {}, |
| 501 | + "outputs": [], |
| 502 | + "source": [ |
| 503 | + "tsvdsolver = opinf.lstsq.TruncatedSVDSolver(-2)\n", |
| 504 | + "tsvdsolver.fit(D, Z)\n", |
| 505 | + "print(tsvdsolver)" |
| 506 | + ] |
| 507 | + }, |
| 508 | + { |
| 509 | + "cell_type": "code", |
| 510 | + "execution_count": null, |
| 511 | + "metadata": {}, |
| 512 | + "outputs": [], |
| 513 | + "source": [ |
| 514 | + "Ohat = tsvdsolver.solve()\n", |
| 515 | + "tsvdsolver.residual(Ohat)" |
| 516 | + ] |
| 517 | + }, |
| 518 | + { |
| 519 | + "cell_type": "code", |
| 520 | + "execution_count": null, |
| 521 | + "metadata": {}, |
| 522 | + "outputs": [], |
| 523 | + "source": [ |
| 524 | + "# Change the number of columns used without recomputing the SVD.\n", |
| 525 | + "tsvdsolver.num_svdmodes = 8\n", |
| 526 | + "print(tsvdsolver)" |
| 527 | + ] |
| 528 | + }, |
| 529 | + { |
| 530 | + "cell_type": "code", |
| 531 | + "execution_count": null, |
| 532 | + "metadata": {}, |
| 533 | + "outputs": [], |
| 534 | + "source": [ |
| 535 | + "tsvdsolver.residual(tsvdsolver.solve())" |
| 536 | + ] |
| 537 | + }, |
463 | 538 | {
|
464 | 539 | "cell_type": "markdown",
|
465 | 540 | "metadata": {},
|
|
510 | 585 | "metadata": {},
|
511 | 586 | "outputs": [],
|
512 | 587 | "source": [
|
513 |
| - "Ohat_total = totalsolver.predict()\n", |
| 588 | + "Ohat_total = totalsolver.solve()\n", |
514 | 589 | "totalsolver.residual(Ohat_total)"
|
515 | 590 | ]
|
516 | 591 | },
|
|
547 | 622 | " # Process / store hyperparameters here.\n",
|
548 | 623 | "\n",
|
549 | 624 | " # Required methods --------------------------------------------------------\n",
|
550 |
| - " def predict(self):\n", |
| 625 | + " def solve(self):\n", |
551 | 626 | " \"\"\"Solve the regression and return the operator matrix.\"\"\"\n",
|
552 | 627 | " raise NotImplementedError\n",
|
553 | 628 | "\n",
|
|
615 | 690 | " super().__init__()\n",
|
616 | 691 | " self.options = dict(maxiter=maxiter, atol=atol)\n",
|
617 | 692 | "\n",
|
618 |
| - " def predict(self):\n", |
| 693 | + " def solve(self):\n", |
619 | 694 | " \"\"\"Solve the regression and return the operator matrix.\"\"\"\n",
|
620 | 695 | " # Allocate space for the operator matrix entries.\n",
|
621 | 696 | " Ohat = np.empty((self.r, self.d))\n",
|
|
643 | 718 | "metadata": {},
|
644 | 719 | "outputs": [],
|
645 | 720 | "source": [
|
646 |
| - "solver = NNSolver()\n", |
647 |
| - "solver.verify()" |
| 721 | + "solver = NNSolver().fit(D, Z)\n", |
| 722 | + "print(solver)" |
648 | 723 | ]
|
649 | 724 | },
|
650 | 725 | {
|
|
653 | 728 | "metadata": {},
|
654 | 729 | "outputs": [],
|
655 | 730 | "source": [
|
656 |
| - "solver.fit(D, Z)\n", |
657 |
| - "print(solver)" |
| 731 | + "solver.verify()" |
658 | 732 | ]
|
659 | 733 | },
|
660 | 734 | {
|
|
663 | 737 | "metadata": {},
|
664 | 738 | "outputs": [],
|
665 | 739 | "source": [
|
666 |
| - "Ohat_nn = solver.predict()\n", |
| 740 | + "Ohat_nn = solver.solve()\n", |
667 | 741 | "print(f\"{Ohat_nn.shape=}\")\n",
|
668 | 742 | "\n",
|
669 | 743 | "# Check that the entries of the operator matrix are nonnegative.\n",
|
|
0 commit comments