-
Notifications
You must be signed in to change notification settings - Fork 228
Description
While editing the sampler code in Turing, there's a common pattern where:
- we have a model + varinfo and construct a LogDensityFunction from it
- the sampler takes the LDF, does things, and gives us back a vector of parameters + a single logp
After that we take the returned parameters and construct a transition and a state with it. These two things are slightly different:
- Transition: we stick the vector of parameters back in the varinfo and then re-evaluate the model (in the
Turing.Inference.Transition
constructor) - State: we stick the vector of parameters back in the varinfo and pass it on to the next step in MCMC.
In the transition case, it's clear that the varinfo does not need to have the correct logp because it will be re-evaluated anyway.
In the state case, it depends on whether varinfo's logp is accessed on the next iteration. Some samplers do (MH). Some don't (HMC).
Finally, there are some cases where we construct LDF from a varinfo. In such a case, the parameters and logp in the varinfo are completely useless; the varinfo is purely there for structure.
This creates a case where we're passing varinfos around, some have the wrong logp, and some have the wrong values, and it's quite difficult to reason about when a varinfo really needs to have the right values and logp.
I would love if we could make this clearer with some new types and/or wrappers. For example, consider the case where you do care about the parameters but not logp. So instead of carrying around a VarInfo with correct parameters set inside it, we should carry a struct that contains both a VarInfo (with wrong parameters) and the right parameters, and then call get_vi_with_correct_params()
whenever we need it:
struct VIWithParams
vi::AbstractVarInfo
params::AbstractVector
end
get_vi_with_correct_params(vwp::VIWithParams) = DynamicPPL.unflatten(vwp.vi, vwp.params)
In this case vi
only exists to provide structure, and this can be documented nicely.
But also sometimes the only thing that the next step does is to just re-extract the params with vi[:]
. In such a case, we could simply use vwp.params
instead of going via unflatten
+ vi[:]
which could potentially cause issues (TuringLang/DynamicPPL.jl#1001).