Description
For system with many states, kalman filtering is slow because we have to compute the so-called Kalman gain matrix,
name | description | shape |
---|---|---|
Z | Map hidden states to observed states | |
P | Hidden state covariance matrix | |
H | Measurement error covariance | |
F | Estimated residual covariance | |
K | Kalman gain (optimal update to hidden states) |
Typically the number of observed states will be quite small relative to the number of hidden states, so the inversion of
Then, to compute next-step forecasts, we have to do several more multiplications:
Where
My point is this is all quite expensive to compute. Interestingly though,
This is an algebraic riccati equation, and can be solved in pytensor as pt.linalg.solve_discrete_are(A=T.T, B=Z.T, Q=R @ Q @ R.T, R=H)
. Once we have this, we actually don't need to compute
How to use this is not 100% clear to me. I had previously made a SteadyStateFilter
class that computed and used the steady-state solve_discrete_are
, and right now statespace is basically 100% dependent on JAX for sampling.
The "safer" option would be to use an ifelse
in the update
step to check for convergence. At every iteration of kalman_step
, we can compute np.eye(15) * 1e7
. The two plots are the same, but the right plot begins at t=20:
Here's a table of tolerance levels and convergence iterations:
Tolerance | Convergence Iteration |
---|---|
1 | 25 |
1e-1 | 49 |
1e-2 | 108 |
1e-3 | 216 |
1e-4 | 337 |
1e-5 | 457 |
1e-6 | 583 |
1e-7 | 703 |
1e-8 | 827 |
We could leave convergence tolerance as a free parameter for the user to play with. But we can see that if we pick 1e-2
for instance, anything after 100 time steps is basically free. This would be quite attractive for estimating large, expensive systems or extremely long time series.