diff --git a/docs/source/api/samplers.rst b/docs/source/api/samplers.rst index 5a7caa0c7..14a39c376 100644 --- a/docs/source/api/samplers.rst +++ b/docs/source/api/samplers.rst @@ -14,6 +14,7 @@ This submodule contains functions for MCMC and forward sampling. sample_posterior_predictive draw compute_deterministics + vectorize_over_posterior init_nuts sampling.jax.sample_blackjax_nuts sampling.jax.sample_numpyro_nuts diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index b1f9c3989..1be14f77f 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -28,9 +28,11 @@ import numpy as np import xarray +import xarray as xr from arviz import InferenceData from pytensor import tensor as pt +from pytensor.graph import vectorize_graph from pytensor.graph.basic import ( Apply, Constant, @@ -42,7 +44,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable -from pytensor.tensor.variable import TensorConstant +from pytensor.tensor.variable import TensorConstant, TensorVariable from rich.console import Console from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme @@ -52,6 +54,8 @@ from pymc.backends.arviz import _DefaultTrace, dataset_to_point_list from pymc.backends.base import MultiTrace from pymc.blocking import PointType +from pymc.distributions.shape_utils import change_dist_size +from pymc.logprob.utils import rvs_in_graph from pymc.model import Model, modelcontext from pymc.pytensorf import compile from pymc.util import ( @@ -68,6 +72,7 @@ "draw", "sample_posterior_predictive", "sample_prior_predictive", + "vectorize_over_posterior", ) ArrayLike: TypeAlias = np.ndarray | list[float] @@ -984,3 +989,99 @@ def sample_posterior_predictive( idata.extend(idata_pp) return idata return idata_pp + + +def vectorize_over_posterior( + outputs: list[Variable], + posterior: xr.Dataset, + input_rvs: list[Variable], + allow_rvs_in_graph: bool = True, + sample_dims: tuple[str, ...] = ("chain", "draw"), +) -> list[Variable]: + """Vectorize outputs over posterior samples of subset of input rvs. + + This function creates a new graph for the supplied outputs, where the required + subset of input rvs are replaced by their posterior samples (chain and draw + dimensions are flattened). The other input tensors are kept as is. + + Parameters + ---------- + outputs : list[Variable] + The list of variables to vectorize over the posterior samples. + posterior : xr.Dataset + The posterior samples to use as replacements for the `input_rvs`. + input_rvs : list[Variable] + The list of random variables to replace with their posterior samples. + allow_rvs_in_graph : bool + Whether to allow random variables to be present in the graph. If False, + an error will be raised if any random variables are found in the graph. If + True, the remaining random variables will be resized to match the number of + draws from the posterior. + sample_dims : tuple[str, ...] + The dimensions of the posterior samples to use for vectorizing the `input_rvs`. + + + Returns + ------- + vectorized_outputs : list[Variable] + The vectorized variables + + Raises + ------ + RuntimeError + If random variables are found in the graph and `allow_rvs_in_graph` is False + """ + # Identify which free RVs are needed to compute `outputs` + needed_rvs: list[TensorVariable] = [ + cast(TensorVariable, rv) + for rv in ancestors(outputs, blockers=input_rvs) + if rv in set(input_rvs) + ] + + # Replace needed_rvs with actual posterior samples + batch_shape = tuple([len(posterior.coords[dim]) for dim in sample_dims]) + replace_dict: dict[Variable, Variable] = {} + for rv in needed_rvs: + posterior_samples = posterior[rv.name].data + + replace_dict[rv] = pt.constant(posterior_samples.astype(rv.dtype), name=rv.name) + + # Replace the rvs that remain in the graph with resized versions + all_rvs = rvs_in_graph(outputs) + + # Once we give values for the needed_rvs (setting them to their posterior samples), + # we need to identify the rvs that only depend on these conditioned values, and + # don't depend on any other rvs or output nodes. + # These variables need to be resized because they won't be resized implicitly by + # the replacement of the needed_rvs or other random variables in the graph when we + # later call vectorize_graph. + independent_rvs: list[TensorVariable] = [] + for rv in [ + rv + for rv in general_toposort( # type: ignore[call-overload] + all_rvs, lambda x: x.owner.inputs if x.owner is not None else None + ) + if rv in all_rvs + ]: + rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs]) + if ( + rv not in needed_rvs + and not ({*outputs, *independent_rvs} & set(rv_ancestors)) + and {var for var in rv_ancestors if var in all_rvs} <= {rv, *needed_rvs} + ): + independent_rvs.append(rv) + for rv in independent_rvs: + replace_dict[rv] = change_dist_size(rv, new_size=batch_shape, expand=True) + + # Vectorize across samples + vectorized_outputs = list(vectorize_graph(outputs, replace=replace_dict)) + for vectorized_output, output in zip(vectorized_outputs, outputs): + vectorized_output.name = output.name + + if not allow_rvs_in_graph: + remaining_rvs = rvs_in_graph(vectorized_outputs) + if remaining_rvs: + raise RuntimeError( + f"The following random variables found in the extracted graph: {remaining_rvs}" + ) + return vectorized_outputs diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index d3b41bf66..df8bb2dbf 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -28,17 +28,22 @@ from pytensor import Mode, shared from pytensor.compile import SharedVariable from pytensor.graph import graph_inputs +from pytensor.graph.basic import get_var_by_name, variable_depends_on +from pytensor.tensor.variable import TensorConstant from scipy import stats import pymc as pm from pymc.backends.base import MultiTrace +from pymc.logprob.utils import rvs_in_graph +from pymc.model.transform.optimization import freeze_dims_and_data from pymc.pytensorf import compile from pymc.sampling.forward import ( compile_forward_sampling_function, get_constant_coords, get_vars_in_point_list, observed_dependent_deterministics, + vectorize_over_posterior, ) from pymc.testing import fast_unstable_sampling_mode @@ -1801,3 +1806,156 @@ def test_sample_prior_predictive_samples_deprecated_warns() -> None: match = "The samples argument has been deprecated" with pytest.warns(DeprecationWarning, match=match): pm.sample_prior_predictive(model=m, samples=10) + + +@pytest.fixture(params=["deterministic", "observed", "conditioned_on_observed"]) +def variable_to_vectorize(request): + if request.param == "deterministic": + return ["y"] + elif request.param == "conditioned_on_observed": + return ["z", "z_downstream"] + else: + return ["z"] + + +@pytest.fixture(params=["allow_rvs_in_graph", "disallow_rvs_in_graph"]) +def allow_rvs_in_graph(request): + if request.param == "allow_rvs_in_graph": + return True + else: + return False + + +@pytest.fixture(scope="module", params=["nested_random_variables", "no_nested_random_variables"]) +def has_nested_random_variables(request): + return request.param == "nested_random_variables" + + +@pytest.fixture(scope="module") +def model_to_vectorize(has_nested_random_variables): + with pm.Model() as model: + if not has_nested_random_variables: + x_parent = 0.0 + else: + x_parent = pm.Normal("x_parent") + x = pm.Normal("x", mu=x_parent) + d = pm.Data("d", np.array([1, 2, 3])) + obs = pm.Data("obs", np.ones_like(d.get_value())) + y = pm.Deterministic("y", x * d) + z = pm.Gamma("z", mu=pt.exp(y), sigma=pt.exp(y) * 0.1, observed=obs) + pm.Deterministic("z_downstream", z * 2) + + with model: + idata = pm.sample_prior_predictive(100) + idata.add_groups({"posterior": idata.prior}) + return freeze_dims_and_data(model), idata + + +@pytest.fixture(params=["rv_from_posterior", "resample_rv"]) +def input_rv_names(request, has_nested_random_variables): + if request.param == "rv_from_posterior": + if has_nested_random_variables: + return ["x_parent", "x"] + else: + return ["x"] + else: + return [] + + +def test_vectorize_over_posterior( + variable_to_vectorize, + input_rv_names, + allow_rvs_in_graph, + model_to_vectorize, +): + model, idata = model_to_vectorize + + if not allow_rvs_in_graph and (len(input_rv_names) == 0 or "z" in variable_to_vectorize): + with pytest.raises( + RuntimeError, + match="The following random variables found in the extracted graph", + ): + vectorize_over_posterior( + outputs=[model[name] for name in variable_to_vectorize], + posterior=idata.posterior, + input_rvs=[model[name] for name in input_rv_names], + allow_rvs_in_graph=allow_rvs_in_graph, + ) + else: + vectorized = vectorize_over_posterior( + outputs=[model[name] for name in variable_to_vectorize], + posterior=idata.posterior, + input_rvs=[model[name] for name in input_rv_names], + allow_rvs_in_graph=allow_rvs_in_graph, + ) + assert all( + vectorized_var is not model[name] + for vectorized_var, name in zip(vectorized, variable_to_vectorize) + ) + assert all(vectorized_var.type.shape == (1, 100, 3) for vectorized_var in vectorized) + assert all(variable_depends_on(vectorized_var, model["d"]) for vectorized_var in vectorized) + if len(vectorized) == 2: + assert variable_depends_on( + vectorized[variable_to_vectorize.index("z_downstream")], + vectorized[variable_to_vectorize.index("z")], + ) + if len(input_rv_names) > 0: + for input_rv_name in input_rv_names: + if input_rv_name == "x_parent": + assert len(get_var_by_name(vectorized, input_rv_name)) == 0 + else: + [vectorized_rv] = get_var_by_name(vectorized, input_rv_name) + rv_posterior = idata.posterior[input_rv_name].data + assert isinstance(vectorized_rv, TensorConstant) + assert np.all(vectorized_rv.value == rv_posterior) + else: + batch_shape = ( + len(idata.posterior.coords["chain"]), + len(idata.posterior.coords["draw"]), + ) + original_rvs = rvs_in_graph([model[name] for name in variable_to_vectorize]) + expected_rv_shapes = {(*batch_shape, *rv.type.shape) for rv in original_rvs} + rvs = rvs_in_graph(vectorized) + assert {rv.type.shape for rv in rvs} == expected_rv_shapes + + +def test_vectorize_over_posterior_matches_sample(): + rng = np.random.default_rng(1234) + with pm.Model() as model: + x = pm.Normal("x") + sigma = 0.1 + obs = pm.Normal("obs", x, sigma, observed=rng.normal(size=10)) + det = pm.Deterministic("det", obs + 1) + + chains = 2 + draws = 100 + x_posterior = np.broadcast_to(100 * np.arange(chains)[..., None], (chains, draws)) + with model: + posterior = xr.Dataset( + { + "x": xr.DataArray( + x_posterior, + dims=("chain", "draw"), + coords={"chain": np.arange(chains), "draw": np.arange(draws)}, + ) + } + ) + idata = InferenceData(posterior=posterior) + with model: + pp = pm.sample_posterior_predictive(idata, var_names=["obs", "det"], random_seed=1234) + vectorized = vectorize_over_posterior( + outputs=[obs, det], + posterior=posterior, + input_rvs=[x], + allow_rvs_in_graph=True, + ) + [vect_obs, vect_det] = compile(inputs=[], outputs=vectorized, random_seed=1234)() + assert pp.posterior_predictive["obs"].shape == vect_obs.shape + assert pp.posterior_predictive["det"].shape == vect_det.shape + np.testing.assert_allclose(vect_obs + 1, vect_det) + np.testing.assert_allclose( + pp.posterior_predictive["obs"].mean(dim=("chain", "draw")), + vect_obs.mean(axis=(0, 1)), + atol=0.6 / np.sqrt(10000), + ) + assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1)