Skip to content

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Sep 8, 2025

julia> using MCMCChains

julia> chn = Chains(
           rand(10, 3, 1),
           ["a", "b", "c"],
                   info = (
               start_time = 1,
               stop_time = 2,
               samplerstate = "state1",
               otherinfo = "info1",
           ),
       )

julia> # Current version
       chainscat(chn).info
(otherinfo = "info1", start_time = 1, stop_time = 2, samplerstate = "state1")

julia> # This PR
       chainscat(chn).info
(otherinfo = "info1", start_time = [1], stop_time = [2], samplerstate = ["state1"])

Firstly, this is consistent with standard Julia concatenation behaviour which creates singleton arrays:

julia> vcat(1)
1-element Vector{Int64}:
 1

More importantly, this is needed to differentiate the output of sample(model, spl, N) from sample(model, spl, MCMCThreads(), N, 1) so that the save/resume interface in Turing.jl is coherent.

With this PR + DynamicPPL 0.37.2 + TuringLang/Turing.jl#2670 the following all works correctly now:

using Turing, Test

@model f() = x ~ Normal()

chn1 = sample(f(), MH(), 10; save_state=true)
@test chn1.info.samplerstate isa Turing.Inference.MHState
chn2 = sample(f(), MH(), 10; resume_from=chn1)
chn3 = sample(f(), MH(), 10; initial_state=chn1.info.samplerstate)

chn1 = sample(f(), MH(), MCMCSerial(), 10, 1; save_state=true)
@test chn1.info.samplerstate isa AbstractVector{<:Turing.Inference.MHState} && length(chn1.info.samplerstate) == 1
chn2 = sample(f(), MH(), MCMCSerial(), 10, 1; resume_from=chn1)
chn3 = sample(f(), MH(), MCMCSerial(), 10, 1; initial_state=chn1.info.samplerstate)

chn1 = sample(f(), MH(), MCMCSerial(), 10, 3; save_state=true)
@test chn1.info.samplerstate isa AbstractVector{<:Turing.Inference.MHState} && length(chn1.info.samplerstate) == 3
chn2 = sample(f(), MH(), MCMCSerial(), 10, 3; resume_from=chn1)
chn3 = sample(f(), MH(), MCMCSerial(), 10, 3; initial_state=chn1.info.samplerstate)

Copy link
Contributor

github-actions bot commented Sep 8, 2025

MCMCChains.jl documentation for PR #492 is available at:
https://TuringLang.github.io/MCMCChains.jl/previews/PR492/

Copy link

codecov bot commented Sep 8, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 86.13%. Comparing base (1a03a99) to head (ba56ff7).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #492      +/-   ##
==========================================
+ Coverage   86.12%   86.13%   +0.01%     
==========================================
  Files          20       20              
  Lines        1146     1147       +1     
==========================================
+ Hits          987      988       +1     
  Misses        159      159              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@joelkandiah joelkandiah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy that this works as intended and I have run some additional checks for the single chain cases locally that would be good to have tested in Turing.jl as discussed with @penelopeysm .

Any further feedback would just be to make sure that we are happy with the defaults in AbstractMCMC to return the object c if chainsstack(c) is called where c is not a vector going forward....

@penelopeysm penelopeysm merged commit 6909f74 into main Sep 8, 2025
10 checks passed
@penelopeysm penelopeysm deleted the py/vector-in-cat3 branch September 8, 2025 12:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants