Skip to content

Comments

Implement vectorize_graph for XTensor Ops#1876

Merged
ricardoV94 merged 5 commits intopymc-devs:mainfrom
ricardoV94:xtensor_vectorize
Feb 19, 2026
Merged

Implement vectorize_graph for XTensor Ops#1876
ricardoV94 merged 5 commits intopymc-devs:mainfrom
ricardoV94:xtensor_vectorize

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 5, 2026

to #The biggest challenge is when vectorization would imply new dimensions (say xtensor_from_tensor). There is no trivial way to allow the user to tell us what those dimension labels should be.

So there's a new more powerful vectorize_graph in xtensor/vectorization.py. If a regular graph_replace would fail because of batched xtensor_from_tensor an informative error is raised, pointing users to the new function.

I think it was a mistake to advertise vectorize_graph in a general graph.replace module. Because the semantics of vectorization depend on the types used in the graph, and different types may not be mutually compatible. Long term I would perhaps move the existing function to the tensor submodule.

Limitation wise: sized operations (where numerical inputs determine output sizes). Those can't be batched. Note that batch dimension is either redundant or implies a ragged output. This is equally tricky in tensor graphs. Sometimes we fallback to a Blockwise and try to rewrite away later when we can prove it's a homogenous input, or fail at runtime. We can revisit that option later, the methods raise a NotImplementedError accordingly.

Implementation detail: I decided to implement most methods as an Op property. I think this should be the default, as it keeps code organized. Users can still implement their custom dispatch if it's missing, and we can offer a two level dispatch (not to accidentally override the error checking that is now done for all Ops before dispatching)

Other cleanup: Neither of the vectorize_graph functions requires the individual node functions to return an Apply. This was an artifical requirement and sometimes we had to return identity nodes. For complex Ops that may be better vectorized by different operations, the old design would be even more absurd. Related to #902 (I wouldn't close until we cleanup pre-existing implementations and deprecate the return of an Apply option)

@ricardoV94
Copy link
Member Author

I removed the public vectorize_node, because there's no point to it. If you want to vectorize a single node, just call vectorize_graph on the outputs. It was actually masking issues (one test is failing because of that), and forcing us to test cases that can't be generated from vectorize_graph at all.

@ricardoV94 ricardoV94 requested review from jessegrabowski and removed request for jessegrabowski February 10, 2026 16:56
@ricardoV94 ricardoV94 force-pushed the xtensor_vectorize branch 6 times, most recently from 2f984bc to 2ea456c Compare February 14, 2026 17:20
@ricardoV94 ricardoV94 marked this pull request as ready for review February 14, 2026 17:24
@ricardoV94 ricardoV94 changed the title Implement vectorize for XTensor Ops Implement vectorize_graph for XTensor Ops Feb 14, 2026
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved with comments. I don't think any of them rise to blocker level.


with pytest.raises(ValueError, match=msg):
vectorize_graph(
variable_shape_out, {variable_shape: non_variable_batch_shape}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case it seems like we could have tried to cast to variable for the user. We do that in most user-facing functions, why not here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because what do we cast it to? TensorVariable, SparseVariable, XTensorVariable? There's no utility for convert x to the same type (but potentially different ndim)

for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):

vect_node_or_outputs = _vectorize_node(node.op, node, *vect_inputs)
# Compatibility with the old API
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a bit more documentation, I wouldn't have understood the Apply change without our call yesterday

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pretty internal though, the type signature shows _vectorize_node returns Sequence[Variable] | Apply and I was hoping vect_node_or_outputs -> vect_outputs further drives the point home. I want to deprecate the Apply eventually as well, but don't want to do it in this PR

)
)
old_x, *_ = node.inputs
new_x, *_ = broadcast(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the exclude here mean that you can't vectorize on a dimension that already exists in one of the inputs, or is that automatically handled somewhere else.

I tried to write out a specific example a few times but since it's an indexing Op I'm not clear enough on how broadcasting will work. If I have x = xtensor(dims=('a', )) can I index with idx = xtensor(dims=('b',)), or is that already nonsense? If it is possible, I am asking whether we can replace x with x_batch = xtensor(dims=('b', 'a'))

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your first question, yes you can index with xtensor(dims=('b',)). The way indexing in xarray/xtensor works is: the input has some dimensions, the index variables as other dimensions. It defines a map of input_dims -> dimensions_after_indexing. So xtensor(dims=("a",))[xtensor(dims=("b", dtype=int))] -> has dims("b",). Indexing converted a subset of a dim into b dim. There's special consideration if b already existed in the input or in another indexing variable, but let's not get there. Indexing gets hairy pretty fast.

On vectorization. No, you can't do x_batch = xtensor(dims=("b", "a")) The constraint is you can never introduce a new dimension in the batch inputs that already existed in the graph (otherwise it would interact with the core_graph, so it wouldn't be a true vectorization). The logic is, you should always be able to implement vectorization by doing a loop with each entry of the batch inputs at a time and then concatenate the results along the new dimension.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is enforced in vectorize_x_node:

        # Or have new dimensions that were already in the graph
        if new_core_dims := ((new_dims_set - old_dims_set) & all_old_dims_set):
            raise ValueError(
                f"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}"
            )

Comment on lines 305 to 308
new_extra_dim_lengths, new_params = (
new_extra_dim_lengths_and_params[: len(self.extra_dims)],
new_extra_dim_lengths_and_params[len(self.extra_dims) :],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k = len(self.extra_dims)
new_extra_dim_lengths = new_extra_dim_lengths_and_params[:k]
new_params = new_extra_dim_lengths_and_params[k:]

Doing it all in one shot is a bit too cute imo, because of the lengths of the variable names.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way you could go straight for the squeeze listcomp too

) -> Sequence[Variable]:
"""Returns vectorized version of node with new batched inputs."""

all_old_dims_set = set(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like I've read this pattern a few times, might merit extraction into a helper.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What pattern?

isinstance(key, Variable) and isinstance(value, Variable)
for key, value in replace.items()
):
raise ValueError(f"Some of the replaced items are not Variables: {replace}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filter only the offending items in replace for the error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't do it, because I'm iterating over the keys and values at once. What if it's a key and another value that are not variables?

@ricardoV94 ricardoV94 merged commit 8747006 into pymc-devs:main Feb 19, 2026
68 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants