Skip to content

Commit c2c6c11

Browse files
authored
Merge branch 'main' into qqy/sghmc
2 parents 715aefa + 7e78371 commit c2c6c11

File tree

9 files changed

+105
-44
lines changed

9 files changed

+105
-44
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,29 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1818

1919
[weakdeps]
2020
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
21-
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
2221
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
22+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
2323
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2424
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2525

2626
[extensions]
2727
AdvancedHMCADTypesExt = "ADTypes"
28-
AdvancedHMCComponentArraysExt = "ComponentArrays"
2928
AdvancedHMCCUDAExt = "CUDA"
29+
AdvancedHMCComponentArraysExt = "ComponentArrays"
3030
AdvancedHMCMCMCChainsExt = "MCMCChains"
3131
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
3232

3333
[compat]
3434
ADTypes = "1"
3535
AbstractMCMC = "5.6"
3636
ArgCheck = "1, 2"
37-
ComponentArrays = "0.15"
3837
CUDA = "3, 4, 5"
38+
ComponentArrays = "0.15"
3939
DocStringExtensions = "0.8, 0.9"
4040
LinearAlgebra = "<0.1, 1"
4141
LogDensityProblems = "2"
4242
LogDensityProblemsAD = "1"
43-
MCMCChains = "5, 6"
43+
MCMCChains = "5, 6, 7"
4444
OrdinaryDiffEq = "6"
4545
ProgressMeter = "1"
4646
Random = "<0.1, 1"

src/adaptation/Adaptation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct NaiveHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAd
3737
pc::M
3838
ssa::Tssa
3939
end
40-
function Base.show(io::IO, a::NaiveHMCAdaptor)
40+
function Base.show(io::IO, ::MIME"text/plain", a::NaiveHMCAdaptor)
4141
return print(io, "NaiveHMCAdaptor(pc=$(a.pc), ssa=$(a.ssa))")
4242
end
4343

src/adaptation/massmatrix.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ end
2323

2424
struct UnitMassMatrix{T<:AbstractFloat} <: MassMatrixAdaptor end
2525

26-
Base.show(io::IO, ::UnitMassMatrix) = print(io, "UnitMassMatrix")
26+
function Base.show(io::IO, mime::MIME"text/plain", ::UnitMassMatrix{T}) where {T}
27+
return print(io, "UnitMassMatrix{$T} adaptor")
28+
end
2729

2830
UnitMassMatrix() = UnitMassMatrix{Float64}()
2931

@@ -91,7 +93,9 @@ mutable struct WelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVec
9193
end
9294
end
9395

94-
Base.show(io::IO, ::WelfordVar) = print(io, "WelfordVar")
96+
function Base.show(io::IO, mime::MIME"text/plain", ::WelfordVar{T}) where {T}
97+
return print(io, "WelfordVar{$T} adaptor")
98+
end
9599

96100
function WelfordVar{T}(
97101
sz::Union{Tuple{Int},Tuple{Int,Int}}; n_min::Int=10, var=ones(T, sz)
@@ -190,7 +194,9 @@ mutable struct WelfordCov{F<:AbstractFloat,C<:AbstractMatrix{F}} <: DenseMatrixE
190194
cov::C
191195
end
192196

193-
Base.show(io::IO, ::WelfordCov) = print(io, "WelfordCov")
197+
function Base.show(io::IO, mime::MIME"text/plain", ::WelfordCov{T}) where {T}
198+
return print(io, "WelfordCov{$T} adaptor")
199+
end
194200

195201
function WelfordCov{T}(
196202
sz::Tuple{Int}; n_min::Int=10, cov=LinearAlgebra.diagm(0 => ones(T, first(sz)))

src/adaptation/stan_adaptor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function initialize!(
4949
return nothing
5050
end
5151

52-
function Base.show(io::IO, state::StanHMCAdaptorState)
52+
function Base.show(io::IO, mime::MIME"text/plain", state::StanHMCAdaptorState)
5353
return print(
5454
io,
5555
"window($(state.window_start), $(state.window_end)), window_splits(" *
@@ -69,7 +69,7 @@ struct StanHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAda
6969
window_size::Int
7070
state::StanHMCAdaptorState
7171
end
72-
function Base.show(io::IO, a::StanHMCAdaptor)
72+
function Base.show(io::IO, mime::MIME"text/plain", a::StanHMCAdaptor)
7373
return print(
7474
io,
7575
"StanHMCAdaptor(\n pc=$(a.pc),\n ssa=$(a.ssa),\n init_buffer=$(a.init_buffer), term_buffer=$(a.term_buffer), window_size=$(a.window_size),\n state=$(a.state)\n)",

src/adaptation/stepsize.jl

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
### Mutable states
2+
"""
3+
$(TYPEDEF)
4+
5+
Dual Averaging state
6+
7+
Mutable state for storing the current iteration of the dual averaging algorithm.
28
9+
# Fields
10+
11+
$(TYPEDFIELDS)
12+
"""
313
mutable struct DAState{T<:AbstractScalarOrVec{<:AbstractFloat}}
14+
"Adaptation iteration"
415
m::Int
516
ϵ::T
17+
"Asymptotic mean of parameter"
618
μ::T
19+
"Moving average parameter"
720
x_bar::T
21+
"Moving average statistic"
822
H_bar::T
923
end
1024

@@ -63,48 +77,66 @@ getϵ(ss::StepSizeAdaptor) = ss.state.ϵ
6377
struct FixedStepSize{T<:AbstractScalarOrVec{<:AbstractFloat}} <: StepSizeAdaptor
6478
ϵ::T
6579
end
66-
Base.show(io::IO, a::FixedStepSize) = print(io, "FixedStepSize(", a.ϵ, ")")
80+
function Base.show(io::IO, mime::MIME"text/plain", a::FixedStepSize)
81+
return print(io, "FixedStepSize adaptor with step size ", a.ϵ)
82+
end
6783

6884
getϵ(fss::FixedStepSize) = fss.ϵ
6985

7086
struct ManualSSAdaptor{T<:AbstractScalarOrVec{<:AbstractFloat}} <: StepSizeAdaptor
7187
state::MSSState{T}
7288
end
73-
Base.show(io::IO, a::ManualSSAdaptor) = print(io, "ManualSSAdaptor()")
89+
function Base.show(io::IO, mime::MIME"text/plain", a::ManualSSAdaptor{T}) where {T}
90+
return print(io, "ManualSSAdaptor{$T} with step size of $(a.state.ϵ)")
91+
end
7492

7593
function ManualSSAdaptor(initϵ::T) where {T<:AbstractScalarOrVec{<:AbstractFloat}}
7694
return ManualSSAdaptor{T}(MSSState(initϵ))
7795
end
7896

7997
"""
98+
$(TYPEDEF)
99+
80100
An implementation of the Nesterov dual averaging algorithm to tune step size.
81101
82-
References
102+
# Fields
103+
104+
$(TYPEDFIELDS)
105+
106+
# References
83107
84108
Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn Sampler: adaptively setting path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research, 15(1), 1593-1623.
85109
Nesterov, Y. (2009). Primal-dual subgradient methods for convex problems. Mathematical programming, 120(1), 221-259.
86110
"""
87111
struct NesterovDualAveraging{T<:AbstractFloat,S<:AbstractScalarOrVec{T}} <: StepSizeAdaptor
112+
"Adaption scaling"
88113
γ::T
114+
"Effective starting iteration"
89115
t_0::T
116+
"Adaption shrinkage"
90117
κ::T
118+
"Target value of statistic"
91119
δ::T
92120
state::DAState{S}
93121
end
94-
function Base.show(io::IO, a::NesterovDualAveraging)
122+
function Base.show(io::IO, mime::MIME"text/plain", a::NesterovDualAveraging{T}) where {T}
95123
return print(
96124
io,
97-
"NesterovDualAveraging(γ=",
125+
"NesterovDualAveraging{$T} with\n",
126+
"Scaling γ=",
98127
a.γ,
99-
", t_0=",
128+
"\n",
129+
"Starting iter t_0=",
100130
a.t_0,
101-
", κ=",
131+
"\n",
132+
"Shrinkage κ=",
102133
a.κ,
103-
", δ=",
134+
"\n",
135+
"Target statistic δ=",
104136
a.δ,
105-
", state.ϵ=",
137+
"\n",
138+
"Curret ϵ=",
106139
getϵ(a),
107-
")",
108140
)
109141
end
110142

src/hamiltonian.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@ struct Hamiltonian{M<:AbstractMetric,K<:AbstractKinetic,Tlogπ,T∂logπ∂θ}
44
ℓπ::Tlogπ
55
∂ℓπ∂θ::T∂logπ∂θ
66
end
7-
function Base.show(io::IO, h::Hamiltonian)
8-
return print(io, "Hamiltonian(metric=$(h.metric), kinetic=$(h.kinetic))")
7+
function Base.show(io::IO, mime::MIME"text/plain", h::Hamiltonian)
8+
return print(
9+
io,
10+
"Hamiltonian with ",
11+
nameof(typeof(h.metric)),
12+
" and ",
13+
nameof(typeof(h.kinetic)),
14+
)
915
end
1016

1117
# By default we use Gaussian kinetic energy; also to ensure backward compatibility at the time this was introduced

src/integrator.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ struct Leapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
7272
"Step size."
7373
ϵ::T
7474
end
75-
Base.show(io::IO, l::Leapfrog) = print(io, "Leapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)))")
75+
function Base.show(io::IO, mime::MIME"text/plain", l::Leapfrog)
76+
return print(io, "Leapfrog with step size ϵ=$(round.(l.ϵ; sigdigits=3))")
77+
end
7678
integrator_eltype(i::AbstractLeapfrog{T}) where {T<:AbstractFloat} = T
7779

7880
### Jittering
@@ -118,10 +120,10 @@ end
118120

119121
JitteredLeapfrog(ϵ0, jitter) = JitteredLeapfrog(ϵ0, jitter, ϵ0)
120122

121-
function Base.show(io::IO, l::JitteredLeapfrog)
123+
function Base.show(io::IO, mime::MIME"text/plain", l::JitteredLeapfrog)
122124
return print(
123125
io,
124-
"JitteredLeapfrog(ϵ0=$(round.(l.ϵ0; sigdigits=3)), jitter=$(round.(l.jitter; sigdigits=3)), ϵ=$(round.(l.ϵ; sigdigits=3)))",
126+
"JitteredLeapfrog with step size $(round.(l.ϵ0; sigdigits=3)), jitter $(round.(l.jitter; sigdigits=3)), jittered step size $(round.(l.ϵ; sigdigits=3))",
125127
)
126128
end
127129

@@ -171,9 +173,10 @@ struct TemperedLeapfrog{FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} <: Abstrac
171173
α::FT
172174
end
173175

174-
function Base.show(io::IO, l::TemperedLeapfrog)
176+
function Base.show(io::IO, mime::MIME"text/plain", l::TemperedLeapfrog)
175177
return print(
176-
io, "TemperedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), α=$(round.(l.α; sigdigits=3)))"
178+
io,
179+
"TemperedLeapfrog with step size ϵ=$(round.(l.ϵ; sigdigits=3)) and temperature parameter α=$(round.(l.α; sigdigits=3))",
177180
)
178181
end
179182

src/metric.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,12 @@ renew(ue::UnitEuclideanMetric, M⁻¹) = UnitEuclideanMetric(M⁻¹, ue.size)
3333
Base.eltype(::UnitEuclideanMetric{T}) where {T} = T
3434
Base.size(e::UnitEuclideanMetric) = e.size
3535
Base.size(e::UnitEuclideanMetric, dim::Int) = e.size[dim]
36-
function Base.show(io::IO, uem::UnitEuclideanMetric)
37-
return print(io, "UnitEuclideanMetric($(_string_M⁻¹(ones(uem.size))))")
36+
function Base.show(io::IO, ::MIME"text/plain", uem::UnitEuclideanMetric{T}) where {T}
37+
return print(
38+
io,
39+
"UnitEuclideanMetric{$T} with size $(size(uem)) mass matrix:\n",
40+
_string_M⁻¹(ones(uem.size)),
41+
)
3842
end
3943

4044
struct DiagEuclideanMetric{T,A<:AbstractVecOrMat{T}} <: AbstractMetric
@@ -58,8 +62,12 @@ renew(ue::DiagEuclideanMetric, M⁻¹) = DiagEuclideanMetric(M⁻¹)
5862

5963
Base.eltype(::DiagEuclideanMetric{T}) where {T} = T
6064
Base.size(e::DiagEuclideanMetric, dim...) = size(e.M⁻¹, dim...)
61-
function Base.show(io::IO, dem::DiagEuclideanMetric)
62-
return print(io, "DiagEuclideanMetric($(_string_M⁻¹(dem.M⁻¹)))")
65+
function Base.show(io::IO, ::MIME"text/plain", dem::DiagEuclideanMetric{T}) where {T}
66+
return print(
67+
io,
68+
"DiagEuclideanMetric{$T} with size $(size(dem)) mass matrix:\n",
69+
_string_M⁻¹(dem.M⁻¹),
70+
)
6371
end
6472

6573
struct DenseEuclideanMetric{
@@ -94,8 +102,12 @@ renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹)
94102

95103
Base.eltype(::DenseEuclideanMetric{T}) where {T} = T
96104
Base.size(e::DenseEuclideanMetric, dim...) = size(e._temp, dim...)
97-
function Base.show(io::IO, dem::DenseEuclideanMetric)
98-
return print(io, "DenseEuclideanMetric(diag=$(_string_M⁻¹(dem.M⁻¹)))")
105+
function Base.show(io::IO, ::MIME"text/plain", dem::DenseEuclideanMetric{T}) where {T}
106+
return print(
107+
io,
108+
"DenseEuclideanMetric{$T} with size $(size(dem)) mass matrix:\n",
109+
_string_M⁻¹(dem.M⁻¹),
110+
)
99111
end
100112

101113
# `rand` functions for `metric` types.

src/trajectory.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,12 @@ struct SliceTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler
108108
n::Int
109109
end
110110

111-
Base.show(io::IO, s::SliceTS) = print(io, "SliceTS(ℓu=$(s.ℓu), n=$(s.n))")
111+
function Base.show(io::IO, mime::MIME"text/plain", s::SliceTS)
112+
return print(
113+
io,
114+
"SliceTS with slice variable ℓu=$(s.ℓu) and number of acceptable candiadtes n=$(s.n)",
115+
)
116+
end
112117

113118
"""
114119
$(TYPEDEF)
@@ -217,9 +222,10 @@ end
217222

218223
ConstructionBase.constructorof(::Type{<:Trajectory{TS}}) where {TS} = Trajectory{TS}
219224

220-
function Base.show(io::IO, τ::Trajectory{TS}) where {TS}
225+
function Base.show(io::IO, mime::MIME"text/plain", τ::Trajectory{TS}) where {TS}
221226
return print(
222-
io, "Trajectory{$TS}(integrator=$(τ.integrator), tc=$(τ.termination_criterion))"
227+
io,
228+
"Trajectory{$TS} with $(τ.integrator) and termination criterion $(τ.termination_criterion)",
223229
)
224230
end
225231

@@ -468,8 +474,10 @@ struct Termination
468474
numerical::Bool
469475
end
470476

471-
function Base.show(io::IO, d::Termination)
472-
return print(io, "Termination(dynamic=$(d.dynamic), numerical=$(d.numerical))")
477+
function Base.show(io::IO, mime::MIME"text/plain", d::Termination)
478+
return print(
479+
io, "Termination reasons of (dynamic=$(d.dynamic), numerical=$(d.numerical))"
480+
)
473481
end
474482
function Base.:*(d1::Termination, d2::Termination)
475483
return Termination(d1.dynamic || d2.dynamic, d1.numerical || d2.numerical)
@@ -484,12 +492,6 @@ Check termination of a Hamiltonian trajectory.
484492
function Termination(s::SliceTS, nt::Trajectory, H0::F, H′::F) where {F<:AbstractFloat}
485493
return Termination(false, !(s.ℓu < nt.termination_criterion.Δ_max + -H′))
486494
end
487-
488-
"""
489-
$(SIGNATURES)
490-
491-
Check termination of a Hamiltonian trajectory.
492-
"""
493495
function Termination(
494496
s::MultinomialTS, nt::Trajectory, H0::F, H′::F
495497
) where {F<:AbstractFloat}

0 commit comments

Comments
 (0)