Skip to content

Commit c2baffa

Browse files
JasonPekosyebaipenelopeysm
authored
add new section in HMM tutorial (#508)
* add new section in HMM tutorial * Update tutorials/04-hidden-markov-model/index.qmd Co-authored-by: Penelope Yong <[email protected]> * Update tutorials/04-hidden-markov-model/index.qmd Co-authored-by: Penelope Yong <[email protected]> * Update tutorials/04-hidden-markov-model/index.qmd Co-authored-by: Penelope Yong <[email protected]> * Remove code fold / output suppression * Updates for new versions of Turing --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Penelope Yong <[email protected]>
1 parent 9b7136d commit c2baffa

File tree

3 files changed

+105
-42
lines changed

3 files changed

+105
-42
lines changed

Manifest.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.11.5"
44
manifest_format = "2.0"
5-
project_hash = "451ef37239e6a37e494e660005a325d7a0d8517f"
5+
project_hash = "abb3c770eb08cdd80c9627a8a7c327584291f4c8"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "e2478490447631aedba0823d4d7a80b2cc8cdb32"
@@ -1348,6 +1348,16 @@ git-tree-sha1 = "2eaa69a7cab70a52b9687c8bf950a5a93ec895ae"
13481348
uuid = "076d061b-32b6-4027-95e0-9a2c6f6d7e74"
13491349
version = "0.2.0"
13501350

1351+
[[deps.HiddenMarkovModels]]
1352+
deps = ["ArgCheck", "ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "Random", "SparseArrays", "StatsAPI", "StatsFuns"]
1353+
git-tree-sha1 = "aaeadfb78b874b30f0ce2109a6799547c3f72e89"
1354+
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
1355+
version = "0.7.0"
1356+
weakdeps = ["Distributions"]
1357+
1358+
[deps.HiddenMarkovModels.extensions]
1359+
HiddenMarkovModelsDistributionsExt = "Distributions"
1360+
13511361
[[deps.Hwloc]]
13521362
deps = ["CEnum", "Hwloc_jll", "Printf"]
13531363
git-tree-sha1 = "6a3d80f31ff87bc94ab22a7b8ec2f263f9a6a583"

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1919
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2020
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
2121
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
22+
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
2223
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
2324
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2425
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"

tutorials/hidden-markov-models/index.qmd

Lines changed: 93 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,19 @@ using Pkg;
1212
Pkg.instantiate();
1313
```
1414

15-
This tutorial illustrates training Bayesian [Hidden Markov Models](https://en.wikipedia.org/wiki/Hidden_Markov_model) (HMM) using Turing. The main goals are learning the transition matrix, emission parameter, and hidden states. For a more rigorous academic overview on Hidden Markov Models, see [An introduction to Hidden Markov Models and Bayesian Networks](http://mlg.eng.cam.ac.uk/zoubin/papers/ijprai.pdf) (Ghahramani, 2001).
15+
This tutorial illustrates training Bayesian [hidden Markov models](https://en.wikipedia.org/wiki/Hidden_Markov_model) (HMMs) using Turing.
16+
The main goals are learning the transition matrix, emission parameter, and hidden states.
17+
For a more rigorous academic overview of hidden Markov models, see [An Introduction to Hidden Markov Models and Bayesian Networks](https://mlg.eng.cam.ac.uk/zoubin/papers/ijprai.pdf) (Ghahramani, 2001).
1618

1719
In this tutorial, we assume there are $k$ discrete hidden states; the observations are continuous and normally distributed - centered around the hidden states. This assumption reduces the number of parameters to be estimated in the emission matrix.
1820

19-
Let's load the libraries we'll need. We also set a random seed (for reproducibility) and the automatic differentiation backend to forward mode (more [here]({{<meta using-turing-autodiff>}}) on why this is useful).
21+
Let's load the libraries we'll need, and set a random seed for reproducibility.
2022

2123
```{julia}
2224
# Load libraries.
23-
using Turing, StatsPlots, Random
25+
using Turing, StatsPlots, Random, Bijectors
2426
25-
# Set a random seed and use the forward_diff AD mode.
27+
# Set a random seed
2628
Random.seed!(12345678);
2729
```
2830

@@ -32,52 +34,23 @@ In this example, we'll use something where the states and emission parameters ar
3234

3335
```{julia}
3436
# Define the emission parameter.
35-
y = [
36-
1.0,
37-
1.0,
38-
1.0,
39-
1.0,
40-
1.0,
41-
1.0,
42-
2.0,
43-
2.0,
44-
2.0,
45-
2.0,
46-
2.0,
47-
2.0,
48-
3.0,
49-
3.0,
50-
3.0,
51-
3.0,
52-
3.0,
53-
3.0,
54-
3.0,
55-
2.0,
56-
2.0,
57-
2.0,
58-
2.0,
59-
1.0,
60-
1.0,
61-
1.0,
62-
1.0,
63-
1.0,
64-
1.0,
65-
1.0,
66-
];
37+
y = [fill(1.0, 6)..., fill(2.0, 6)..., fill(3.0, 7)...,
38+
fill(2.0, 4)..., fill(1.0, 7)...]
6739
N = length(y);
6840
K = 3;
6941
7042
# Plot the data we just made.
71-
plot(y; xlim=(0, 30), ylim=(-1, 5), size=(500, 250))
43+
plot(y; xlim=(0, 30), ylim=(-1, 5), size=(500, 250), legend = false)
44+
scatter!(y, color = :blue; xlim=(0, 30), ylim=(-1, 5), size=(500, 250), legend = false)
7245
```
7346

7447
We can see that we have three states, one for each height of the plot (1, 2, 3). This height is also our emission parameter, so state one produces a value of one, state two produces a value of two, and so on.
7548

7649
Ultimately, we would like to understand three major parameters:
7750

7851
1. The transition matrix. This is a matrix that assigns a probability of switching from one state to any other state, including the state that we are already in.
79-
2. The emission matrix, which describes a typical value emitted by some state. In the plot above, the emission parameter for state one is simply one.
80-
3. The state sequence is our understanding of what state we were actually in when we observed some data. This is very important in more sophisticated HMM models, where the emission value does not equal our state.
52+
2. The emission parameters, which describes a typical value emitted by some state. In the plot above, the emission parameter for state one is simply one.
53+
3. The state sequence is our understanding of what state we were actually in when we observed some data. This is very important in more sophisticated HMMs, where the emission value does not equal our state.
8154

8255
With this in mind, let's set up our model. We are going to use some of our knowledge as modelers to provide additional information about our system. This takes the form of the prior on our emission parameter.
8356

@@ -127,18 +100,22 @@ We will use a combination of two samplers (HMC and Particle Gibbs) by passing th
127100

128101
In this case, we use HMC for `m` and `T`, representing the emission and transition matrices respectively. We use the Particle Gibbs sampler for `s`, the state sequence. You may wonder why it is that we are not assigning `s` to the HMC sampler, and why it is that we need compositional Gibbs sampling at all.
129102

130-
The parameter `s` is not a continuous variable. It is a vector of **integers**, and thus Hamiltonian methods like HMC and NUTS won't work correctly. Gibbs allows us to apply the right tools to the best effect. If you are a particularly advanced user interested in higher performance, you may benefit from setting up your Gibbs sampler to use [different automatic differentiation]({{<meta using-turing-autodiff>}}#compositional-sampling-with-differing-ad-modes) backends for each parameter space.
103+
The parameter `s` is not a continuous variable.
104+
It is a vector of **integers**, and thus Hamiltonian methods like HMC and NUTS won't work correctly.
105+
Gibbs allows us to apply the right tools to the best effect.
106+
If you are a particularly advanced user interested in higher performance, you may benefit from setting up your Gibbs sampler to use [different automatic differentiation]({{<meta using-turing-autodiff>}}#compositional-sampling-with-differing-ad-modes) backends for each parameter space.
131107

132108
Time to run our sampler.
133109

134110
```{julia}
135111
#| output: false
112+
#| echo: false
136113
setprogress!(false)
137114
```
138115

139116
```{julia}
140117
g = Gibbs((:m, :T) => HMC(0.01, 50), :s => PG(120))
141-
chn = sample(BayesHmm(y, 3), g, 1000);
118+
chn = sample(BayesHmm(y, 3), g, 1000)
142119
```
143120

144121
Let's see how well our chain performed.
@@ -193,3 +170,78 @@ heideldiag(MCMCChains.group(chn, :T))[1]
193170
```
194171

195172
The p-values on the test suggest that we cannot reject the hypothesis that the observed sequence comes from a stationary distribution, so we can be reasonably confident that our transition matrix has converged to something reasonable.
173+
174+
## Efficient Inference With The Forward Algorithm
175+
176+
While the above method works well for the simple example in this tutorial, some users may desire a more efficient method, especially when their model is more complicated.
177+
One simple way to improve inference is to marginalize out the hidden states of the model with an appropriate algorithm, calculating only the posterior over the continuous random variables.
178+
Not only does this allow more efficient inference via Rao-Blackwellization, but now we can sample our model with `NUTS()` alone, which is usually a much more performant MCMC kernel.
179+
180+
Thankfully, [HiddenMarkovModels.jl](https://github.com/gdalle/HiddenMarkovModels.jl) provides an extremely efficient implementation of many algorithms related to hidden Markov models. This allows us to rewrite our model as:
181+
182+
```{julia}
183+
using HiddenMarkovModels
184+
using FillArrays
185+
using LinearAlgebra
186+
using LogExpFunctions
187+
188+
189+
@model function BayesHmm2(y, K)
190+
m ~ Bijectors.ordered(MvNormal([1.0, 2.0, 3.0], 0.5I))
191+
T ~ filldist(Dirichlet(fill(1/K, K)), K)
192+
193+
hmm = HMM(softmax(ones(K)), copy(T'), [Normal(m[i], 0.1) for i in 1:K])
194+
Turing.@addlogprob! logdensityof(hmm, y)
195+
end
196+
197+
chn2 = sample(BayesHmm2(y, 3), NUTS(), 1000)
198+
```
199+
200+
201+
We can compare the chains of these two models, confirming the posterior estimate is similar (modulo label switching concerns with the Gibbs model):
202+
```{julia}
203+
#| code-fold: true
204+
#| code-summary: "Plotting Chains"
205+
206+
plot(chn["m[1]"], label = "m[1], Model 1, Gibbs", color = :lightblue)
207+
plot!(chn2["m[1]"], label = "m[1], Model 2, NUTS", color = :blue)
208+
plot!(chn["m[2]"], label = "m[2], Model 1, Gibbs", color = :pink)
209+
plot!(chn2["m[2]"], label = "m[2], Model 2, NUTS", color = :red)
210+
plot!(chn["m[3]"], label = "m[3], Model 1, Gibbs", color = :yellow)
211+
plot!(chn2["m[3]"], label = "m[3], Model 2, NUTS", color = :orange)
212+
```
213+
214+
215+
### Recovering Marginalized Trajectories
216+
217+
We can use the `viterbi()` algorithm, also from the `HiddenMarkovModels` package, to recover the most probable state for each parameter set in our posterior sample:
218+
```{julia}
219+
@model function BayesHmmRecover(y, K, IncludeGenerated = false)
220+
m ~ Bijectors.ordered(MvNormal([1.0, 2.0, 3.0], 0.5I))
221+
T ~ filldist(Dirichlet(fill(1/K, K)), K)
222+
223+
hmm = HMM(softmax(ones(K)), copy(T'), [Normal(m[i], 0.1) for i in 1:K])
224+
Turing.@addlogprob! logdensityof(hmm, y)
225+
226+
# Conditional generation of the hidden states.
227+
if IncludeGenerated
228+
seq, _ = viterbi(hmm, y)
229+
s := [m[s] for s in seq]
230+
end
231+
end
232+
233+
chn_recover = sample(BayesHmmRecover(y, 3, true), NUTS(), 1000)
234+
```
235+
236+
Plotting the estimated states, we can see that the results align well with our expectations:
237+
238+
```{julia}
239+
p = plot(xlim=(0, 30), ylim=(-1, 5), size=(500, 250))
240+
for i in 1:100
241+
ind = rand(DiscreteUniform(1, 1000))
242+
plot!(MCMCChains.group(chn_recover, :s).value[ind,:], color = :grey, opacity = 0.1, legend = :false)
243+
end
244+
scatter!(y, color = :blue)
245+
246+
p
247+
```

0 commit comments

Comments
 (0)