@@ -8,74 +8,76 @@ using Distributions
8
8
using Libtask
9
9
using SSMProblems
10
10
11
- Parameters = @NamedTuple begin
12
- a:: Float64
13
- q:: Float64
14
- kernel
11
+ struct GaussianProcessDynamics{T<: Real ,KT<: Kernel } <: LatentDynamics{T,T}
12
+ proc:: GP{ZeroMean{T},KT}
13
+ function GaussianProcessDynamics (:: Type{T} , kernel:: KT ) where {T<: Real ,KT<: Kernel }
14
+ return new {T,KT} (GP (ZeroMean {T} (), kernel))
15
+ end
15
16
end
16
17
17
- mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
18
- X:: Vector{Float64}
19
- observations:: Vector{Float64}
20
- θ:: Parameters
21
-
22
- GPSSM (params:: Parameters ) = new (Vector {Float64} (), params)
23
- GPSSM (y:: Vector{Float64} , params:: Parameters ) = new (Vector {Float64} (), y, params)
18
+ struct LinearGaussianDynamics{T<: Real } <: LatentDynamics{T,T}
19
+ a:: T
20
+ b:: T
21
+ q:: T
24
22
end
25
23
26
- seed = 1
27
- T = 100
28
- Nₚ = 20
29
- Nₛ = 250
30
- a = 0.9
31
- q = 0.5
32
-
33
- params = Parameters ((a, q, SqExponentialKernel ()))
24
+ function SSMProblems. distribution (proc:: LinearGaussianDynamics{T} ) where {T<: Real }
25
+ return Normal (zero (T), proc. q)
26
+ end
34
27
35
- f (θ :: Parameters , x, t) = Normal (θ . a * x, θ . q )
36
- h (θ :: Parameters ) = Normal (0 , θ . q)
37
- g (θ :: Parameters , x, t) = Normal ( 0 , exp ( 0.5 * x) ^ 2 )
28
+ function SSMProblems . distribution (proc :: LinearGaussianDynamics , :: Int , state )
29
+ return Normal (proc . a * state + proc . b, proc . q)
30
+ end
38
31
39
- rng = Random . MersenneTwister (seed)
32
+ struct StochasticVolatility{T <: Real } <: ObservationProcess{T,T} end
40
33
41
- x = zeros (T)
42
- y = similar (x)
43
- x[1 ] = rand (rng, h (params))
44
- for t in 1 : T
45
- if t < T
46
- x[t + 1 ] = rand (rng, f (params, x[t], t))
47
- end
48
- y[t] = rand (rng, g (params, x[t], t))
34
+ function SSMProblems. distribution (:: StochasticVolatility{T} , :: Int , state) where {T<: Real }
35
+ return Normal (zero (T), exp ((1 / 2 ) * state))
49
36
end
50
37
51
- function gp_update (model:: GPSSM , state, step)
52
- gp = GP (model. θ. kernel)
53
- prior = gp (1 : (step - 1 ))
54
- post = posterior (prior, model. X[1 : (step - 1 )])
55
- μ, σ = mean_and_cov (post, [step])
56
- return Normal (μ[1 ], σ[1 ])
38
+ function LinearGaussianStochasticVolatilityModel (a:: T , q:: T ) where {T<: Real }
39
+ dyn = LinearGaussianDynamics (a, zero (T), q)
40
+ obs = StochasticVolatility {T} ()
41
+ return SSMProblems. StateSpaceModel (dyn, obs)
57
42
end
58
43
59
- SSMProblems. transition!! (rng:: AbstractRNG , model:: GPSSM ) = rand (rng, h (model. θ))
60
- function SSMProblems. transition!! (rng:: AbstractRNG , model:: GPSSM , state, step)
61
- return rand (rng, gp_update (model, state, step))
44
+ function GaussianProcessStateSpaceModel (:: Type{T} , kernel:: KT ) where {T<: Real ,KT<: Kernel }
45
+ dyn = GaussianProcessDynamics (T, kernel)
46
+ obs = StochasticVolatility {T} ()
47
+ return SSMProblems. StateSpaceModel (dyn, obs)
62
48
end
63
49
64
- function SSMProblems. emission_logdensity (model:: GPSSM , state, step)
65
- return logpdf (g (model. θ, state, step), model. observations[step])
66
- end
67
- function SSMProblems. transition_logdensity (model:: GPSSM , prev_state, current_state, step)
68
- return logpdf (gp_update (model, prev_state, step), current_state)
50
+ const GPSSM{T,KT<: Kernel } = SSMProblems. StateSpaceModel{
51
+ T,
52
+ GaussianProcessDynamics{T,KT},
53
+ StochasticVolatility{T}
54
+ };
55
+
56
+ # for non-markovian models, we can redefine dynamics to reference the trajectory
57
+ function AdvancedPS. dynamics (
58
+ ssm:: AdvancedPS.TracedSSM{<:GPSSM{T},T,T} , step:: Int
59
+ ) where {T<: Real }
60
+ prior = ssm. model. dyn. proc (1 : (step - 1 ))
61
+ post = posterior (prior, ssm. X[1 : (step - 1 )])
62
+ μ, σ = mean_and_cov (post, [step])
63
+ return LinearGaussianDynamics (zero (T), μ[1 ], sqrt (σ[1 ]))
69
64
end
70
65
71
- AdvancedPS. isdone (:: GPSSM , step) = step > T
66
+ # Everything is now ready to simulate some data.
67
+ rng = MersenneTwister (1234 );
68
+ true_model = LinearGaussianStochasticVolatilityModel (0.9 , 0.5 );
69
+ _, x, y = sample (rng, true_model, 100 );
72
70
73
- model = GPSSM (y, params)
74
- pg = AdvancedPS. PGAS (Nₚ)
75
- chains = sample (rng, model, pg, Nₛ)
71
+ # Create the model and run the sampler
72
+ gpssm = GaussianProcessStateSpaceModel (Float64, SqExponentialKernel ());
73
+ model = gpssm (y);
74
+ pg = AdvancedPS. PGAS (20 );
75
+ chains = sample (rng, model, pg, 250 ; progress= false );
76
+ # md nothing #hide
76
77
77
- particles = hcat ([chain. trajectory. model. X for chain in chains]. .. )
78
+ particles = hcat ([chain. trajectory. model. X for chain in chains]. .. );
78
79
mean_trajectory = mean (particles; dims= 2 );
80
+ # md nothing #hide
79
81
80
82
scatter (particles; label= false , opacity= 0.01 , color= :black , xlabel= " t" , ylabel= " state" )
81
83
plot! (x; color= :darkorange , label= " Original Trajectory" )
0 commit comments