Skip to content

All VarInfos are informative, but some VarInfos are more informative than others #2642

@penelopeysm

Description

@penelopeysm

While editing the sampler code in Turing, there's a common pattern where:

  • we have a model + varinfo and construct a LogDensityFunction from it
  • the sampler takes the LDF, does things, and gives us back a vector of parameters + a single logp

After that we take the returned parameters and construct a transition and a state with it. These two things are slightly different:

  • Transition: we stick the vector of parameters back in the varinfo and then re-evaluate the model (in the Turing.Inference.Transition constructor)
  • State: we stick the vector of parameters back in the varinfo and pass it on to the next step in MCMC.

In the transition case, it's clear that the varinfo does not need to have the correct logp because it will be re-evaluated anyway.

In the state case, it depends on whether varinfo's logp is accessed on the next iteration. Some samplers do (MH). Some don't (HMC).

Finally, there are some cases where we construct LDF from a varinfo. In such a case, the parameters and logp in the varinfo are completely useless; the varinfo is purely there for structure.

This creates a case where we're passing varinfos around, some have the wrong logp, and some have the wrong values, and it's quite difficult to reason about when a varinfo really needs to have the right values and logp.

I would love if we could make this clearer with some new types and/or wrappers. For example, consider the case where you do care about the parameters but not logp. So instead of carrying around a VarInfo with correct parameters set inside it, we should carry a struct that contains both a VarInfo (with wrong parameters) and the right parameters, and then call get_vi_with_correct_params() whenever we need it:

struct VIWithParams
    vi::AbstractVarInfo
    params::AbstractVector
end

get_vi_with_correct_params(vwp::VIWithParams) = DynamicPPL.unflatten(vwp.vi, vwp.params)

In this case vi only exists to provide structure, and this can be documented nicely.

But also sometimes the only thing that the next step does is to just re-extract the params with vi[:]. In such a case, we could simply use vwp.params instead of going via unflatten + vi[:] which could potentially cause issues (TuringLang/DynamicPPL.jl#1001).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions