Skip to content

Commit 296f654

Browse files
authored
Fix multiple-chain method ambiguity (#2670)
* Fix multiple-chain method ambiguity * Remove dead code in RepeatSampler * Fix test * Update changelog * Add tests for chain save/resume * Changelog * fix missing import
1 parent 0a6a10b commit 296f654

File tree

7 files changed

+64
-35
lines changed

7 files changed

+64
-35
lines changed

HISTORY.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
# 0.40.3
2+
3+
This patch makes the `resume_from` keyword argument work correctly when sampling multiple chains.
4+
5+
In the process this also fixes a method ambiguity caused by a bugfix in DynamicPPL 0.37.2.
6+
7+
This patch means that if you are using `RepeatSampler()` to sample from a model, and you want to obtain `MCMCChains.Chains` from it, you need to specify `sample(...; chain_type=MCMCChains.Chains)`.
8+
This only applies if the sampler itself is a `RepeatSampler`; it doesn't apply if you are using `RepeatSampler` _within_ another sampler like Gibbs.
9+
110
# 0.40.2
211

312
`sample(model, NUTS(), N; verbose=false)` now suppresses the 'initial step size' message.

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.40.2"
3+
version = "0.40.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -64,7 +64,7 @@ Distributions = "0.25.77"
6464
DistributionsAD = "0.6"
6565
DocStringExtensions = "0.8, 0.9"
6666
DynamicHMC = "3.4"
67-
DynamicPPL = "0.37"
67+
DynamicPPL = "0.37.2"
6868
EllipticalSliceSampling = "0.5, 1, 2"
6969
ForwardDiff = "0.10.3, 1"
7070
Libtask = "0.9.3"

src/mcmc/abstractmcmc.jl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,3 @@ function AbstractMCMC.sample(
5757
check_model && _check_model(model, alg)
5858
return AbstractMCMC.sample(rng, model, Sampler(alg), ensemble, N, n_chains; kwargs...)
5959
end
60-
61-
function AbstractMCMC.sample(
62-
rng::AbstractRNG,
63-
model::AbstractModel,
64-
sampler::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
65-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
66-
N::Integer,
67-
n_chains::Integer;
68-
chain_type=MCMCChains.Chains,
69-
progress=PROGRESS[],
70-
kwargs...,
71-
)
72-
return AbstractMCMC.mcmcsample(
73-
rng,
74-
model,
75-
sampler,
76-
ensemble,
77-
N,
78-
n_chains;
79-
chain_type=chain_type,
80-
progress=progress,
81-
kwargs...,
82-
)
83-
end

src/mcmc/repeat_sampler.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@ function RepeatSampler(alg::InferenceAlgorithm, num_repeat::Int)
2828
return RepeatSampler(Sampler(alg), num_repeat)
2929
end
3030

31-
getADType(spl::RepeatSampler) = getADType(spl.sampler)
32-
DynamicPPL.default_chain_type(sampler::RepeatSampler) = default_chain_type(sampler.sampler)
33-
# TODO(mhauru) Remove the below once DynamicPPL has removed all its Selector stuff.
34-
DynamicPPL.inspace(vn::VarName, spl::RepeatSampler) = inspace(vn, spl.sampler)
35-
3631
function setparams_varinfo!!(model::DynamicPPL.Model, sampler::RepeatSampler, state, params)
3732
return setparams_varinfo!!(model, sampler.sampler, state, params)
3833
end

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ Combinatorics = "1"
5353
Distributions = "0.25"
5454
DistributionsAD = "0.6.3"
5555
DynamicHMC = "2.1.6, 3.0"
56-
DynamicPPL = "0.37"
56+
DynamicPPL = "0.37.2"
5757
FiniteDifferences = "0.10.8, 0.11, 0.12"
5858
ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1"
5959
HypothesisTests = "0.11"
6060
LinearAlgebra = "1"
6161
LogDensityProblems = "2"
6262
LogDensityProblemsAD = "1.4"
63-
MCMCChains = "5, 6, 7"
63+
MCMCChains = "7.3.0"
6464
NamedArrays = "0.9.4, 0.10"
6565
Optim = "1"
6666
Optimization = "3, 4"

test/mcmc/Inference.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ..Models: gdemo_d, gdemo_default
44
using ..NumericalTests: check_gdemo, check_numerical
55
using Distributions: Bernoulli, Beta, InverseGamma, Normal
66
using Distributions: sample
7+
using AbstractMCMC: AbstractMCMC
78
import DynamicPPL
89
using DynamicPPL: Sampler
910
import ForwardDiff
@@ -72,7 +73,48 @@ using Turing
7273
end
7374
end
7475

75-
@testset "chain save/resume" begin
76+
@testset "save/resume correctly reloads state" begin
77+
struct StaticSampler <: Turing.Inference.InferenceAlgorithm end
78+
function DynamicPPL.initialstep(
79+
rng, model, ::DynamicPPL.Sampler{<:StaticSampler}, vi; kwargs...
80+
)
81+
return Turing.Inference.Transition(model, vi, nothing), vi
82+
end
83+
function AbstractMCMC.step(
84+
rng,
85+
model,
86+
::DynamicPPL.Sampler{<:StaticSampler},
87+
vi::DynamicPPL.AbstractVarInfo;
88+
kwargs...,
89+
)
90+
return Turing.Inference.Transition(model, vi, nothing), vi
91+
end
92+
93+
@model demo() = x ~ Normal()
94+
95+
@testset "single-chain" begin
96+
chn1 = sample(demo(), StaticSampler(), 10; save_state=true)
97+
@test chn1.info.samplerstate isa DynamicPPL.AbstractVarInfo
98+
chn2 = sample(demo(), StaticSampler(), 10; resume_from=chn1)
99+
xval = chn1[:x][1]
100+
@test all(chn2[:x] .== xval)
101+
end
102+
103+
@testset "multiple-chain" for nchains in [1, 3]
104+
chn1 = sample(
105+
demo(), StaticSampler(), MCMCThreads(), 10, nchains; save_state=true
106+
)
107+
@test chn1.info.samplerstate isa AbstractVector{<:DynamicPPL.AbstractVarInfo}
108+
@test length(chn1.info.samplerstate) == nchains
109+
chn2 = sample(
110+
demo(), StaticSampler(), MCMCThreads(), 10, nchains; resume_from=chn1
111+
)
112+
xval = chn1[:x][1, :]
113+
@test all(i -> chn2[:x][i, :] == xval, 1:10)
114+
end
115+
end
116+
117+
@testset "single-chain save/resume numerical accuracy" begin
76118
alg1 = HMCDA(1000, 0.65, 0.15)
77119
alg2 = PG(20)
78120
alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4))

test/mcmc/repeat_sampler.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module RepeatSamplerTests
22

33
using ..Models: gdemo_default
44
using DynamicPPL: Sampler
5+
using MCMCChains: Chains
56
using StableRNGs: StableRNG
67
using Test: @test, @testset
78
using Turing
@@ -26,7 +27,13 @@ using Turing
2627
)
2728
repeat_sampler = RepeatSampler(sampler, num_repeats)
2829
chn2 = sample(
29-
copy(rng), gdemo_default, repeat_sampler, MCMCThreads(), num_samples, num_chains
30+
copy(rng),
31+
gdemo_default,
32+
repeat_sampler,
33+
MCMCThreads(),
34+
num_samples,
35+
num_chains;
36+
chain_type=Chains,
3037
)
3138
@test chn1.value == chn2.value
3239
end

0 commit comments

Comments
 (0)