Skip to content

Commit 089c063

Browse files
charlesknippyebai
andauthored
Update to most recent SSMProblems interface (#116)
* update to most recent SSMProblems interface * fixed linear Gaussian uinit test * fixed Levy SSM * temporary fix, that yields inefficient sampler * improved GP fix * remove environment path * Update Project.toml * remove unused constructor * update Levy SSM --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 57e6afd commit 089c063

File tree

12 files changed

+307
-406
lines changed

12 files changed

+307
-406
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
*.jl.cov
22
*.jl.*.cov
33
*.jl.mem
4-
/Manifest.toml
4+
Manifest.toml
55
/test/Manifest.toml

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Random = "<0.0.1, 1"
2626
Random123 = "1.3"
2727
Requires = "1.0"
2828
StatsFuns = "0.9, 1"
29-
SSMProblems = "0.1"
29+
SSMProblems = "0.5"
3030
julia = "1.7"
3131

3232
[extras]

examples/gaussian-process/script.jl

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,74 +8,76 @@ using Distributions
88
using Libtask
99
using SSMProblems
1010

11-
Parameters = @NamedTuple begin
12-
a::Float64
13-
q::Float64
14-
kernel
11+
struct GaussianProcessDynamics{T<:Real,KT<:Kernel} <: LatentDynamics{T,T}
12+
proc::GP{ZeroMean{T},KT}
13+
function GaussianProcessDynamics(::Type{T}, kernel::KT) where {T<:Real,KT<:Kernel}
14+
return new{T,KT}(GP(ZeroMean{T}(), kernel))
15+
end
1516
end
1617

17-
mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
18-
X::Vector{Float64}
19-
observations::Vector{Float64}
20-
θ::Parameters
21-
22-
GPSSM(params::Parameters) = new(Vector{Float64}(), params)
23-
GPSSM(y::Vector{Float64}, params::Parameters) = new(Vector{Float64}(), y, params)
18+
struct LinearGaussianDynamics{T<:Real} <: LatentDynamics{T,T}
19+
a::T
20+
b::T
21+
q::T
2422
end
2523

26-
seed = 1
27-
T = 100
28-
Nₚ = 20
29-
Nₛ = 250
30-
a = 0.9
31-
q = 0.5
32-
33-
params = Parameters((a, q, SqExponentialKernel()))
24+
function SSMProblems.distribution(proc::LinearGaussianDynamics{T}) where {T<:Real}
25+
return Normal(zero(T), proc.q)
26+
end
3427

35-
f::Parameters, x, t) = Normal.a * x, θ.q)
36-
h::Parameters) = Normal(0, θ.q)
37-
g::Parameters, x, t) = Normal(0, exp(0.5 * x)^2)
28+
function SSMProblems.distribution(proc::LinearGaussianDynamics, ::Int, state)
29+
return Normal(proc.a * state + proc.b, proc.q)
30+
end
3831

39-
rng = Random.MersenneTwister(seed)
32+
struct StochasticVolatility{T<:Real} <: ObservationProcess{T,T} end
4033

41-
x = zeros(T)
42-
y = similar(x)
43-
x[1] = rand(rng, h(params))
44-
for t in 1:T
45-
if t < T
46-
x[t + 1] = rand(rng, f(params, x[t], t))
47-
end
48-
y[t] = rand(rng, g(params, x[t], t))
34+
function SSMProblems.distribution(::StochasticVolatility{T}, ::Int, state) where {T<:Real}
35+
return Normal(zero(T), exp((1 / 2) * state))
4936
end
5037

51-
function gp_update(model::GPSSM, state, step)
52-
gp = GP(model.θ.kernel)
53-
prior = gp(1:(step - 1))
54-
post = posterior(prior, model.X[1:(step - 1)])
55-
μ, σ = mean_and_cov(post, [step])
56-
return Normal(μ[1], σ[1])
38+
function LinearGaussianStochasticVolatilityModel(a::T, q::T) where {T<:Real}
39+
dyn = LinearGaussianDynamics(a, zero(T), q)
40+
obs = StochasticVolatility{T}()
41+
return SSMProblems.StateSpaceModel(dyn, obs)
5742
end
5843

59-
SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM) = rand(rng, h(model.θ))
60-
function SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM, state, step)
61-
return rand(rng, gp_update(model, state, step))
44+
function GaussianProcessStateSpaceModel(::Type{T}, kernel::KT) where {T<:Real,KT<:Kernel}
45+
dyn = GaussianProcessDynamics(T, kernel)
46+
obs = StochasticVolatility{T}()
47+
return SSMProblems.StateSpaceModel(dyn, obs)
6248
end
6349

64-
function SSMProblems.emission_logdensity(model::GPSSM, state, step)
65-
return logpdf(g(model.θ, state, step), model.observations[step])
66-
end
67-
function SSMProblems.transition_logdensity(model::GPSSM, prev_state, current_state, step)
68-
return logpdf(gp_update(model, prev_state, step), current_state)
50+
const GPSSM{T,KT<:Kernel} = SSMProblems.StateSpaceModel{
51+
T,
52+
GaussianProcessDynamics{T,KT},
53+
StochasticVolatility{T}
54+
};
55+
56+
# for non-markovian models, we can redefine dynamics to reference the trajectory
57+
function AdvancedPS.dynamics(
58+
ssm::AdvancedPS.TracedSSM{<:GPSSM{T},T,T}, step::Int
59+
) where {T<:Real}
60+
prior = ssm.model.dyn.proc(1:(step - 1))
61+
post = posterior(prior, ssm.X[1:(step - 1)])
62+
μ, σ = mean_and_cov(post, [step])
63+
return LinearGaussianDynamics(zero(T), μ[1], sqrt(σ[1]))
6964
end
7065

71-
AdvancedPS.isdone(::GPSSM, step) = step > T
66+
# Everything is now ready to simulate some data.
67+
rng = MersenneTwister(1234);
68+
true_model = LinearGaussianStochasticVolatilityModel(0.9, 0.5);
69+
_, x, y = sample(rng, true_model, 100);
7270

73-
model = GPSSM(y, params)
74-
pg = AdvancedPS.PGAS(Nₚ)
75-
chains = sample(rng, model, pg, Nₛ)
71+
# Create the model and run the sampler
72+
gpssm = GaussianProcessStateSpaceModel(Float64, SqExponentialKernel());
73+
model = gpssm(y);
74+
pg = AdvancedPS.PGAS(20);
75+
chains = sample(rng, model, pg, 250; progress=false);
76+
#md nothing #hide
7677

77-
particles = hcat([chain.trajectory.model.X for chain in chains]...)
78+
particles = hcat([chain.trajectory.model.X for chain in chains]...);
7879
mean_trajectory = mean(particles; dims=2);
80+
#md nothing #hide
7981

8082
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
8183
plot!(x; color=:darkorange, label="Original Trajectory")

examples/gaussian-ssm/script.jl

Lines changed: 31 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -28,81 +28,55 @@ using SSMProblems
2828
# as well as the initial distribution $f_0(x) = \mathcal{N}(0, q^2/(1-a^2))$.
2929

3030
# To use `AdvancedPS` we first need to define a model type that subtypes `AdvancedPS.AbstractStateSpaceModel`.
31-
Parameters = @NamedTuple begin
32-
a::Float64
33-
q::Float64
34-
r::Float64
31+
mutable struct Parameters{T<:Real}
32+
a::T
33+
q::T
34+
r::T
3535
end
3636

37-
mutable struct LinearSSM <: SSMProblems.AbstractStateSpaceModel
38-
X::Vector{Float64}
39-
observations::Vector{Float64}
40-
θ::Parameters
41-
LinearSSM::Parameters) = new(Vector{Float64}(), θ)
42-
LinearSSM(y::Vector, θ::Parameters) = new(Vector{Float64}(), y, θ)
37+
struct LinearGaussianDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T}
38+
a::T
39+
q::T
4340
end
4441

45-
# and the densities defined above.
46-
f::Parameters, state, t) = Normal.a * state, θ.q) # Transition density
47-
g::Parameters, state, t) = Normal(state, θ.r) # Observation density
48-
f₀::Parameters) = Normal(0, θ.q^2 / (1 - θ.a^2)) # Initial state density
49-
#md nothing #hide
42+
function SSMProblems.distribution(dyn::LinearGaussianDynamics{T}; kwargs...) where {T<:Real}
43+
return Normal(zero(T), sqrt(dyn.q^2 / (1 - dyn.a^2)))
44+
end
5045

51-
# We also need to specify the dynamics of the system through the transition equations:
52-
# - `AdvancedPS.initialization`: the initial state density
53-
# - `AdvancedPS.transition`: the state transition density
54-
# - `AdvancedPS.observation`: the observation score given the observed data
55-
# - `AdvancedPS.isdone`: signals the end of the execution for the model
56-
SSMProblems.transition!!(rng::AbstractRNG, model::LinearSSM) = rand(rng, f₀(model.θ))
57-
function SSMProblems.transition!!(
58-
rng::AbstractRNG, model::LinearSSM, state::Float64, step::Int
59-
)
60-
return rand(rng, f(model.θ, state, step))
46+
function SSMProblems.distribution(dyn::LinearGaussianDynamics, step::Int, state; kwargs...)
47+
return Normal(dyn.a * state, dyn.q)
6148
end
6249

63-
function SSMProblems.emission_logdensity(modeL::LinearSSM, state::Float64, step::Int)
64-
return logpdf(g(model.θ, state, step), model.observations[step])
50+
struct LinearGaussianObservation{T<:Real} <: SSMProblems.ObservationProcess{T,T}
51+
r::T
6552
end
66-
function SSMProblems.transition_logdensity(
67-
model::LinearSSM, prev_state, current_state, step
53+
54+
function SSMProblems.distribution(
55+
obs::LinearGaussianObservation, step::Int, state; kwargs...
6856
)
69-
return logpdf(f(model.θ, prev_state, step), current_state)
57+
return Normal(state, obs.r)
7058
end
7159

72-
# We need to think seriously about how the data is handled
73-
AdvancedPS.isdone(::LinearSSM, step) = step > Tₘ
60+
function LinearGaussianStateSpaceModel::Parameters)
61+
dyn = LinearGaussianDynamics.a, θ.q)
62+
obs = LinearGaussianObservation.r)
63+
return SSMProblems.StateSpaceModel(dyn, obs)
64+
end
7465

7566
# Everything is now ready to simulate some data.
76-
a = 0.9 # Scale
77-
q = 0.32 # State variance
78-
r = 1 # Observation variance
79-
Tₘ = 200 # Number of observation
80-
Nₚ = 20 # Number of particles
81-
Nₛ = 500 # Number of samples
82-
seed = 1 # Reproduce everything
83-
84-
θ₀ = Parameters((a, q, r))
85-
rng = Random.MersenneTwister(seed)
86-
87-
x = zeros(Tₘ)
88-
y = zeros(Tₘ)
89-
x[1] = rand(rng, f₀(θ₀))
90-
for t in 1:Tₘ
91-
if t < Tₘ
92-
x[t + 1] = rand(rng, f(θ₀, x[t], t))
93-
end
94-
y[t] = rand(rng, g(θ₀, x[t], t))
95-
end
67+
rng = Random.MersenneTwister(1234)
68+
θ = Parameters(0.9, 0.32, 1.0)
69+
true_model = LinearGaussianStateSpaceModel(θ)
70+
_, x, y = sample(rng, true_model, 200);
9671

9772
# Here are the latent and obseravation timeseries
9873
plot(x; label="x", xlabel="t")
9974
plot!(y; seriestype=:scatter, label="y", xlabel="t", mc=:red, ms=2, ma=0.5)
10075

10176
# `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel
10277
# and a model interface.
103-
model = LinearSSM(y, θ₀)
104-
pgas = AdvancedPS.PGAS(Nₚ)
105-
chains = sample(rng, model, pgas, Nₛ; progress=false);
78+
pgas = AdvancedPS.PGAS(20)
79+
chains = sample(rng, true_model(y), pgas, 500; progress=false);
10680
#md nothing #hide
10781

10882
#
@@ -118,7 +92,7 @@ plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
11892
# We used a particle gibbs kernel with the ancestor updating step which should help with the particle
11993
# degeneracy problem and improve the mixing.
12094
# We can compute the update rate of $x_t$ vs $t$ defined as the proportion of times $t$ where $x_t$ gets updated:
121-
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
95+
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / length(chains)
12296
#md nothing #hide
12397

12498
# and compare it to the theoretical value of $1 - 1/Nₚ$.
@@ -130,4 +104,4 @@ plot(
130104
xlabel="Iteration",
131105
ylabel="Update rate",
132106
)
133-
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")
107+
hline!([1 - 1 / length(chains)]; label="N: $(length(chains))")

0 commit comments

Comments
 (0)