Skip to content

BUG: pymc.sample raises SamplingError when using ExpQuad with HSGP. #7881

@Armatron44

Description

@Armatron44

Describe the issue:

Hi all,

I've just updated my pymc to 5.25.1 and I'm finding that attempting to sample a HSGP with the ExpQuad covariance function is raising a SamplingError: Initial evaluation of model at starting point failed!. This appears to relate to the logp value of the HSGP at initialisation, which is inf according to the traceback (see below). There doesn't appear to be an issue when I swap out ExpQuad for something like Matern32. I've recreated this issue on google colab with pymc 5.25.1. I was working with pymc 5.23 recently on a different computer, which ran ExpQuad HSGPs fine, although I can't test this now.

Reproduceable code example:

import pymc as pm
import numpy as np

# fake some data
x = np.sort(np.random.uniform(-1, 1, 101))
y = 3*np.cos(x*0.9) - 1
y += np.random.normal(scale=0.05, size=101)

with pm.Model(coords={"basis_coeffs": np.arange(200), "obs_id": np.arange(y.size)}) as model:
    ell = pm.Exponential("ell", scale=1) # dont @ me for these priors...
    eta = pm.Exponential("eta", scale=1.0)
    cov_func = eta**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=ell) # fails with pymc sampler, nutpie, numpyro, blackjax (so possibly not to do with nuts_sampler kwarg in pm.sample...)
    #cov_func = eta**2 * pm.gp.cov.Matern32(input_dim=1, ls=ell) # this works with all

    m, c = 200, 1.5
    gp = pm.gp.HSGP(m=[m], c=c, parametrization="centered", cov_func=cov_func)
    f = gp.prior("f", X=x[:, None], hsgp_coeffs_dims="basis_coeffs", gp_dims="obs_id")

    sigma = pm.Exponential("sigma", scale=1.0)
    pm.Normal("y_obs", mu=f, sigma=sigma, observed=y, dims="obs_id")

    idata = pm.sample()

Error message:

---------------------------------------------------------------------------
SamplingError                             Traceback (most recent call last)
Cell In[16], line 16
     13 sigma = pm.Exponential("sigma", scale=1.0)
     14 pm.Normal("y_obs", mu=f, sigma=sigma, observed=y, dims="obs_id")
---> 16 idata = pm.sample()

File [~\.local\share\mamba\envs\pymc_dev\Lib\site-packages\pymc\sampling\mcmc.py:825](http://localhost:8888/~/.local/share/mamba/envs/pymc_dev/Lib/site-packages/pymc/sampling/mcmc.py#line=824), in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    823         [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
    824     with joined_blas_limiter():
--> 825         initial_points, step = init_nuts(
    826             init=init,
    827             chains=chains,
    828             n_init=n_init,
    829             model=model,
    830             random_seed=random_seed_list,
    831             progressbar=progress_bool,
    832             jitter_max_retries=jitter_max_retries,
    833             tune=tune,
    834             initvals=initvals,
    835             compile_kwargs=compile_kwargs,
    836             **kwargs,
    837         )
    838 else:
    839     # Get initial points
    840     ipfns = make_initial_point_fns_per_chain(
    841         model=model,
    842         overrides=initvals,
    843         jitter_rvs=set(),
    844         chains=chains,
    845     )

File [~\.local\share\mamba\envs\pymc_dev\Lib\site-packages\pymc\sampling\mcmc.py:1598](http://localhost:8888/~/.local/share/mamba/envs/pymc_dev/Lib/site-packages/pymc/sampling/mcmc.py#line=1597), in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs)
   1595     q, _ = DictToArrayBijection.map(ip)
   1596     return logp_dlogp_func([q], extra_vars={})[0]
-> 1598 initial_points = _init_jitter(
   1599     model,
   1600     initvals,
   1601     seeds=random_seed_list,
   1602     jitter="jitter" in init,
   1603     jitter_max_retries=jitter_max_retries,
   1604     logp_fn=model_logp_fn,
   1605 )
   1607 apoints = [DictToArrayBijection.map(point) for point in initial_points]
   1608 apoints_data = [apoint.data for apoint in apoints]

File [~\.local\share\mamba\envs\pymc_dev\Lib\site-packages\pymc\sampling\mcmc.py:1479](http://localhost:8888/~/.local/share/mamba/envs/pymc_dev/Lib/site-packages/pymc/sampling/mcmc.py#line=1478), in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_fn)
   1476 if not np.isfinite(point_logp):
   1477     if i == jitter_max_retries:
   1478         # Print informative message on last attempted point
-> 1479         model.check_start_vals(point)
   1480     # Retry with a new seed
   1481     seed = rng.integers(2**30, dtype=np.int64)

File [~\.local\share\mamba\envs\pymc_dev\Lib\site-packages\pymc\model\core.py:1761](http://localhost:8888/~/.local/share/mamba/envs/pymc_dev/Lib/site-packages/pymc/model/core.py#line=1760), in Model.check_start_vals(self, start, **kwargs)
   1758 initial_eval = self.point_logps(point=elem, **kwargs)
   1760 if not all(np.isfinite(v) for v in initial_eval.values()):
-> 1761     raise SamplingError(
   1762         "Initial evaluation of model at starting point failed!\n"
   1763         f"Starting value[s:\n](file:///S:/n){elem}\n\n"
   1764         f"Logp initial evaluation result[s:\n](file:///S:/n){initial_eval}\n"
   1765         "You can call `model.debug()` for more details."
   1766     )

SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'ell_log__': array(0.06154412), 'eta_log__': array(-0.07390129), 'f_hsgp_coeffs': array([ 0.8964528 ,  0.49850679,  0.77766117, -0.21469201, -0.83229329,
       -0.2429649 , -0.1551994 , -0.16094197, -0.66134098,  0.08932747,
        0.85003664, -0.96698846, -0.24095418,  0.95928666,  0.40584951,
        0.65621531,  0.34217628, -0.13498463,  0.6735691 , -0.91265536,
       -0.05148727,  0.37550022, -0.66848646, -0.237498  ,  0.13401424,
        0.07616318, -0.76062642,  0.42537257, -0.78619399,  0.71017075,
        0.62232671,  0.76061207, -0.25878416, -0.71957469,  0.75449224,
       -0.2458921 , -0.64380881, -0.88398595, -0.73363227, -0.72695346,
       -0.26828684,  0.64891776, -0.68961931,  0.72908515, -0.42343627,
        0.24523088, -0.50362676, -0.80204453, -0.47411123,  0.06919655,
       -0.85278136, -0.6872726 , -0.7074239 , -0.97904535, -0.2096503 ,
        0.41902197,  0.25750279,  0.16304053, -0.37161017, -0.36869419,
        0.87463671, -0.99804548, -0.2472362 , -0.99437107,  0.17233818,
       -0.53704303, -0.70933562, -0.6216585 , -0.74211035, -0.11780913,
       -0.33046545, -0.10765366,  0.09696944, -0.68235125, -0.78363395,
       -0.53045928, -0.17417613,  0.691059  , -0.05228136, -0.38724882,
        0.35066208, -0.5149922 , -0.77655213,  0.45167872,  0.96291537,
       -0.74180878,  0.47324007,  0.07420529,  0.45694168, -0.19554454,
       -0.08631478,  0.40328765, -0.82952522,  0.33224662,  0.06260759,
       -0.54895729,  0.75930369, -0.3085233 , -0.79609509,  0.82898824,
        0.53739623,  0.30328473, -0.90124674, -0.64246727, -0.21607528,
       -0.04892372,  0.90662235,  0.71510085, -0.22509855, -0.26623875,
       -0.641338  , -0.75124308, -0.8214267 ,  0.5451419 ,  0.02570617,
        0.4018908 , -0.1126687 , -0.31593296,  0.0362656 ,  0.76238948,
        0.3919529 , -0.27760741, -0.10068226, -0.04583653,  0.74203014,
        0.75065354,  0.54871431, -0.64430454,  0.53359048, -0.97495406,
       -0.73663779,  0.33514719,  0.69741655, -0.53137909,  0.78693164,
        0.17234047,  0.74777694, -0.1744733 ,  0.7607344 , -0.86257238,
       -0.17365085, -0.82280093,  0.7484344 ,  0.88597422, -0.91898113,
       -0.77001598, -0.95169786, -0.38264231, -0.81062648, -0.47486829,
        0.64959473,  0.24822373, -0.02866922,  0.4009877 ,  0.95769761,
        0.49782551, -0.99525773, -0.53563913,  0.77753391, -0.09806273,
       -0.20661118,  0.65396797, -0.97223146,  0.96095015, -0.66162998,
        0.07860522, -0.2993302 ,  0.03866622,  0.18522487,  0.54755709,
       -0.08540239,  0.18920188,  0.55557113,  0.26725513,  0.54615184,
       -0.18610811, -0.38184134, -0.4754358 , -0.54684329, -0.50546527,
        0.85735608, -0.89882684,  0.16445413, -0.86161286, -0.13556939,
       -0.8191506 ,  0.95764276,  0.55347055, -0.23243567, -0.35112965,
        0.34498401,  0.93611154,  0.51646511, -0.95735037,  0.28280915,
        0.18386817, -0.54182348,  0.48228643, -0.53230718, -0.82028976]), 'sigma_log__': array(0.69110407)}

Logp initial evaluation results:
{'ell': np.float64(-1.0), 'eta': np.float64(-1.0), 'f_hsgp_coeffs': np.float64(-inf), 'sigma': np.float64(-1.3), 'y_obs': np.float64(-534.46)}
You can call `model.debug()` for more details.

PyMC version information:

python 3.13.5
pymc 5.25.1
pytensor 2.31.7
numpy 2.2.6

Context for the issue:

RBF is the text book covariance function and is a go-to for exploring modelling data with GPs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    GPGaussian Processdocs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions