Implement vectorize_graph for XTensor Ops#1876
Conversation
e06d0e3 to
796ccf6
Compare
db7f970 to
42a2e8d
Compare
|
I removed the public vectorize_node, because there's no point to it. If you want to vectorize a single node, just call |
2f984bc to
2ea456c
Compare
2ea456c to
637e423
Compare
637e423 to
df27074
Compare
jessegrabowski
left a comment
There was a problem hiding this comment.
Approved with comments. I don't think any of them rise to blocker level.
|
|
||
| with pytest.raises(ValueError, match=msg): | ||
| vectorize_graph( | ||
| variable_shape_out, {variable_shape: non_variable_batch_shape} |
There was a problem hiding this comment.
In this case it seems like we could have tried to cast to variable for the user. We do that in most user-facing functions, why not here?
There was a problem hiding this comment.
Because what do we cast it to? TensorVariable, SparseVariable, XTensorVariable? There's no utility for convert x to the same type (but potentially different ndim)
| for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True): | ||
|
|
||
| vect_node_or_outputs = _vectorize_node(node.op, node, *vect_inputs) | ||
| # Compatibility with the old API |
There was a problem hiding this comment.
This needs a bit more documentation, I wouldn't have understood the Apply change without our call yesterday
There was a problem hiding this comment.
It's pretty internal though, the type signature shows _vectorize_node returns Sequence[Variable] | Apply and I was hoping vect_node_or_outputs -> vect_outputs further drives the point home. I want to deprecate the Apply eventually as well, but don't want to do it in this PR
| ) | ||
| ) | ||
| old_x, *_ = node.inputs | ||
| new_x, *_ = broadcast( |
There was a problem hiding this comment.
does the exclude here mean that you can't vectorize on a dimension that already exists in one of the inputs, or is that automatically handled somewhere else.
I tried to write out a specific example a few times but since it's an indexing Op I'm not clear enough on how broadcasting will work. If I have x = xtensor(dims=('a', )) can I index with idx = xtensor(dims=('b',)), or is that already nonsense? If it is possible, I am asking whether we can replace x with x_batch = xtensor(dims=('b', 'a'))
There was a problem hiding this comment.
Your first question, yes you can index with xtensor(dims=('b',)). The way indexing in xarray/xtensor works is: the input has some dimensions, the index variables as other dimensions. It defines a map of input_dims -> dimensions_after_indexing. So xtensor(dims=("a",))[xtensor(dims=("b", dtype=int))] -> has dims("b",). Indexing converted a subset of a dim into b dim. There's special consideration if b already existed in the input or in another indexing variable, but let's not get there. Indexing gets hairy pretty fast.
On vectorization. No, you can't do x_batch = xtensor(dims=("b", "a")) The constraint is you can never introduce a new dimension in the batch inputs that already existed in the graph (otherwise it would interact with the core_graph, so it wouldn't be a true vectorization). The logic is, you should always be able to implement vectorization by doing a loop with each entry of the batch inputs at a time and then concatenate the results along the new dimension.
There was a problem hiding this comment.
This is enforced in vectorize_x_node:
# Or have new dimensions that were already in the graph
if new_core_dims := ((new_dims_set - old_dims_set) & all_old_dims_set):
raise ValueError(
f"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}"
)| new_extra_dim_lengths, new_params = ( | ||
| new_extra_dim_lengths_and_params[: len(self.extra_dims)], | ||
| new_extra_dim_lengths_and_params[len(self.extra_dims) :], | ||
| ) |
There was a problem hiding this comment.
k = len(self.extra_dims)
new_extra_dim_lengths = new_extra_dim_lengths_and_params[:k]
new_params = new_extra_dim_lengths_and_params[k:]Doing it all in one shot is a bit too cute imo, because of the lengths of the variable names.
There was a problem hiding this comment.
This way you could go straight for the squeeze listcomp too
| ) -> Sequence[Variable]: | ||
| """Returns vectorized version of node with new batched inputs.""" | ||
|
|
||
| all_old_dims_set = set( |
There was a problem hiding this comment.
I feel like I've read this pattern a few times, might merit extraction into a helper.
| isinstance(key, Variable) and isinstance(value, Variable) | ||
| for key, value in replace.items() | ||
| ): | ||
| raise ValueError(f"Some of the replaced items are not Variables: {replace}") |
There was a problem hiding this comment.
Filter only the offending items in replace for the error?
There was a problem hiding this comment.
I didn't do it, because I'm iterating over the keys and values at once. What if it's a key and another value that are not variables?
They would be silently ignored otherwise
…turn list of variables
df27074 to
e56c37f
Compare
to #The biggest challenge is when vectorization would imply new dimensions (say
xtensor_from_tensor). There is no trivial way to allow the user to tell us what those dimension labels should be.So there's a new more powerful
vectorize_graphinxtensor/vectorization.py. If a regulargraph_replacewould fail because of batched xtensor_from_tensor an informative error is raised, pointing users to the new function.I think it was a mistake to advertise
vectorize_graphin a generalgraph.replacemodule. Because the semantics of vectorization depend on the types used in the graph, and different types may not be mutually compatible. Long term I would perhaps move the existing function to the tensor submodule.Limitation wise: sized operations (where numerical inputs determine output sizes). Those can't be batched. Note that batch dimension is either redundant or implies a ragged output. This is equally tricky in tensor graphs. Sometimes we fallback to a Blockwise and try to rewrite away later when we can prove it's a homogenous input, or fail at runtime. We can revisit that option later, the methods raise a
NotImplementedErroraccordingly.Implementation detail: I decided to implement most methods as an Op property. I think this should be the default, as it keeps code organized. Users can still implement their custom dispatch if it's missing, and we can offer a two level dispatch (not to accidentally override the error checking that is now done for all Ops before dispatching)
Other cleanup: Neither of the
vectorize_graphfunctions requires the individual node functions to return an Apply. This was an artifical requirement and sometimes we had to return identity nodes. For complex Ops that may be better vectorized by different operations, the old design would be even more absurd. Related to #902 (I wouldn't close until we cleanup pre-existing implementations and deprecate the return of an Apply option)