Add torch/jax array ingest support to OMEArrow#56
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds direct PyTorch/JAX tensor ingest: public Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant OME as OMEArrow.__init__
participant Detect as TypeDetection
participant Torch as from_torch_array
participant JAX as from_jax_array
participant NumPy as from_numpy
participant Result as OMEArrowInstance
User->>OME: OMEArrow(input_tensor, dim_order?)
OME->>Detect: inspect input type
alt torch.Tensor
Detect->>Torch: route to from_torch_array
Torch->>Torch: normalize (dense, CPU), detach, convert to NumPy
Torch->>Torch: infer dim_order if None
Torch->>NumPy: delegate to from_numpy(...)
else jax.Array
Detect->>JAX: route to from_jax_array
JAX->>JAX: convert to host NumPy
JAX->>JAX: infer dim_order if None
JAX->>NumPy: delegate to from_numpy(...)
else other
Detect->>NumPy: existing from_numpy/from_source flow
end
NumPy->>Result: build flattened NumPy buffers, return StructScalar / OMEArrow instance
Result->>User: constructed OMEArrow object
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR adds first-class ingestion of PyTorch tensors and JAX arrays into the OMEArrow workflow, allowing users to construct OME-Arrow records directly from in-memory deep learning tensors (with optional dim_order overrides), and documents the new functionality.
Changes:
- Add
from_torch_array/from_jax_arrayingest helpers plus tensor-rank-based defaultdim_orderinference. - Extend
OMEArrow(...)constructor to accept torch/jax inputs and an optionaldim_orderparameter (restricted to array/tensor inputs). - Add tests and documentation covering tensor ingest and the new APIs.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
src/ome_arrow/ingest.py |
Adds torch/jax ingest helpers, backend detection, and rank-based default dim_order inference. |
src/ome_arrow/core.py |
Extends OMEArrow constructor to accept torch/jax inputs and optional dim_order. |
src/ome_arrow/__init__.py |
Re-exports new ingest helpers at package top-level. |
tests/test_core.py |
Adds unit tests for torch/jax constructor and explicit dim_order behavior. |
README.md |
Documents direct tensor ingest and helper-based ingest. |
docs/src/dlpack.md |
Adds examples for direct constructor ingest for torch/jax. |
docs/src/python-api.md |
Adds autodoc sections for core and ingest modules. |
Comments suppressed due to low confidence (1)
src/ome_arrow/core.py:105
- The constructor now supports
torch.Tensorandjax.Array, but thedatatype annotation still only advertisesstr | dict | pa.StructScalar | np.ndarray. This makes the public API signature inconsistent with the documented/implemented behavior and hurts static type checking. Consider broadening the annotation (e.g., to includeAny/protocols for torch/jax) while keeping optional deps out of import-time typing.
def __init__(
self,
data: str | dict | pa.StructScalar | "np.ndarray",
dim_order: str | None = None,
tcz: Tuple[int, int, int] = (0, 0, 0),
column_name: str = "ome_arrow",
row_index: int = 0,
image_type: str | None = None,
lazy: bool = False,
) -> None:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/ome_arrow/ingest.py`:
- Around line 929-1003: The project requires JAX >=0.4.1 because from_jax_array
uses the jax.Array type; update the dependency specification from "jax>=0.4" to
"jax>=0.4.1" in the project dependency manifests (e.g., pyproject.toml /
requirements files) so that callers installing the package get the minimum JAX
version that provides jax.Array referenced in from_jax_array; no code changes to
the from_jax_array function itself are needed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 79fb9752-2635-42d2-bd41-4f7b23196c04
📒 Files selected for processing (7)
README.mddocs/src/dlpack.mddocs/src/python-api.mdsrc/ome_arrow/__init__.pysrc/ome_arrow/core.pysrc/ome_arrow/ingest.pytests/test_core.py
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/test_core.py (1)
398-469: Optional: parameterize Torch/JAX parity tests to reduce duplication.The torch/jax test pairs are structurally identical. A backend-parameterized helper would keep this easier to extend and maintain.
♻️ Example refactor sketch
+@pytest.mark.parametrize("backend", ["torch", "jax"]) +def test_constructor_accepts_array_backend(backend: str) -> None: + if backend == "torch": + torch = pytest.importorskip("torch") + arr = torch.arange(2 * 3 * 4).reshape(2, 3, 4).to(dtype=torch.uint16) + as_numpy = arr.numpy() + else: + jnp = pytest.importorskip("jax.numpy") + arr = jnp.arange(2 * 3 * 4, dtype=jnp.uint16).reshape(2, 3, 4) + as_numpy = np.asarray(arr) + + oa = OMEArrow(arr) + exported = oa.export(how="numpy") + assert exported.shape == (1, 1, 2, 3, 4) + np.testing.assert_array_equal(exported[0, 0], as_numpy)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_core.py` around lines 398 - 469, Several test pairs (test_constructor_accepts_torch_tensor / test_constructor_accepts_jax_array, test_from_torch_array_explicit_dim_order / test_from_jax_array_explicit_dim_order, test_constructor_dim_order_override_torch_tensor / test_constructor_dim_order_override_jax_array) are duplicate; parameterize them over backend to remove duplication. Create a single parameterized test using pytest.mark.parametrize with a backend param ("torch" or "jax"), inside the test call pytest.importorskip for the backend, map backend to the appropriate constructors and helpers (e.g., OMEArrow, ingest.from_torch_array vs ingest.from_jax_array) and to the conversion method (tensor.numpy() vs np.asarray(arr)), then replace the duplicate test functions (the six named tests) with the parameterized versions that branch using those mapped symbols (OMEArrow, ingest.from_* and conversion) so behavior is identical while removing duplicated code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/test_core.py`:
- Around line 398-469: Several test pairs (test_constructor_accepts_torch_tensor
/ test_constructor_accepts_jax_array, test_from_torch_array_explicit_dim_order /
test_from_jax_array_explicit_dim_order,
test_constructor_dim_order_override_torch_tensor /
test_constructor_dim_order_override_jax_array) are duplicate; parameterize them
over backend to remove duplication. Create a single parameterized test using
pytest.mark.parametrize with a backend param ("torch" or "jax"), inside the test
call pytest.importorskip for the backend, map backend to the appropriate
constructors and helpers (e.g., OMEArrow, ingest.from_torch_array vs
ingest.from_jax_array) and to the conversion method (tensor.numpy() vs
np.asarray(arr)), then replace the duplicate test functions (the six named
tests) with the parameterized versions that branch using those mapped symbols
(OMEArrow, ingest.from_* and conversion) so behavior is identical while removing
duplicated code.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b43d19b5-3126-4edc-93da-fc53bba92fb5
📒 Files selected for processing (4)
README.mddocs/src/dlpack.mdsrc/ome_arrow/ingest.pytests/test_core.py
✅ Files skipped from review due to trivial changes (2)
- docs/src/dlpack.md
- README.md
| Args: | ||
| data: Input source or record payload. | ||
| dim_order: Axis labels used only for array/tensor ingest | ||
| (NumPy, torch, JAX). Ignored inputs are rejected to prevent |
There was a problem hiding this comment.
Are these the only arrays/tensors that can use dim_order? If so, I think you mention the arrays that can use dim_order in the init function docstring above, except for in the numpy array comment (line 114)
Co-Authored-By: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com>
Co-Authored-By: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com>
Co-Authored-By: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/ome_arrow/core.py`:
- Around line 690-693: The view method currently allows unsupported render modes
to fall through and return None; update the OMEArrow.view implementation to
validate the how parameter and raise a ValueError for unsupported values instead
of returning None — add a final raise ValueError (with a clear message including
the invalid how) after the existing "pyvista" branch (and/or check against the
allowed modes at the top of the method) so calling OMEArrow(...).view(how="foo")
raises the documented ValueError.
- Around line 96-100: The constructor signature change shifted positional args
so calls like OMEArrow(arr, (0,0,1)) now bind the tuple to dim_order instead of
tcz; restore the original positional ABI by ensuring tcz remains the second
positional parameter and make dim_order keyword-only. Update __init__ so tcz
appears before dim_order (e.g. def __init__(self, data, tcz=(0,0,0), *,
dim_order=None, ...)) or otherwise place a bare * before dim_order; adjust
references inside __init__ accordingly (symbols: __init__, tcz, dim_order,
OMEArrow) so existing positional call sites keep working.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9921484b-9d60-481a-8577-16fc76b742aa
📒 Files selected for processing (4)
README.mddocs/src/dlpack.mdsrc/ome_arrow/core.pytests/test_core.py
✅ Files skipped from review due to trivial changes (2)
- README.md
- docs/src/dlpack.md
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/test_core.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/ome_arrow/core.py`:
- Around line 687-689: The docstring for view() incorrectly states it returns a
bare matplotlib.figure.Figure while the matplotlib branch returns whatever
view_matplotlib(...) yields (which tests index with [0]); update the Returns
section to accurately describe the actual object returned by the matplotlib
branch (e.g., a tuple like (matplotlib.figure.Figure, matplotlib.axes.Axes) or
the specific object signature returned by view_matplotlib), referencing view()
and view_matplotlib(...) so callers/tests know to index [0], or alternatively
modify view() to unwrap view_matplotlib(...) and return only the Figure—choose
one consistent approach and update the docstring and any callers/tests
accordingly.
- Around line 767-769: Validate the 'how' argument before triggering
materialization: inside the method that calls self._ensure_materialized() (the
view/display method that accepts the how parameter), add an early check that how
is one of the allowed values ('matplotlib' or 'pyvista') and raise the
ValueError there if not, so the check runs before calling
self._ensure_materialized(); keep the existing error message text but perform
this validation prior to invoking self._ensure_materialized().
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 48747856-e5f2-4546-88e0-31535ca2a1fc
📒 Files selected for processing (2)
src/ome_arrow/core.pytests/test_core.py
…e-arrow into from-torch-or-jax-array
|
Thank you @gwaybio and @MattsonCam for your reviews! After addressing things I'll now merge this in. |
Description
This PR enables OMEArrow to ingest torch or jax arrays directly. It was inspired from comments made by @MattsonCam .
What kind of change(s) are included?
Checklist
Please ensure that all boxes are checked before indicating that this pull request is ready for review.
Summary by CodeRabbit
New Features
Documentation
Performance
Tests
Chores