Skip to content

Add vectorize_over_posterior to pymc.sampling.forward #7841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ This submodule contains functions for MCMC and forward sampling.
sample_posterior_predictive
draw
compute_deterministics
vectorize_over_posterior
init_nuts
sampling.jax.sample_blackjax_nuts
sampling.jax.sample_numpyro_nuts
Expand Down
103 changes: 102 additions & 1 deletion pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@

import numpy as np
import xarray
import xarray as xr

from arviz import InferenceData
from pytensor import tensor as pt
from pytensor.graph import vectorize_graph
from pytensor.graph.basic import (
Apply,
Constant,
Expand All @@ -42,7 +44,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
from pytensor.tensor.variable import TensorConstant
from pytensor.tensor.variable import TensorConstant, TensorVariable
from rich.console import Console
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
Expand All @@ -52,6 +54,8 @@
from pymc.backends.arviz import _DefaultTrace, dataset_to_point_list
from pymc.backends.base import MultiTrace
from pymc.blocking import PointType
from pymc.distributions.shape_utils import change_dist_size
from pymc.logprob.utils import rvs_in_graph
from pymc.model import Model, modelcontext
from pymc.pytensorf import compile
from pymc.util import (
Expand All @@ -68,6 +72,7 @@
"draw",
"sample_posterior_predictive",
"sample_prior_predictive",
"vectorize_over_posterior",
)

ArrayLike: TypeAlias = np.ndarray | list[float]
Expand Down Expand Up @@ -984,3 +989,99 @@ def sample_posterior_predictive(
idata.extend(idata_pp)
return idata
return idata_pp


def vectorize_over_posterior(
outputs: list[Variable],
posterior: xr.Dataset,
input_rvs: list[Variable],
allow_rvs_in_graph: bool = True,
sample_dims: tuple[str, ...] = ("chain", "draw"),
) -> list[Variable]:
"""Vectorize outputs over posterior samples of subset of input rvs.

This function creates a new graph for the supplied outputs, where the required
subset of input rvs are replaced by their posterior samples (chain and draw
dimensions are flattened). The other input tensors are kept as is.

Parameters
----------
outputs : list[Variable]
The list of variables to vectorize over the posterior samples.
posterior : xr.Dataset
The posterior samples to use as replacements for the `input_rvs`.
input_rvs : list[Variable]
The list of random variables to replace with their posterior samples.
allow_rvs_in_graph : bool
Whether to allow random variables to be present in the graph. If False,
an error will be raised if any random variables are found in the graph. If
True, the remaining random variables will be resized to match the number of
draws from the posterior.
sample_dims : tuple[str, ...]
The dimensions of the posterior samples to use for vectorizing the `input_rvs`.


Returns
-------
vectorized_outputs : list[Variable]
The vectorized variables

Raises
------
RuntimeError
If random variables are found in the graph and `allow_rvs_in_graph` is False
"""
# Identify which free RVs are needed to compute `outputs`
needed_rvs: list[TensorVariable] = [
cast(TensorVariable, rv)
for rv in ancestors(outputs, blockers=input_rvs)
if rv in set(input_rvs)
]

# Replace needed_rvs with actual posterior samples
batch_shape = tuple([len(posterior.coords[dim]) for dim in sample_dims])
replace_dict: dict[Variable, Variable] = {}
for rv in needed_rvs:
posterior_samples = posterior[rv.name].data

replace_dict[rv] = pt.constant(posterior_samples.astype(rv.dtype), name=rv.name)

# Replace the rvs that remain in the graph with resized versions
all_rvs = rvs_in_graph(outputs)

# Once we give values for the needed_rvs (setting them to their posterior samples),
# we need to identify the rvs that only depend on these conditioned values, and
# don't depend on any other rvs or output nodes.
# These variables need to be resized because they won't be resized implicitly by
# the replacement of the needed_rvs or other random variables in the graph when we
# later call vectorize_graph.
independent_rvs: list[TensorVariable] = []
for rv in [
rv
for rv in general_toposort( # type: ignore[call-overload]
all_rvs, lambda x: x.owner.inputs if x.owner is not None else None
)
if rv in all_rvs
]:
rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs])
if (
rv not in needed_rvs
and not ({*outputs, *independent_rvs} & set(rv_ancestors))
and {var for var in rv_ancestors if var in all_rvs} <= {rv, *needed_rvs}
):
independent_rvs.append(rv)
for rv in independent_rvs:
replace_dict[rv] = change_dist_size(rv, new_size=batch_shape, expand=True)

# Vectorize across samples
vectorized_outputs = list(vectorize_graph(outputs, replace=replace_dict))
for vectorized_output, output in zip(vectorized_outputs, outputs):
vectorized_output.name = output.name

if not allow_rvs_in_graph:
remaining_rvs = rvs_in_graph(vectorized_outputs)
if remaining_rvs:
raise RuntimeError(
f"The following random variables found in the extracted graph: {remaining_rvs}"
)
return vectorized_outputs
158 changes: 158 additions & 0 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,22 @@
from pytensor import Mode, shared
from pytensor.compile import SharedVariable
from pytensor.graph import graph_inputs
from pytensor.graph.basic import get_var_by_name, variable_depends_on
from pytensor.tensor.variable import TensorConstant
from scipy import stats

import pymc as pm

from pymc.backends.base import MultiTrace
from pymc.logprob.utils import rvs_in_graph
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.pytensorf import compile
from pymc.sampling.forward import (
compile_forward_sampling_function,
get_constant_coords,
get_vars_in_point_list,
observed_dependent_deterministics,
vectorize_over_posterior,
)
from pymc.testing import fast_unstable_sampling_mode

Expand Down Expand Up @@ -1801,3 +1806,156 @@ def test_sample_prior_predictive_samples_deprecated_warns() -> None:
match = "The samples argument has been deprecated"
with pytest.warns(DeprecationWarning, match=match):
pm.sample_prior_predictive(model=m, samples=10)


@pytest.fixture(params=["deterministic", "observed", "conditioned_on_observed"])
def variable_to_vectorize(request):
if request.param == "deterministic":
return ["y"]
elif request.param == "conditioned_on_observed":
return ["z", "z_downstream"]
else:
return ["z"]


@pytest.fixture(params=["allow_rvs_in_graph", "disallow_rvs_in_graph"])
def allow_rvs_in_graph(request):
if request.param == "allow_rvs_in_graph":
return True
else:
return False


@pytest.fixture(scope="module", params=["nested_random_variables", "no_nested_random_variables"])
def has_nested_random_variables(request):
return request.param == "nested_random_variables"


@pytest.fixture(scope="module")
def model_to_vectorize(has_nested_random_variables):
with pm.Model() as model:
if not has_nested_random_variables:
x_parent = 0.0
else:
x_parent = pm.Normal("x_parent")
x = pm.Normal("x", mu=x_parent)
d = pm.Data("d", np.array([1, 2, 3]))
obs = pm.Data("obs", np.ones_like(d.get_value()))
y = pm.Deterministic("y", x * d)
z = pm.Gamma("z", mu=pt.exp(y), sigma=pt.exp(y) * 0.1, observed=obs)
pm.Deterministic("z_downstream", z * 2)

with model:
idata = pm.sample_prior_predictive(100)
idata.add_groups({"posterior": idata.prior})
return freeze_dims_and_data(model), idata


@pytest.fixture(params=["rv_from_posterior", "resample_rv"])
def input_rv_names(request, has_nested_random_variables):
if request.param == "rv_from_posterior":
if has_nested_random_variables:
return ["x_parent", "x"]
else:
return ["x"]
else:
return []


def test_vectorize_over_posterior(
variable_to_vectorize,
input_rv_names,
allow_rvs_in_graph,
model_to_vectorize,
):
model, idata = model_to_vectorize

if not allow_rvs_in_graph and (len(input_rv_names) == 0 or "z" in variable_to_vectorize):
with pytest.raises(
RuntimeError,
match="The following random variables found in the extracted graph",
):
vectorize_over_posterior(
outputs=[model[name] for name in variable_to_vectorize],
posterior=idata.posterior,
input_rvs=[model[name] for name in input_rv_names],
allow_rvs_in_graph=allow_rvs_in_graph,
)
else:
vectorized = vectorize_over_posterior(
outputs=[model[name] for name in variable_to_vectorize],
posterior=idata.posterior,
input_rvs=[model[name] for name in input_rv_names],
allow_rvs_in_graph=allow_rvs_in_graph,
)
assert all(
vectorized_var is not model[name]
for vectorized_var, name in zip(vectorized, variable_to_vectorize)
)
assert all(vectorized_var.type.shape == (1, 100, 3) for vectorized_var in vectorized)
assert all(variable_depends_on(vectorized_var, model["d"]) for vectorized_var in vectorized)
if len(vectorized) == 2:
assert variable_depends_on(
vectorized[variable_to_vectorize.index("z_downstream")],
vectorized[variable_to_vectorize.index("z")],
)
if len(input_rv_names) > 0:
for input_rv_name in input_rv_names:
if input_rv_name == "x_parent":
assert len(get_var_by_name(vectorized, input_rv_name)) == 0
else:
[vectorized_rv] = get_var_by_name(vectorized, input_rv_name)
rv_posterior = idata.posterior[input_rv_name].data
assert isinstance(vectorized_rv, TensorConstant)
assert np.all(vectorized_rv.value == rv_posterior)
else:
batch_shape = (
len(idata.posterior.coords["chain"]),
len(idata.posterior.coords["draw"]),
)
original_rvs = rvs_in_graph([model[name] for name in variable_to_vectorize])
expected_rv_shapes = {(*batch_shape, *rv.type.shape) for rv in original_rvs}
rvs = rvs_in_graph(vectorized)
assert {rv.type.shape for rv in rvs} == expected_rv_shapes


def test_vectorize_over_posterior_matches_sample():
rng = np.random.default_rng(1234)
with pm.Model() as model:
x = pm.Normal("x")
sigma = 0.1
obs = pm.Normal("obs", x, sigma, observed=rng.normal(size=10))
det = pm.Deterministic("det", obs + 1)

chains = 2
draws = 100
x_posterior = np.broadcast_to(100 * np.arange(chains)[..., None], (chains, draws))
with model:
posterior = xr.Dataset(
{
"x": xr.DataArray(
x_posterior,
dims=("chain", "draw"),
coords={"chain": np.arange(chains), "draw": np.arange(draws)},
)
}
)
idata = InferenceData(posterior=posterior)
with model:
pp = pm.sample_posterior_predictive(idata, var_names=["obs", "det"], random_seed=1234)
vectorized = vectorize_over_posterior(
outputs=[obs, det],
posterior=posterior,
input_rvs=[x],
allow_rvs_in_graph=True,
)
[vect_obs, vect_det] = compile(inputs=[], outputs=vectorized, random_seed=1234)()
assert pp.posterior_predictive["obs"].shape == vect_obs.shape
assert pp.posterior_predictive["det"].shape == vect_det.shape
np.testing.assert_allclose(vect_obs + 1, vect_det)
np.testing.assert_allclose(
pp.posterior_predictive["obs"].mean(dim=("chain", "draw")),
vect_obs.mean(axis=(0, 1)),
atol=0.6 / np.sqrt(10000),
)
assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1)
Loading