diff --git a/HISTORY.md b/HISTORY.md index 038968ef..2c7e4e01 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # AdvancedHMC Changelog +## 0.9.0 + + - Stochastic gradient based methods `SGHMC` and `SGLD` are supported in AdvancedHMC.jl, please note there are similar methods with the same name in Turing.jl, so when using the two packages together, please specify the package exporting the method. + ## 0.8.0 - To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`). diff --git a/Project.toml b/Project.toml index 599dae3c..77e85396 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.8.0" +version = "0.9.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/Project.toml b/docs/Project.toml index c48bd544..ba4848b5 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" [compat] -AdvancedHMC = "0.8" +AdvancedHMC = "0.9" Documenter = "1" -DocumenterCitations = "1" \ No newline at end of file +DocumenterCitations = "1" diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index b25710d5..96f20a19 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -125,7 +125,7 @@ include("sampler.jl") export sample include("constructors.jl") -export HMCSampler, HMC, NUTS, HMCDA +export HMCSampler, HMC, NUTS, HMCDA, SGHMC include("abstractmcmc.jl") diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 1ae71a72..a59e8002 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -205,6 +205,120 @@ function AbstractMCMC.step( return Transition(t.z, tstat), newstate end +struct SGHMCState{ + TTrans<:Transition, + TMetric<:AbstractMetric, + TKernel<:AbstractMCMCKernel, + TAdapt<:Adaptation.AbstractAdaptor, + T<:AbstractVector{<:Real}, +} + "Index of current iteration." + i::Int + "Current [`Transition`](@ref)." + transition::TTrans + "Current [`AbstractMetric`](@ref), possibly adapted." + metric::TMetric + "Current [`AbstractMCMCKernel`](@ref)." + κ::TKernel + "Current [`AbstractAdaptor`](@ref)." + adaptor::TAdapt + velocity::T +end +getadaptor(state::SGHMCState) = state.adaptor +getmetric(state::SGHMCState) = state.metric +getintegrator(state::SGHMCState) = state.κ.τ.integrator + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.LogDensityModel, + spl::SGHMC; + initial_params=nothing, + kwargs..., +) + # Unpack model + logdensity = model.logdensity + + # Define metric + metric = make_metric(spl, logdensity) + + # Construct the hamiltonian using the initial metric + hamiltonian = Hamiltonian(metric, model) + + # Compute initial sample and state. + initial_params = make_initial_params(rng, spl, logdensity, initial_params) + ϵ = make_step_size(rng, spl, hamiltonian, initial_params) + integrator = make_integrator(spl, ϵ) + + # Make kernel + κ = make_kernel(spl, integrator) + + # Make adaptor + adaptor = make_adaptor(spl, metric, integrator) + + # Get an initial sample. + h, t = AdvancedHMC.sample_init(rng, hamiltonian, initial_params) + + state = SGHMCState(0, t, metric, κ, adaptor, initial_params) + + return AbstractMCMC.step(rng, model, spl, state; kwargs...) +end + +function AbstractMCMC.step( + rng::AbstractRNG, + model::AbstractMCMC.LogDensityModel, + spl::SGHMC, + state::SGHMCState; + n_adapts::Int=0, + kwargs..., +) + if haskey(kwargs, :nadapts) + throw( + ArgumentError( + "keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.", + ), + ) + end + + i = state.i + 1 + t_old = state.transition + adaptor = state.adaptor + κ = state.κ + metric = state.metric + + # Reconstruct hamiltonian. + h = Hamiltonian(metric, model) + + # Compute gradient of log density. + logdensity_and_gradient = Base.Fix1( + LogDensityProblems.logdensity_and_gradient, model.logdensity + ) + θ = copy(t_old.z.θ) + grad = last(logdensity_and_gradient(θ)) + + # Update latent variables and velocity according to + # equation (15) of Chen et al. (2014) + v = state.velocity + η = spl.learning_rate + α = spl.momentum_decay + newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) + θ .+= newv + + # Make new transition. + z = phasepoint(h, θ, v) + t = transition(rng, h, κ, z) + + # Adapt h and spl. + tstat = stat(t) + h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate) + tstat = merge(tstat, (is_adapt=isadapted,)) + + # Compute next sample and state. + sample = Transition(t.z, tstat) + newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv) + + return sample, newstate +end + ################ ### Callback ### ################ @@ -392,6 +506,10 @@ function make_adaptor(spl::HMC, metric::AbstractMetric, integrator::AbstractInte return NoAdaptation() end +function make_adaptor(spl::SGHMC, metric::AbstractMetric, integrator::AbstractIntegrator) + return NoAdaptation() +end + function make_adaptor( spl::HMCSampler, metric::AbstractMetric, integrator::AbstractIntegrator ) @@ -417,3 +535,7 @@ end function make_kernel(spl::HMCSampler, integrator::AbstractIntegrator) return spl.κ end + +function make_kernel(spl::SGHMC, integrator::AbstractIntegrator) + return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(spl.n_leapfrog))) +end diff --git a/src/constructors.jl b/src/constructors.jl index 4d60fdef..69b903ba 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -163,3 +163,48 @@ function HMCDA(δ, λ; integrator=:leapfrog, metric=:diagonal) end sampler_eltype(::HMCDA{T}) where {T} = T + +########### Static Hamiltonian Monte Carlo ########### + +############# +### SGHMC ### +############# +""" + SGHMC(learning_rate::Real, momentun_decay::Real, integrator = :leapfrog, metric = :diagonal) + +Stochastic Gradient Hamiltonian Monte Carlo sampler + +# Fields + +$(FIELDS) + +# Notes + +For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1402.4102)): + +- Chen, Tianqi, Emily Fox, and Carlos Guestrin. "Stochastic gradient hamiltonian monte carlo." International conference on machine learning. PMLR, 2014. +""" +struct SGHMC{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <: + AbstractHMCSampler + "Learning rate for the gradient descent." + learning_rate::T + "Momentum decay rate." + momentum_decay::T + "Number of leapfrog steps." + n_leapfrog::Int + "Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)" + integrator::I + "Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption." + metric::M +end + +function SGHMC( + learning_rate, momentum_decay, n_leapfrog; integrator=:leapfrog, metric=:diagonal +) + T = determine_sampler_eltype( + learning_rate, momentum_decay, n_leapfrog, integrator, metric + ) + return SGHMC(T(learning_rate), T(momentum_decay), n_leapfrog, integrator, metric) +end + +sampler_eltype(::SGHMC{T}) where {T} = T diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index da448f78..50207cb0 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -10,6 +10,7 @@ using Statistics: mean nuts = NUTS(0.8) hmc = HMC(100; integrator=Leapfrog(0.05)) hmcda = HMCDA(0.8, 0.1) + sghmc = SGHMC(0.01, 0.1, 100) integrator = Leapfrog(1e-3) κ = AdvancedHMC.make_kernel(nuts, integrator) @@ -111,6 +112,29 @@ using Statistics: mean @test m_est_hmc ≈ [49 / 24, 7 / 6] atol = RNDATOL + samples_sghmc = AbstractMCMC.sample( + rng, + model, + sghmc, + n_adapts + n_samples; + n_adapts=n_adapts, + initial_params=θ_init, + progress=false, + verbose=false, + ) + + # Transform back to original space. + # NOTE: We're not correcting for the `logabsdetjac` here since, but + # we're only interested in the mean it doesn't matter. + for t in samples_sghmc + t.z.θ .= invlink_gdemo(t.z.θ) + end + m_est_sghmc = mean(samples_sghmc) do t + t.z.θ + end + + @test m_est_sghmc ≈ [49 / 24, 7 / 6] atol = RNDATOL + samples_custom = AbstractMCMC.sample( rng, model,