Skip to content

Add torch/jax array ingest support to OMEArrow#56

Merged
d33bs merged 20 commits intoWayScience:mainfrom
d33bs:from-torch-or-jax-array
Apr 8, 2026
Merged

Add torch/jax array ingest support to OMEArrow#56
d33bs merged 20 commits intoWayScience:mainfrom
d33bs:from-torch-or-jax-array

Conversation

@d33bs
Copy link
Copy Markdown
Member

@d33bs d33bs commented Apr 6, 2026

Description

This PR enables OMEArrow to ingest torch or jax arrays directly. It was inspired from comments made by @MattsonCam .

What kind of change(s) are included?

  • Documentation (changes docs or other related content)
  • Bug fix (fixes an issue).
  • Enhancement (adds functionality).
  • Breaking change (these changes would cause existing functionality to not work as expected).

Checklist

Please ensure that all boxes are checked before indicating that this pull request is ready for review.

  • I have read and followed the CONTRIBUTING.md guidelines.
  • I have searched for existing content to ensure this is not a duplicate.
  • I have performed a self-review of these additions (including spelling, grammar, and related).
  • These changes pass all pre-commit checks.
  • I have added comments to my code to help provide understanding
  • I have added a test which covers the code changes found within this PR
  • I have deleted all non-relevant text in this pull request template.

Summary by CodeRabbit

  • New Features

    • Direct ingestion of PyTorch tensors and JAX arrays with rank‑based default dim_order and an explicit dim_order override; new helper ingest APIs for torch/jax
  • Documentation

    • README and docs expanded with examples, usage notes, batch/time guidance, and optional backend install instructions
  • Performance

    • More efficient buffering: flattened NumPy pixel buffers used instead of materializing per‑plane Python lists
  • Tests

    • Added tests covering Torch/JAX ingestion and dim_order behaviors
  • Chores

    • Tightened optional JAX version constraint in project metadata

Copilot AI review requested due to automatic review settings April 6, 2026 21:05
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 6, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds direct PyTorch/JAX tensor ingest: public from_torch_array / from_jax_array entrypoints, OMEArrow.__init__ accepts dim_order and detects torch/jax inputs, tensors are normalized to host NumPy with rank-based dim_order inference or explicit override, ingestion uses flattened NumPy buffers (no per-plane Python lists); docs, tests, and exports updated.

Changes

Cohort / File(s) Summary
Documentation
README.md, docs/src/dlpack.md, docs/src/python-api.md, docs/src/index.md
Documented direct tensor ingestion, rank-based dim_order defaults and overrides, examples for PyTorch/JAX, added Sphinx automodule blocks and TOC entry.
Core API
src/ome_arrow/core.py
OMEArrow.__init__ adds `dim_order: str
Ingest helpers
src/ome_arrow/ingest.py
Switched plane/chunk payloads from Python lists to flattened NumPy arrays, changed default empty-planes to np.zeros(...), added _is_torch_array, _is_jax_array, _infer_dim_order_for_tensor_rank, _from_array_via_numpy, and public from_torch_array(...) / from_jax_array(...) that normalize/convert to host NumPy, infer/apply dim_order, and delegate to from_numpy.
Module exports
src/ome_arrow/__init__.py
Re-exported from_torch_array and from_jax_array.
Tests
tests/test_core.py
Added backend-parameterized tests for Torch/JAX ingestion, dim_order inference and explicit overrides, export shape checks, negative test for dim_order on non-array sources, and view(how="foo") error test.
Build config
pyproject.toml
Tightened optional dependency spec: jax>=0.4.1 (was >=0.4).

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant OME as OMEArrow.__init__
    participant Detect as TypeDetection
    participant Torch as from_torch_array
    participant JAX as from_jax_array
    participant NumPy as from_numpy
    participant Result as OMEArrowInstance

    User->>OME: OMEArrow(input_tensor, dim_order?)
    OME->>Detect: inspect input type
    alt torch.Tensor
        Detect->>Torch: route to from_torch_array
        Torch->>Torch: normalize (dense, CPU), detach, convert to NumPy
        Torch->>Torch: infer dim_order if None
        Torch->>NumPy: delegate to from_numpy(...)
    else jax.Array
        Detect->>JAX: route to from_jax_array
        JAX->>JAX: convert to host NumPy
        JAX->>JAX: infer dim_order if None
        JAX->>NumPy: delegate to from_numpy(...)
    else other
        Detect->>NumPy: existing from_numpy/from_source flow
    end
    NumPy->>Result: build flattened NumPy buffers, return StructScalar / OMEArrow instance
    Result->>User: constructed OMEArrow object
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • MikeLippincott

Poem

🐇 I hopped from GPU to CPU with glee,
Torch and JAX brought shapes to me,
Rank by rank I mapped each core,
From tensors to Arrow I hopped once more,
Now pixels hum and buffers leap — whee!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add torch/jax array ingest support to OMEArrow' directly summarizes the main feature addition across the changeset—enabling direct ingestion of torch and JAX arrays into OMEArrow with rank-based dimension order inference.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds first-class ingestion of PyTorch tensors and JAX arrays into the OMEArrow workflow, allowing users to construct OME-Arrow records directly from in-memory deep learning tensors (with optional dim_order overrides), and documents the new functionality.

Changes:

  • Add from_torch_array / from_jax_array ingest helpers plus tensor-rank-based default dim_order inference.
  • Extend OMEArrow(...) constructor to accept torch/jax inputs and an optional dim_order parameter (restricted to array/tensor inputs).
  • Add tests and documentation covering tensor ingest and the new APIs.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
src/ome_arrow/ingest.py Adds torch/jax ingest helpers, backend detection, and rank-based default dim_order inference.
src/ome_arrow/core.py Extends OMEArrow constructor to accept torch/jax inputs and optional dim_order.
src/ome_arrow/__init__.py Re-exports new ingest helpers at package top-level.
tests/test_core.py Adds unit tests for torch/jax constructor and explicit dim_order behavior.
README.md Documents direct tensor ingest and helper-based ingest.
docs/src/dlpack.md Adds examples for direct constructor ingest for torch/jax.
docs/src/python-api.md Adds autodoc sections for core and ingest modules.
Comments suppressed due to low confidence (1)

src/ome_arrow/core.py:105

  • The constructor now supports torch.Tensor and jax.Array, but the data type annotation still only advertises str | dict | pa.StructScalar | np.ndarray. This makes the public API signature inconsistent with the documented/implemented behavior and hurts static type checking. Consider broadening the annotation (e.g., to include Any/protocols for torch/jax) while keeping optional deps out of import-time typing.
    def __init__(
        self,
        data: str | dict | pa.StructScalar | "np.ndarray",
        dim_order: str | None = None,
        tcz: Tuple[int, int, int] = (0, 0, 0),
        column_name: str = "ome_arrow",
        row_index: int = 0,
        image_type: str | None = None,
        lazy: bool = False,
    ) -> None:

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/ome_arrow/core.py Outdated
Comment thread src/ome_arrow/ingest.py Outdated
Comment thread docs/src/python-api.md
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/ome_arrow/ingest.py`:
- Around line 929-1003: The project requires JAX >=0.4.1 because from_jax_array
uses the jax.Array type; update the dependency specification from "jax>=0.4" to
"jax>=0.4.1" in the project dependency manifests (e.g., pyproject.toml /
requirements files) so that callers installing the package get the minimum JAX
version that provides jax.Array referenced in from_jax_array; no code changes to
the from_jax_array function itself are needed.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 79fb9752-2635-42d2-bd41-4f7b23196c04

📥 Commits

Reviewing files that changed from the base of the PR and between d0c1ce0 and a732343.

📒 Files selected for processing (7)
  • README.md
  • docs/src/dlpack.md
  • docs/src/python-api.md
  • src/ome_arrow/__init__.py
  • src/ome_arrow/core.py
  • src/ome_arrow/ingest.py
  • tests/test_core.py

Comment thread src/ome_arrow/ingest.py
@d33bs d33bs requested a review from MattsonCam April 6, 2026 21:25
Comment thread docs/src/dlpack.md Outdated
Comment thread docs/src/dlpack.md Outdated
Comment thread src/ome_arrow/ingest.py Outdated
Comment thread src/ome_arrow/ingest.py
Comment thread src/ome_arrow/ingest.py
Comment thread README.md Outdated
Comment thread README.md Outdated
d33bs and others added 4 commits April 7, 2026 06:10
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/test_core.py (1)

398-469: Optional: parameterize Torch/JAX parity tests to reduce duplication.

The torch/jax test pairs are structurally identical. A backend-parameterized helper would keep this easier to extend and maintain.

♻️ Example refactor sketch
+@pytest.mark.parametrize("backend", ["torch", "jax"])
+def test_constructor_accepts_array_backend(backend: str) -> None:
+    if backend == "torch":
+        torch = pytest.importorskip("torch")
+        arr = torch.arange(2 * 3 * 4).reshape(2, 3, 4).to(dtype=torch.uint16)
+        as_numpy = arr.numpy()
+    else:
+        jnp = pytest.importorskip("jax.numpy")
+        arr = jnp.arange(2 * 3 * 4, dtype=jnp.uint16).reshape(2, 3, 4)
+        as_numpy = np.asarray(arr)
+
+    oa = OMEArrow(arr)
+    exported = oa.export(how="numpy")
+    assert exported.shape == (1, 1, 2, 3, 4)
+    np.testing.assert_array_equal(exported[0, 0], as_numpy)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_core.py` around lines 398 - 469, Several test pairs
(test_constructor_accepts_torch_tensor / test_constructor_accepts_jax_array,
test_from_torch_array_explicit_dim_order /
test_from_jax_array_explicit_dim_order,
test_constructor_dim_order_override_torch_tensor /
test_constructor_dim_order_override_jax_array) are duplicate; parameterize them
over backend to remove duplication. Create a single parameterized test using
pytest.mark.parametrize with a backend param ("torch" or "jax"), inside the test
call pytest.importorskip for the backend, map backend to the appropriate
constructors and helpers (e.g., OMEArrow, ingest.from_torch_array vs
ingest.from_jax_array) and to the conversion method (tensor.numpy() vs
np.asarray(arr)), then replace the duplicate test functions (the six named
tests) with the parameterized versions that branch using those mapped symbols
(OMEArrow, ingest.from_* and conversion) so behavior is identical while removing
duplicated code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/test_core.py`:
- Around line 398-469: Several test pairs (test_constructor_accepts_torch_tensor
/ test_constructor_accepts_jax_array, test_from_torch_array_explicit_dim_order /
test_from_jax_array_explicit_dim_order,
test_constructor_dim_order_override_torch_tensor /
test_constructor_dim_order_override_jax_array) are duplicate; parameterize them
over backend to remove duplication. Create a single parameterized test using
pytest.mark.parametrize with a backend param ("torch" or "jax"), inside the test
call pytest.importorskip for the backend, map backend to the appropriate
constructors and helpers (e.g., OMEArrow, ingest.from_torch_array vs
ingest.from_jax_array) and to the conversion method (tensor.numpy() vs
np.asarray(arr)), then replace the duplicate test functions (the six named
tests) with the parameterized versions that branch using those mapped symbols
(OMEArrow, ingest.from_* and conversion) so behavior is identical while removing
duplicated code.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b43d19b5-3126-4edc-93da-fc53bba92fb5

📥 Commits

Reviewing files that changed from the base of the PR and between 04a93cf and e1f55c8.

📒 Files selected for processing (4)
  • README.md
  • docs/src/dlpack.md
  • src/ome_arrow/ingest.py
  • tests/test_core.py
✅ Files skipped from review due to trivial changes (2)
  • docs/src/dlpack.md
  • README.md

Copy link
Copy Markdown
Member

@MattsonCam MattsonCam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @d33bs , good job!

Comment thread docs/src/dlpack.md Outdated
Comment thread src/ome_arrow/core.py Outdated
Args:
data: Input source or record payload.
dim_order: Axis labels used only for array/tensor ingest
(NumPy, torch, JAX). Ignored inputs are rejected to prevent
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the only arrays/tensors that can use dim_order? If so, I think you mention the arrays that can use dim_order in the init function docstring above, except for in the numpy array comment (line 114)

Comment thread src/ome_arrow/core.py Outdated
Comment thread src/ome_arrow/core.py Outdated
Comment thread src/ome_arrow/core.py
Comment thread src/ome_arrow/core.py Outdated
Comment thread src/ome_arrow/ingest.py
Comment thread src/ome_arrow/ingest.py
Comment thread src/ome_arrow/core.py
Comment thread src/ome_arrow/core.py
Comment thread src/ome_arrow/core.py Outdated
Comment thread src/ome_arrow/core.py Outdated
Comment thread README.md Outdated
Comment thread docs/src/dlpack.md
d33bs and others added 8 commits April 7, 2026 10:59
Co-Authored-By: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com>
Co-Authored-By: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com>
Co-Authored-By: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Co-Authored-By: Gregory Way <gregory.way@gmail.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/ome_arrow/core.py`:
- Around line 690-693: The view method currently allows unsupported render modes
to fall through and return None; update the OMEArrow.view implementation to
validate the how parameter and raise a ValueError for unsupported values instead
of returning None — add a final raise ValueError (with a clear message including
the invalid how) after the existing "pyvista" branch (and/or check against the
allowed modes at the top of the method) so calling OMEArrow(...).view(how="foo")
raises the documented ValueError.
- Around line 96-100: The constructor signature change shifted positional args
so calls like OMEArrow(arr, (0,0,1)) now bind the tuple to dim_order instead of
tcz; restore the original positional ABI by ensuring tcz remains the second
positional parameter and make dim_order keyword-only. Update __init__ so tcz
appears before dim_order (e.g. def __init__(self, data, tcz=(0,0,0), *,
dim_order=None, ...)) or otherwise place a bare * before dim_order; adjust
references inside __init__ accordingly (symbols: __init__, tcz, dim_order,
OMEArrow) so existing positional call sites keep working.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9921484b-9d60-481a-8577-16fc76b742aa

📥 Commits

Reviewing files that changed from the base of the PR and between e1f55c8 and 9e75d9c.

📒 Files selected for processing (4)
  • README.md
  • docs/src/dlpack.md
  • src/ome_arrow/core.py
  • tests/test_core.py
✅ Files skipped from review due to trivial changes (2)
  • README.md
  • docs/src/dlpack.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/test_core.py

Comment thread src/ome_arrow/core.py
Comment thread src/ome_arrow/core.py
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/ome_arrow/core.py`:
- Around line 687-689: The docstring for view() incorrectly states it returns a
bare matplotlib.figure.Figure while the matplotlib branch returns whatever
view_matplotlib(...) yields (which tests index with [0]); update the Returns
section to accurately describe the actual object returned by the matplotlib
branch (e.g., a tuple like (matplotlib.figure.Figure, matplotlib.axes.Axes) or
the specific object signature returned by view_matplotlib), referencing view()
and view_matplotlib(...) so callers/tests know to index [0], or alternatively
modify view() to unwrap view_matplotlib(...) and return only the Figure—choose
one consistent approach and update the docstring and any callers/tests
accordingly.
- Around line 767-769: Validate the 'how' argument before triggering
materialization: inside the method that calls self._ensure_materialized() (the
view/display method that accepts the how parameter), add an early check that how
is one of the allowed values ('matplotlib' or 'pyvista') and raise the
ValueError there if not, so the check runs before calling
self._ensure_materialized(); keep the existing error message text but perform
this validation prior to invoking self._ensure_materialized().
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 48747856-e5f2-4546-88e0-31535ca2a1fc

📥 Commits

Reviewing files that changed from the base of the PR and between 9e75d9c and 40d8491.

📒 Files selected for processing (2)
  • src/ome_arrow/core.py
  • tests/test_core.py

Comment thread src/ome_arrow/core.py Outdated
Comment thread src/ome_arrow/core.py
@d33bs
Copy link
Copy Markdown
Member Author

d33bs commented Apr 8, 2026

Thank you @gwaybio and @MattsonCam for your reviews! After addressing things I'll now merge this in.

@d33bs d33bs merged commit c33e70e into WayScience:main Apr 8, 2026
13 checks passed
@d33bs d33bs deleted the from-torch-or-jax-array branch April 8, 2026 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants