diff --git a/docs/user-guide/remapping.ipynb b/docs/user-guide/remapping.ipynb index d4f6b0770..cd187b2e5 100644 --- a/docs/user-guide/remapping.ipynb +++ b/docs/user-guide/remapping.ipynb @@ -1,8 +1,9 @@ { "cells": [ { - "metadata": {}, "cell_type": "markdown", + "id": "7eec39631eeeb6f8", + "metadata": {}, "source": [ "# Remapping\n", "\n", @@ -15,14 +16,14 @@ "- **Nearest Neighbor** \n", "- **Inverse Distance Weighted**\n", "- **Bilinear**\n" - ], - "id": "7eec39631eeeb6f8" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "bc895fe997d10515", + "metadata": {}, + "outputs": [], "source": [ "import os\n", "import urllib.request\n", @@ -39,8 +40,7 @@ "hv.extension(\"matplotlib\")\n", "\n", "common_kwargs = {\"cmap\": cmocean.cm.deep, \"features\": [\"coastline\"]}" - ], - "id": "bc895fe997d10515" + ] }, { "cell_type": "markdown", @@ -54,8 +54,10 @@ }, { "cell_type": "code", + "execution_count": null, "id": "4d73a380-349d-473d-8e57-10c52102adca", "metadata": {}, + "outputs": [], "source": [ "data_var = \"bottomDepth\"\n", "\n", @@ -80,22 +82,20 @@ "}\n", "uxds_480 = ux.open_dataset(*file_path_dict[\"480km\"])\n", "uxds_120 = ux.open_dataset(*file_path_dict[\"120km\"])" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "id": "da0b1ff8-da1a-4c6c-9031-749b34bfad7a", "metadata": {}, + "outputs": [], "source": [ "(\n", " uxds_480[\"bottomDepth\"].plot(title=\"Bottom Depth (480km)\", **common_kwargs)\n", " + uxds_120[\"bottomDepth\"].plot(title=\"Bottom Depth (120km)\", **common_kwargs)\n", ").cols(1).opts(fig_size=200)" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -107,13 +107,13 @@ }, { "cell_type": "code", + "execution_count": null, "id": "2453895d-b41d-47fe-bc2b-42358a9acbe5", "metadata": {}, + "outputs": [], "source": [ "uxds_120.remap" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -151,20 +151,22 @@ }, { "cell_type": "code", + "execution_count": null, "id": "d4550735-053a-4542-b259-fb7d8c2e6fae", "metadata": {}, + "outputs": [], "source": [ "upsampling = uxds_480[\"bottomDepth\"].remap.nearest_neighbor(\n", " destination_grid=uxds_120.uxgrid, remap_to=\"faces\"\n", ")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "id": "cf2cf918-62f8-4aa4-9fa1-122bc06862ce", "metadata": {}, + "outputs": [], "source": [ "(\n", " uxds_480[\"bottomDepth\"].plot(title=\"Bottom Depth (480km)\", **common_kwargs)\n", @@ -179,9 +181,7 @@ " **common_kwargs,\n", " )\n", ").cols(2).opts(fig_size=200)" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -203,20 +203,22 @@ }, { "cell_type": "code", + "execution_count": null, "id": "40094ba0-0dad-48d7-af70-040f088d7be5", "metadata": {}, + "outputs": [], "source": [ "downsampling = uxds_120[\"bottomDepth\"].remap.nearest_neighbor(\n", " destination_grid=uxds_480.uxgrid, remap_to=\"face centers\"\n", ")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "id": "65d3f7db-9820-4c46-8d78-571a8ea48ef1", "metadata": {}, + "outputs": [], "source": [ "(\n", " uxds_120[\"bottomDepth\"].plot(title=\"Bottom Depth (120km)\", **common_kwargs)\n", @@ -231,9 +233,7 @@ " **common_kwargs,\n", " )\n", ").cols(2).opts(fig_size=200)" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -280,20 +280,22 @@ }, { "cell_type": "code", + "execution_count": null, "id": "400398d5-5cc0-4790-9a95-cb0f88cc1ca8", "metadata": {}, + "outputs": [], "source": [ "upsampling_idw = uxds_480[\"bottomDepth\"].remap.inverse_distance_weighted(\n", " destination_grid=uxds_120.uxgrid, remap_to=\"faces\"\n", ")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "id": "8dd1d8e3-54e2-4710-b132-f69f0ba950fd", "metadata": {}, + "outputs": [], "source": [ "(\n", " uxds_480[\"bottomDepth\"].plot(title=\"Bottom Depth (480km)\", **common_kwargs)\n", @@ -310,9 +312,7 @@ " **common_kwargs,\n", " )\n", ").cols(2).opts(fig_size=200)" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -331,20 +331,22 @@ }, { "cell_type": "code", + "execution_count": null, "id": "7d21c42d-d368-4a34-8dcb-6643f0a0a7d1", "metadata": {}, + "outputs": [], "source": [ "downsampling_idw = uxds_120[\"bottomDepth\"].remap.inverse_distance_weighted(\n", " destination_grid=uxds_480.uxgrid, remap_to=\"faces\"\n", ")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "id": "935a7a74-7d88-49ae-bc20-7b18d76795c9", "metadata": {}, + "outputs": [], "source": [ "(\n", " uxds_120[\"bottomDepth\"].plot(title=\"Bottom Depth (120km)\", **common_kwargs)\n", @@ -361,9 +363,7 @@ " **common_kwargs,\n", " )\n", ").cols(2).opts(fig_size=200)" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -383,8 +383,10 @@ }, { "cell_type": "code", + "execution_count": null, "id": "9e31c8ec-75a0-4898-96e7-35a4b4853ad0", "metadata": {}, + "outputs": [], "source": [ "downsampling_idw_low = uxds_120[\"bottomDepth\"].remap.inverse_distance_weighted(\n", " uxds_480.uxgrid, remap_to=\"faces\", power=1, k=2\n", @@ -392,14 +394,14 @@ "downsampling_idw_high = uxds_120[\"bottomDepth\"].remap.inverse_distance_weighted(\n", " uxds_480.uxgrid, remap_to=\"faces\", power=5, k=128\n", ")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "id": "88756342-64e0-42f4-96d9-b9822e002bd7", "metadata": {}, + "outputs": [], "source": [ "(\n", " downsampling_idw_low.plot(\n", @@ -415,14 +417,14 @@ " **common_kwargs,\n", " )\n", ").cols(1).opts(fig_size=200)" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "id": "439ac480-c4f3-4080-a18c-8f670367c194", "metadata": {}, + "outputs": [], "source": [ "upsampling_idw_low = uxds_480[\"bottomDepth\"].remap.inverse_distance_weighted(\n", " uxds_120.uxgrid, remap_to=\"faces\", power=1, k=2\n", @@ -430,14 +432,14 @@ "upsampling_idw_high = uxds_480[\"bottomDepth\"].remap.inverse_distance_weighted(\n", " uxds_120.uxgrid, remap_to=\"faces\", power=5, k=128\n", ")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "id": "0dc0497d-13bd-4b5b-9792-ad6ce26aded5", "metadata": {}, + "outputs": [], "source": [ "(\n", " upsampling_idw_low.plot(\n", @@ -453,9 +455,7 @@ " **common_kwargs,\n", " )\n", ").cols(1).opts(fig_size=200)" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -471,7 +471,9 @@ "cell_type": "markdown", "id": "f1f33631-19b7-4b73-8452-7dc1e3fa48a2", "metadata": {}, - "source": "## Bilinear" + "source": [ + "## Bilinear" + ] }, { "cell_type": "markdown", @@ -493,23 +495,28 @@ "cell_type": "markdown", "id": "467252cb-9e07-42bd-8734-15666f612387", "metadata": {}, - "source": "### Upsampling" + "source": [ + "### Upsampling" + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": null, + "id": "c807a61bbc5bffff", + "metadata": {}, + "outputs": [], "source": [ "upsampling_bl = uxds_480[\"bottomDepth\"].remap.bilinear(\n", " destination_grid=uxds_120.uxgrid, remap_to=\"faces\"\n", ")" - ], - "id": "c807a61bbc5bffff", - "outputs": [], - "execution_count": null + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": null, + "id": "1429cc143fcd6694", + "metadata": {}, + "outputs": [], "source": [ "(\n", " uxds_480[\"bottomDepth\"].plot(title=\"Bottom Depth (480km)\", **common_kwargs)\n", @@ -526,32 +533,34 @@ " **common_kwargs,\n", " )\n", ").cols(2).opts(fig_size=200)" - ], - "id": "1429cc143fcd6694", - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", "id": "9ae38e84", "metadata": {}, - "source": "### Downsampling" + "source": [ + "### Downsampling" + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": null, + "id": "c3e12c15307c5042", + "metadata": {}, + "outputs": [], "source": [ "downsampling_bl = uxds_120[\"bottomDepth\"].remap.bilinear(\n", " destination_grid=uxds_480.uxgrid, remap_to=\"faces\"\n", ")" - ], - "id": "c3e12c15307c5042", - "outputs": [], - "execution_count": null + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": null, + "id": "bec1b9d05d0a40ba", + "metadata": {}, + "outputs": [], "source": [ "(\n", " uxds_120[\"bottomDepth\"].plot(title=\"Bottom Depth (120km)\", **common_kwargs)\n", @@ -568,10 +577,141 @@ " **common_kwargs,\n", " )\n", ").cols(2).opts(fig_size=200)" - ], - "id": "bec1b9d05d0a40ba", + ] + }, + { + "cell_type": "markdown", + "id": "938d5532-d93d-4abe-83db-098879b2cd93", + "metadata": {}, + "source": [ + "## Coordinate Handling\n", + "\n", + "The source data that is being remapped may have spatial coordinate variables present, and they need to be handled properly during the remapping operation. It may include swapping of coordinate values and renaming of some of the coordinates with respect to the dimensions of source data and the `remap_to` selection. This logic works as follows:\n", + "\n", + "1. If `remap_to` matches the `source` dimension (e.g. `source` on face centers` and `remap_to=\"faces\"` etc.)\n", + " - Swap values of the spatial coordinates with values of the corresponding coordinates from `destination_grid`.\n", + "\n", + "2. Else (if `remap_to` does not match `source` dimension, e.g. `source` on face centers but `remap_to=\"nodes\"` etc.)\n", + " - Swap values of the spatial coordinates with values of the coordinates from `destination_grid` that are defined on the `remap_to` dimension.\n", + " - Rename these coordinates to reflect the new element type (e.g. 'face_x' → 'node_x')\n", + "\n", + "Let us showcase both of these below:\n", + "\n", + "It would be helpful to recall the data array contents again:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41910c69-e9eb-4b47-b6ee-0250a5e044e7", + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "uxds_120[\"bottomDepth\"]" + ] + }, + { + "cell_type": "markdown", + "id": "d2032ad6-533b-474d-b87d-234e7f64f7c6", + "metadata": {}, + "source": [ + "Since the source data does not have any coordinate variables, let us add some arbitrarily named lat/lon and x/y coords into it and define it as a new data variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e11989a-f768-47ec-b925-d2d8fce94094", + "metadata": {}, + "outputs": [], + "source": [ + "uxda_with_coords = ux.core.UxDataArray(\n", + " data=uxds_120[\"bottomDepth\"],\n", + " uxgrid=uxds_120.uxgrid,\n", + " coords={\n", + " \"Mesh2_face_lat\": uxds_120.uxgrid.face_lat,\n", + " \"Mesh_coord_lon\": uxds_120.uxgrid.face_lon,\n", + " \"Mesh_Faces_x\": uxds_120.uxgrid.face_x,\n", + " \"Mesh_FACES_y\": uxds_120.uxgrid.face_y,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70a87e0a-6349-49c7-b0cb-3812bc8a604f", + "metadata": {}, + "outputs": [], + "source": [ + "uxda_with_coords" + ] + }, + { + "cell_type": "markdown", + "id": "d9dbe0f5-90aa-44c7-933b-099736500605", + "metadata": {}, + "source": [ + "### Source and remapped output dimensions match" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a420a376-afcb-4f4f-8bac-4f18e6aa2bcf", + "metadata": {}, + "outputs": [], + "source": [ + "remapped_same_dims = uxda_with_coords.remap.bilinear(\n", + " destination_grid=uxds_480.uxgrid, remap_to=\"faces\"\n", + ")\n", + "remapped_same_dims" + ] + }, + { + "cell_type": "markdown", + "id": "6eb42d78-c0bc-4d68-bed9-4cae5b185793", + "metadata": {}, + "source": [ + "Note the values of the coordinate variables in the source data have been swapped with the ones from the destination grid, but all the names have been kept as is since this is same-dimension remapping." + ] + }, + { + "cell_type": "markdown", + "id": "5c816670-f17d-481d-b8e1-da29efc75b62", + "metadata": {}, + "source": [ + "### Source and remapped output dimensions DO NOT match" + ] + }, + { + "cell_type": "markdown", + "id": "a10c24e4-1c9c-4612-8ada-f44b32a7df41", + "metadata": {}, + "source": [ + "We are now looking into the case of remapping `face`-centered source data to `node`s from the destination grids." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f584cb5-4710-45e0-882a-0d477b6fcbce", + "metadata": {}, + "outputs": [], + "source": [ + "remapped_different_dims = uxda_with_coords.remap.bilinear(\n", + " destination_grid=uxds_480.uxgrid, remap_to=\"nodes\"\n", + ")\n", + "remapped_different_dims" + ] + }, + { + "cell_type": "markdown", + "id": "2dc34687-ca2d-4f36-8efd-77627cb9a7c2", + "metadata": {}, + "source": [ + "Note the values of the `face`-related coordinate variables in the source data have been swapped with the `node`-related coordinates from the destination grid, and the names of those that had a form of \"face\" in them have been renamed to have \"node\" but the one that that did not have any \"face\" sting has been kept as is." + ] } ], "metadata": { @@ -590,7 +730,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.8" + "version": "3.13.11" } }, "nbformat": 4, diff --git a/test/test_remap.py b/test/test_remap.py index 186faae0d..836fd517e 100644 --- a/test/test_remap.py +++ b/test/test_remap.py @@ -265,3 +265,45 @@ def test_b_quadrilateral(gridpath, datasetpath): out = uxds['var2'].remap.bilinear(destination_grid=dest) assert out.size > 0 + +def test_b_coords_remap_to_faces(gridpath): + """Bilinear remap should change the array when remap_to != source.""" + mesh_path = gridpath("mpas", "QU", "mesh.QU.1920km.151026.nc") + uxds = ux.open_dataset(mesh_path, mesh_path) + dest = ux.open_grid(gridpath("ugrid", "geoflow-small", "grid.nc")) + + uxda_with_coords = ux.core.UxDataArray( + data=uxds["latCell"], + uxgrid=uxds.uxgrid, + coords={"Mesh2_face_lat": uxds.uxgrid.face_lat, + "Mesh_Face_lon": uxds.uxgrid.face_lon, + } + ) + + da_remap_b = uxda_with_coords.remap.bilinear( + destination_grid=dest, remap_to="faces" + ) + + assert (da_remap_b.Mesh_Face_lon.size == dest.face_lon.size) + assert np.array_equal(da_remap_b.Mesh_Face_lon.values, dest.face_lon.values) + +def test_b_coords_remap_to_nodes(gridpath): + """Bilinear remap should change the array when remap_to != source.""" + mesh_path = gridpath("mpas", "QU", "mesh.QU.1920km.151026.nc") + uxds = ux.open_dataset(mesh_path, mesh_path) + dest = ux.open_grid(gridpath("ugrid", "geoflow-small", "grid.nc")) + + uxda_with_coords = ux.core.UxDataArray( + data=uxds["latCell"], + uxgrid=uxds.uxgrid, + coords={"Mesh2_face_lat": uxds.uxgrid.face_lat, + "Mesh_Face_lon": uxds.uxgrid.face_lon, + } + ) + + da_remap_b = uxda_with_coords.remap.bilinear( + destination_grid=dest, remap_to="nodes" + ) + + assert (da_remap_b.Mesh_Node_lon.size == dest.node_lon.size) + assert np.array_equal(da_remap_b.Mesh_Node_lon.values, dest.node_lon.values) diff --git a/uxarray/conventions/ugrid.py b/uxarray/conventions/ugrid.py index 707c1dd79..f0f19de0b 100644 --- a/uxarray/conventions/ugrid.py +++ b/uxarray/conventions/ugrid.py @@ -84,12 +84,14 @@ "standard_name": "x", "long name": "Cartesian x location of the corner nodes of each face", "units": "meters", + "axis": "X", } NODE_Y_ATTRS = { "standard_name": "y", "long name": "Cartesian y location of the corner nodes of each face", "units": "meters", + "axis": "Y", } NODE_Z_ATTRS = { @@ -104,12 +106,14 @@ "standard_name": "x", "long name": "Cartesian x location of the center of each edge", "units": "meters", + "axis": "X", } EDGE_Y_ATTRS = { "standard_name": "y", "long name": "Cartesian y location of the center of each edge", "units": "meters", + "axis": "Y", } EDGE_Z_ATTRS = { @@ -124,12 +128,14 @@ "standard_name": "x", "long name": "Cartesian x location of the center of each face", "units": "meters", + "axis": "X", } FACE_Y_ATTRS = { "standard_name": "y", "long name": "Cartesian y location of the center of each face", "units": "meters", + "axis": "Y", } FACE_Z_ATTRS = { diff --git a/uxarray/remap/bilinear.py b/uxarray/remap/bilinear.py index 1ebfd0fb2..52261d44b 100644 --- a/uxarray/remap/bilinear.py +++ b/uxarray/remap/bilinear.py @@ -29,7 +29,7 @@ def _bilinear( source: UxDataArray | UxDataset, destination_grid: Grid, - destination_dim: str = "n_face", + remap_to: str = "faces", ) -> np.ndarray: """Bilinear Remapping between two grids, mapping data that resides on the corner nodes, edge centers, or face centers on the source grid to the @@ -39,8 +39,8 @@ def _bilinear( --------- source_uxda : UxDataArray Source UxDataArray - remap_to : str, default="nodes" - Location of where to map data, either "nodes", "edge centers", or "face centers" + remap_to : str, default="faces" + Which grid element receives the remapped values, either "nodes", "edges", or "faces" Returns ------- @@ -49,7 +49,7 @@ def _bilinear( """ # ensure array is a np.ndarray - _assert_dimension(destination_dim) + _assert_dimension(remap_to) # Ensure the destination grid is normalized destination_grid.normalize_cartesian_coordinates() @@ -70,12 +70,12 @@ def _bilinear( dual = source.uxgrid.get_dual() # get destination coordinate pairs - point_xyz = _prepare_points(destination_grid, destination_dim) + point_xyz = _prepare_points(destination_grid, remap_to) weights, indices = _barycentric_weights( point_xyz=point_xyz, dual=dual, - data_size=getattr(destination_grid, f"n_{KDTREE_DIM_MAP[destination_dim]}"), + data_size=getattr(destination_grid, f"n_{KDTREE_DIM_MAP[remap_to]}"), source_grid=ds.uxgrid, ) @@ -87,8 +87,8 @@ def _bilinear( inds, w = indices, weights # pack indices & weights into tiny DataArrays: - indexer = xr.DataArray(inds, dims=[LABEL_TO_COORD[destination_dim], "k"]) - weight_da = xr.DataArray(w, dims=[LABEL_TO_COORD[destination_dim], "k"]) + indexer = xr.DataArray(inds, dims=[LABEL_TO_COORD[remap_to], "k"]) + weight_da = xr.DataArray(w, dims=[LABEL_TO_COORD[remap_to], "k"]) # gather the k neighbor values: da_k = da.isel({source_dim: indexer}, ignore_grid=True) @@ -103,7 +103,7 @@ def _bilinear( remapped_vars[name] = da ds_remapped = _construct_remapped_ds( - source, remapped_vars, destination_grid, destination_dim + source, remapped_vars, destination_grid, remap_to ) return ds_remapped[name] if is_da else ds_remapped diff --git a/uxarray/remap/inverse_distance_weighted.py b/uxarray/remap/inverse_distance_weighted.py index 487cfe09f..e1dd6e14f 100644 --- a/uxarray/remap/inverse_distance_weighted.py +++ b/uxarray/remap/inverse_distance_weighted.py @@ -52,7 +52,7 @@ def _idw_weights(distances, power): def _inverse_distance_weighted_remap( source: UxDataArray | UxDataset, destination_grid: Grid, - destination_dim: str = "n_face", + remap_to: str = "faces", power: int = 2, k: int = 8, ): @@ -68,8 +68,8 @@ def _inverse_distance_weighted_remap( The data to be remapped. destination_grid : Grid The UXarray grid instance on which to interpolate data. - destination_dim : {'n_node', 'n_edge', 'n_face'}, default='n_face' - The spatial dimension on `destination_grid` to receive interpolated values. + remap_to : {'nodes', 'edges', 'faces'}, default='faces' + Which grid element receives the remapped values, either "nodes", "edges", or "faces" power : int, default=2 Exponent in the inverse-distance weighting function. Larger values emphasize closer neighbors. @@ -88,9 +88,9 @@ def _inverse_distance_weighted_remap( """ # Fall back onto nearest neighbor if k == 1: - return _nearest_neighbor_remap(source, destination_grid, destination_dim) + return _nearest_neighbor_remap(source, destination_grid, remap_to) - _assert_dimension(destination_dim) + _assert_dimension(remap_to) # Perform remapping on a UxDataset ds, is_da, name = _to_dataset(source) @@ -106,7 +106,7 @@ def _inverse_distance_weighted_remap( ds.uxgrid, destination_grid, src_dim, - destination_dim, + remap_to, k=k, return_distances=True, ) @@ -123,8 +123,8 @@ def _inverse_distance_weighted_remap( inds, w = indices_weights_map[source_dim] # pack indices & weights into tiny DataArrays: - indexer = xr.DataArray(inds, dims=[LABEL_TO_COORD[destination_dim], "k"]) - weight_da = xr.DataArray(w, dims=[LABEL_TO_COORD[destination_dim], "k"]) + indexer = xr.DataArray(inds, dims=[LABEL_TO_COORD[remap_to], "k"]) + weight_da = xr.DataArray(w, dims=[LABEL_TO_COORD[remap_to], "k"]) # gather the k neighbor values: da_k = da.isel({source_dim: indexer}, ignore_grid=True) @@ -139,7 +139,7 @@ def _inverse_distance_weighted_remap( remapped_vars[name] = da ds_remapped = _construct_remapped_ds( - source, remapped_vars, destination_grid, destination_dim + source, remapped_vars, destination_grid, remap_to ) return ds_remapped[name] if is_da else ds_remapped diff --git a/uxarray/remap/nearest_neighbor.py b/uxarray/remap/nearest_neighbor.py index 7e4fc197e..30fa16c75 100644 --- a/uxarray/remap/nearest_neighbor.py +++ b/uxarray/remap/nearest_neighbor.py @@ -75,7 +75,7 @@ def _nearest_neighbor_query( def _nearest_neighbor_remap( source: UxDataArray | UxDataset, destination_grid: Grid, - destination_dim: str = "n_face", + remap_to: str = "faces", ): """ Apply nearest-neighbor remapping from a UXarray object onto another grid. @@ -88,15 +88,15 @@ def _nearest_neighbor_remap( The data array or dataset to be remapped. destination_grid : Grid The UXarray Grid instance to which data will be remapped. - destination_dim : str, default='n_face' - The spatial dimension on the destination grid ('n_node', 'n_edge', 'n_face'). + remap_to : str, default='faces' + Which grid element receives the remapped values, either 'nodes', 'edges', 'faces'). Returns ------- UxDataArray or UxDataset A new UXarray object with data values assigned to `destination_grid`. """ - _assert_dimension(destination_dim) + _assert_dimension(remap_to) # Perform remapping on a UxDataset ds, is_da, name = _to_dataset(source) @@ -106,9 +106,7 @@ def _nearest_neighbor_remap( # Build Nearest Neighbor Index Arrays indices_map: dict[str, np.ndarray] = { - src_dim: _nearest_neighbor_query( - ds.uxgrid, destination_grid, src_dim, destination_dim - ) + src_dim: _nearest_neighbor_query(ds.uxgrid, destination_grid, src_dim, remap_to) for src_dim in dims_to_remap } remapped_vars = {} @@ -122,7 +120,7 @@ def _nearest_neighbor_remap( indexer = xr.DataArray( indices, dims=[ - LABEL_TO_COORD[destination_dim], + LABEL_TO_COORD[remap_to], ], ) @@ -134,7 +132,7 @@ def _nearest_neighbor_remap( remapped_vars[name] = da ds_remapped = _construct_remapped_ds( - source, remapped_vars, destination_grid, destination_dim + source, remapped_vars, destination_grid, remap_to ) return ds_remapped[name] if is_da else ds_remapped diff --git a/uxarray/remap/spatial_coords_remap.py b/uxarray/remap/spatial_coords_remap.py new file mode 100644 index 000000000..70f32867f --- /dev/null +++ b/uxarray/remap/spatial_coords_remap.py @@ -0,0 +1,350 @@ +import warnings +from typing import Dict, Literal, Optional, Tuple + +import xarray as xr + +from uxarray.core.dataarray import UxDataArray +from uxarray.grid.grid import Grid + +COORD_TYPES = { + "LON": "lon", + "LAT": "lat", + "CART_X": "X", + "CART_Y": "Y", +} + +# CF attributes that indicate coordinate type +CF_LAT_ATTRS = ["latitude", "projection_y_coordinate"] +CF_LON_ATTRS = ["longitude", "projection_x_coordinate"] + +# CF units that indicate coordinate type +CF_LAT_UNITS = ["degrees_north", "degree_north", "degree_n"] +CF_LON_UNITS = ["degrees_east", "degree_east", "degree_e"] + + +class SpatialCoordsRemapper: + """Ensures remapping spatial coordinates between the source and destination grid for the remapping functions. + It may include swapping of values and renaming of some of the coordinates with respect to the dimensions of + source data and the `remap_to` selection.""" + + def __init__( + self, + source: UxDataArray, + destination_grid: Grid, + remap_to: Literal["nodes", "faces", "edges"], + ): + """ + Initialize spatial coordinate remapper for UXarray's remapping functions. + + Parameters + ---------- + source : UxDataArray + Source data array that is being remapped to the `destination_grid`. + destination_grid : Grid + Destination grid that `source` is being remapped to. + remap_to : str + Which grid element receives the remapped values, either 'nodes', 'faces', or 'edges'. + """ + + if source is None: + raise ValueError( + "`source` must be provided for spatial coordinates remapping." + ) + + if destination_grid is None: + raise ValueError( + "`destination_grid` must be provided for spatial coordinates remapping." + ) + + self.destination_grid = destination_grid + self.source = source + self.remap_to = remap_to + + def _get_destination_grid_coords(self) -> Dict[str, xr.DataArray]: + """ + Get the spatial coordinates of the destination grid corresponding to `remap_to`. + + Returns + ------- + Dict[str, xr.DataArray] + Dictionary with 'lon' and 'lat' coordinate arrays + """ + if self.remap_to == "nodes": + return { + COORD_TYPES["LON"]: self.destination_grid.node_lon, + COORD_TYPES["LAT"]: self.destination_grid.node_lat, + COORD_TYPES["CART_X"]: self.destination_grid.node_x, + COORD_TYPES["CART_Y"]: self.destination_grid.node_y, + } + elif self.remap_to == "faces": + return { + COORD_TYPES["LON"]: self.destination_grid.face_lon, + COORD_TYPES["LAT"]: self.destination_grid.face_lat, + COORD_TYPES["CART_X"]: self.destination_grid.face_x, + COORD_TYPES["CART_Y"]: self.destination_grid.face_y, + } + elif self.remap_to == "edges": + return { + COORD_TYPES["LON"]: self.destination_grid.edge_lon, + COORD_TYPES["LAT"]: self.destination_grid.edge_lat, + COORD_TYPES["CART_X"]: self.destination_grid.edge_x, + COORD_TYPES["CART_Y"]: self.destination_grid.edge_y, + } + else: + raise ValueError( + f"Unknown `remap_to`: {self.remap_to}. Must be either 'nodes', 'faces', or 'edges'." + ) + + def _find_source_coords(self) -> Dict[str, Tuple[str, str]]: + """ + Find spatial coordinate variables in `source` by checking their attributes, units, and axes. + + Returns + ------- + Dict[str, Tuple[str, str]] + Dictionary with keys as spatial identifiers ('lat' or 'lon') and values as + (coord_var_name, standard_name) tuples + + Example output would look like: + { + 'lat': ('Mesh2_face_y', 'latitude'), + 'lon': ('Mesh2_face_x', 'longitude') + } + """ + + source_coords = {} + + # Check all coordinates in `source` + for coord_name in self.source.coords: + coord = self.source.coords[coord_name] + + # Skip if in rare case this coordinate doesn't have dimensions or has multiple dimensions + if not hasattr(coord, "dims") or len(coord.dims) != 1: + continue + + # Determine if this is a spatial coordinate by checking attributes + is_spatial = False + coord_type = None # will be 'lat' or 'lon' later + + if hasattr(coord, "attrs"): + # Check `standard_name` first + if "standard_name" in coord.attrs: + std_name = coord.attrs["standard_name"].lower() + if std_name in CF_LAT_ATTRS: + is_spatial = True + coord_type = COORD_TYPES["LAT"] + elif std_name in CF_LON_ATTRS: + is_spatial = True + coord_type = COORD_TYPES["LON"] + + # Check units if standard_name didn't work + if not is_spatial and "units" in coord.attrs: + units = coord.attrs["units"].lower() + if any(u in units for u in CF_LAT_UNITS): + is_spatial = True + coord_type = COORD_TYPES["LAT"] + elif any(u in units for u in CF_LON_UNITS): + is_spatial = True + coord_type = COORD_TYPES["LON"] + + # Check axis attribute as last chance + if not is_spatial and "axis" in coord.attrs: + axis = coord.attrs["axis"].upper() + if axis == COORD_TYPES["CART_Y"]: + is_spatial = True + coord_type = COORD_TYPES["CART_Y"] + elif axis == COORD_TYPES["CART_X"]: + is_spatial = True + coord_type = COORD_TYPES["CART_X"] + + # If a spatial coord is found and `coord_type` is identified in `source` + if is_spatial and coord_type: + # Store the coordinate variable + standard_name = coord.attrs.get("standard_name", coord_type) + source_coords[coord_type] = (coord_name, standard_name) + + return source_coords + + def _get_element_type_from_dimension(self, dim_name: str) -> Optional[str]: + """ + Determine element type (i.e. 'nodes', 'faces', or 'edges') from dimension name. + + Parameters + ---------- + dim_name : str + Dimension name (e.g., 'n_face', 'nMesh2_face', etc.) + + Returns + ------- + Optional[str] + Element type ('nodes', 'faces', 'edges') or None + """ + dim_lower = dim_name.lower() + if "face" in dim_lower: + return "faces" + elif "node" in dim_lower: + return "nodes" + elif "edge" in dim_lower: + return "edges" + return None + + def _rename_coord_for_new_dimension( + self, original_name: str, old_element: str, new_element: str + ) -> str: + """ + Rename a coordinate variable when changing from one element type to another, which occurs when the `remap_to` + element does not match the `source` dimension. + + Parameters + ---------- + original_name : str + Original coordinate variable name + old_element : str + Old element type ('nodes', 'faces', 'edges') + new_element : str + New element type ('nodes', 'faces', 'edges') + + Returns + ------- + str + New coordinate name with element type updated + """ + # Map plural to singular + element_type_to_coord_name_string = { + "nodes": "node", + "faces": "face", + "edges": "edge", + } + + old_coord_name_string = element_type_to_coord_name_string[old_element] + new_coord_name_string = element_type_to_coord_name_string[new_element] + + # Try to replace the old element name in the coordinate name + # Handle both singular and plural forms + new_name = original_name + + # Case-sensitive replacements + # e.g. "*face*" -> "*node*" + new_name = new_name.replace(old_coord_name_string, new_coord_name_string) + # e.g. "*faces*" -> "*nodes*" + new_name = new_name.replace(old_element, new_element) + # e.g. "*FACE*" -> "*NODE*" + new_name = new_name.replace( + old_coord_name_string.upper(), new_coord_name_string.upper() + ) + # e.g. "*FACES*" -> "*NODES*" + new_name = new_name.replace(old_element.upper(), new_element.upper()) + # e.g. "*Face*" -> "*Node*" + new_name = new_name.replace( + old_coord_name_string.capitalize(), new_coord_name_string.capitalize() + ) + # e.g. "*Faces*" -> "*Nodes*" + new_name = new_name.replace(old_element.capitalize(), new_element.capitalize()) + + return new_name + + def construct_output_coords(self) -> Dict[str, xr.DataArray]: + """ + Construct spatial coordinates for the remapped output by finding spatial coordinate variables, if any, + in `source` and employing a logic as follows: + + Logic: + ------ + If `remap_to` matches the `source` dimension (e.g. `source` on face centers` and `remap_to="faces"` etc.) + - Swap values of spatial coords with values of the corresponding coords from `destination_grid` + + Else (if `remap_to` doesn't match `source` dim (e.g. `source` on face centers but `remap_to="nodes"` etc.)) + - Swap values of spatial coords with values of the coords from `destination_grid` that are + defined on the `remap_to` dimension. + - Rename these coords to reflect new element type (e.g. 'face_x' → 'node_x') + + Returns + ------- + Dict[str, xr.DataArray] + Dictionary mapping output coordinate variables to their new values + """ + + # Find spatial coordinate variables in `source` by checking their attributes + source_coords = self._find_source_coords() + + if not source_coords: + warnings.warn( + "No spatial coordinate variables found in `source`.", + UserWarning, + stacklevel=2, + ) + return {} + + # Get the dimension that `source` is defined on + source_dims = list(self.source.dims) + if len(source_dims) == 0: + raise ValueError("Source data has no dimensions") + + # Find the primary spatial dimension (should be n_face, n_node, or n_edge) + source_spatial_dim = None + for dim in source_dims: + if self._get_element_type_from_dimension(dim) is not None: + source_spatial_dim = dim + break + + if source_spatial_dim is None: + raise ValueError( + f"Could not identify spatial dimension in `source` dims: {source_dims}" + ) + + source_element_type = self._get_element_type_from_dimension(source_spatial_dim) + + # Get destination grid values for the remap_to element + dest_grid_coords = self._get_destination_grid_coords() + + output_coords = {} + + # Logic for the remapped spatial coords construction starts here + # If `remap_to` matches `source` dimension + if source_element_type == self.remap_to: + # Swap coords on matching dimension + for coord_type in COORD_TYPES.values(): + if coord_type in source_coords: + source_coord_name, std_name = source_coords[coord_type] + out_name = source_coord_name + + # Assign destination grid values + output_coords[out_name] = dest_grid_coords[coord_type].variable + + # `remap_to` differs from `source` dimension + else: + warnings.warn( + f"Coordinates handling as part of remapping: `source` has the dimension:" + f"('{source_spatial_dim}') but is being remapped to ('{self.remap_to}'). Therefore, " + f"coordinate values will be swapped to the '{self.remap_to}' coordinates from " + f"`destination_grid` and renamed accordingly.", + UserWarning, + stacklevel=2, + ) + + renamed_coords = [] + + # Swap and rename (as needed) coords from source dimension + for coord_type in COORD_TYPES.values(): + if coord_type in source_coords: + source_coord_name, std_name = source_coords[coord_type] + + # Rename to reflect new element type + out_name = self._rename_coord_for_new_dimension( + source_coord_name, source_element_type, self.remap_to + ) + if out_name != source_coord_name: + renamed_coords.append((source_coord_name, out_name)) + + # Assign destination grid values on remap_to dimension + output_coords[out_name] = dest_grid_coords[coord_type].variable + + if renamed_coords: + for old, new in renamed_coords: + warnings.warn( + f"Renamed coordinate '{old}' → '{new}' due to dimension change.", + UserWarning, + stacklevel=2, + ) + + return output_coords diff --git a/uxarray/remap/utils.py b/uxarray/remap/utils.py index c60a9c517..cefcb606f 100644 --- a/uxarray/remap/utils.py +++ b/uxarray/remap/utils.py @@ -1,5 +1,3 @@ -from copy import deepcopy - import numpy as np import uxarray.core.dataset @@ -57,7 +55,7 @@ def _assert_dimension(dim): raise ValueError(f"Invalid spatial dimension: {dim!r}") -def _construct_remapped_ds(source, remapped_vars, destination_grid, destination_dim): +def _construct_remapped_ds(source, remapped_vars, destination_grid, remap_to): """ Construct a new UxDataset from remapped data variables and updated coordinates. @@ -69,22 +67,29 @@ def _construct_remapped_ds(source, remapped_vars, destination_grid, destination_ Mapping of variable names to their remapped DataArrays. destination_grid : Grid The UXarray grid instance representing the new topology. - destination_dim : str - The spatial dimension name (e.g., 'n_face') for the destination grid. + remap_to : str + Which grid element receives the remapped values, either "nodes", "edges", or "faces" Returns ------- UxDataset A new dataset containing only the remapped variables and retained coordinates. """ - destination_coords = deepcopy(source.coords) - if destination_dim in destination_coords: - del destination_coords[destination_dim] + + from uxarray.remap.spatial_coords_remap import SpatialCoordsRemapper + + # Ensure handling of spatial coordinates between `source` and `destination_grid` for the remapped output + # with respect to the source dimension and `remap_to` selection. See the class definition and functions + # for detailed information + coords_remapper = SpatialCoordsRemapper( + source=source, destination_grid=destination_grid, remap_to=remap_to + ) + output_coords = coords_remapper.construct_output_coords() ds_remapped = uxarray.core.dataset.UxDataset( data_vars=remapped_vars, uxgrid=destination_grid, - coords=destination_coords, + coords=output_coords, ) return ds_remapped