-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add vectorize_over_posterior
to pymc.sampling.forward
#7841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
cade4b6
to
88699ec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, some minor suggestions and complaints. This PR reminds me that I don't think we have vectorize_node
implemented for SymbolicRandomVariables. Shoud be easy to implement, at least for those we have a signature (it's how we know how to resize them automatically)
pymc/pymc/distributions/distribution.py
Lines 395 to 414 in 3ae5095
@_change_dist_size.register(SymbolicRandomVariable) | |
def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) -> TensorVariable: | |
extended_signature = op.extended_signature | |
if extended_signature is None: | |
raise NotImplementedError( | |
f"SymbolicRandomVariable {op} without signature requires custom `_change_dist_size` implementation." | |
) | |
size = op.size_param(rv.owner) | |
if size is None: | |
raise NotImplementedError( | |
f"SymbolicRandomVariable {op} without [size] in extended_signature requires custom `_change_dist_size` implementation." | |
) | |
params = op.dist_params(rv.owner) | |
if expand and not rv_size_is_none(size): | |
new_size = tuple(new_size) + tuple(size) | |
return op.rebuild_rv(*params, size=new_size) |
I'm not suggesting we should do it in this PR, just reminded me |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7841 +/- ##
==========================================
+ Coverage 92.91% 92.92% +0.01%
==========================================
Files 107 107
Lines 18286 18313 +27
==========================================
+ Hits 16991 17018 +27
Misses 1295 1295
🚀 New features to boost your workflow:
|
20972c9
to
b2af12f
Compare
b2af12f
to
d610a0f
Compare
Looks great, failing test is unrelated I've seen it fail sporadically |
This PR adds the function
pymc.sampling.forward.vectorize_over_posterior
.This function basically vectorizes the computation of a list of output variables over possible values of random variables or deterministics that are stored in an
xarray.Dataset
. Any extra random variable remaining in the graph will be vectorized as well, either by changing its size explicitly to match the batch ofchain * draw
dimensions, or implicitly given its inputs.Type of change
📚 Documentation preview 📚: https://pymc--7841.org.readthedocs.build/en/7841/