Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 127 additions & 23 deletions src/e3sm_quickview/plugins/eam_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,53 @@
except Exception as e:
print(e)
import math
import os
from concurrent.futures import ThreadPoolExecutor

from paraview import print_error
from vtkmodules.util import numpy_support, vtkConstants
from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase

# Number of threads for the projection fan-out. pyproj releases the GIL
# inside Transformer.transform, so chunking the input across threads
# scales nearly linearly (7.4x on 8 threads in our bench). Default is
# max(1, cpu_count - 1) to leave one core for the UI/IO thread; override
# via QV_PROJECTION_THREADS for HPC machines or to pin down for testing.
def _default_projection_threads():
env = os.environ.get("QV_PROJECTION_THREADS")
if env:
try:
return max(1, int(env))
except ValueError:
pass
return max(1, (os.cpu_count() or 2) - 1)


_PROJECTION_THREADS = _default_projection_threads()
# Below this point count the thread-pool overhead outweighs the speedup.
_PROJECTION_THREADING_MIN = 1_000_000


def _threaded_transform(xformer, x, y):
"""Apply xformer.transform over x, y by chunking across threads."""
n = len(x)
if _PROJECTION_THREADS <= 1 or n < _PROJECTION_THREADING_MIN:
return xformer.transform(x, y)

chunk = n // _PROJECTION_THREADS

def work(i):
lo = i * chunk
hi = n if i == _PROJECTION_THREADS - 1 else lo + chunk
return xformer.transform(x[lo:hi], y[lo:hi])

with ThreadPoolExecutor(max_workers=_PROJECTION_THREADS) as ex:
results = list(ex.map(work, range(_PROJECTION_THREADS)))

x_out = np.concatenate([r[0] for r in results])
y_out = np.concatenate([r[1] for r in results])
return x_out, y_out

try:
import warnings

Expand Down Expand Up @@ -62,19 +104,61 @@ def ProcessPoint(point, radius):
return [x, y, z]


# Slice plans keyed on the PedigreeIds array identity. Pedigree permutations
# from vtkTableBasedClipDataSet are long-run-monotonic (typically runs of
# thousands of +1-stepped indices), so we can replace fancy indexing with a
# list of slice copies and reduce the per-tick cost substantially.
_pedigree_slice_plan_cache = {}


def _get_pedigree_slice_plan(pedigree_vtk):
"""Return (starts, ends, pid_np) for the pedigree permutation.

The plan represents pedigree as a sequence of runs where each run i maps
output[starts[i]:ends[i]] ← input[pid_np[starts[i]]:pid_np[starts[i]]+len].
Cached by (id, MTime) of the pedigree VTK array — vtk_to_numpy returns
a fresh ndarray each call, so keying on ndarray identity would miss.
"""
key = (id(pedigree_vtk), pedigree_vtk.GetMTime())
entry = _pedigree_slice_plan_cache.get(key)
if entry is not None:
return entry

pid_np = numpy_support.vtk_to_numpy(pedigree_vtk)
diff = np.diff(pid_np.astype(np.int64, copy=False))
breaks = np.flatnonzero(diff != 1)
starts = np.empty(len(breaks) + 1, dtype=np.int64)
starts[0] = 0
starts[1:] = breaks + 1
ends = np.empty_like(starts)
ends[:-1] = starts[1:]
ends[-1] = len(pid_np)
entry = (starts, ends, pid_np)
_pedigree_slice_plan_cache[key] = entry
return entry


def add_cell_arrays(inData, outData, cached_output):
"""
Adds arrays not modified in inData to outData.
New arrays (or arrays modified) values are
set using the PedigreeIds because the number of values
in the new array (just read from the file) is different
than the number of values in the arrays already processed through he
pipeline.
New arrays (or arrays modified) values are set using the PedigreeIds
because the number of values in the new array (just read from the file)
is different than the number of values in the arrays already processed
through the pipeline.

The indexed copy is done in-place into a pre-allocated output buffer
using a cached slice plan over the pedigree permutation — roughly 2x
faster than fancy numpy indexing for the clip-induced permutations we
see here.
"""
pedigreeIds = cached_output.cell_data["PedigreeIds"]
if pedigreeIds is None:
print_error("Error: no PedigreeIds array")
return

pedigree_vtk = cached_output.GetCellData().GetArray("PedigreeIds")
starts, ends, pid_np = _get_pedigree_slice_plan(pedigree_vtk)

cached_cell_data = cached_output.GetCellData()
in_cell_data = inData.GetCellData()
outData.ShallowCopy(cached_output)
Expand All @@ -85,19 +169,24 @@ def add_cell_arrays(inData, outData, cached_output):
in_array = in_cell_data.GetArray(i)
cached_array = cached_cell_data.GetArray(in_array.GetName())
if cached_array and cached_array.GetMTime() >= in_array.GetMTime():
# this scalar has been seen before
# simply add a reference in the outData
# This scalar has been seen before — reuse cached copy.
out_cell_data.AddArray(cached_array)
else:
# this scalar is new
# we have to fill in the additional cells resulted from the clip
out_array = in_array.NewInstance()
array0 = cached_cell_data.GetArray(0)
out_array.SetNumberOfComponents(array0.GetNumberOfComponents())
out_array.SetNumberOfTuples(array0.GetNumberOfTuples())
n_comp = array0.GetNumberOfComponents()
n_tuples = array0.GetNumberOfTuples()
out_array = in_array.NewInstance()
out_array.SetNumberOfComponents(n_comp)
out_array.SetNumberOfTuples(n_tuples)
out_array.SetName(in_array.GetName())
out_cell_data.AddArray(out_array)
outData.cell_data[out_array.GetName()] = inData.cell_data[i][pedigreeIds]

in_np = numpy_support.vtk_to_numpy(in_array)
out_np = numpy_support.vtk_to_numpy(out_array)
for s, e in zip(starts, ends):
src_off = int(pid_np[s])
out_np[s:e] = in_np[src_off:src_off + (e - s)]
out_array.Modified()


@smproxy.filter()
Expand Down Expand Up @@ -286,17 +375,26 @@ def __init__(self):
self.project = 0
self.translate = False
self.cached_points = None
# Cache keyed on input-points identity + projection params. Immune to
# spurious upstream Modified() on the shared points.
self._cached_input_points = None
self._cached_key = None

def _invalidate_cache(self):
self.cached_points = None
self._cached_input_points = None
self._cached_key = None

def SetTranslation(self, translate):
if self.translate != translate:
self.translate = translate
self.cached_points = None
self._invalidate_cache()
self.Modified()

def SetProjection(self, project):
if self.project != int(project):
self.project = int(project)
self.cached_points = None
self._invalidate_cache()
self.Modified()

def RequestData(self, request, inInfo, outInfo):
Expand All @@ -310,9 +408,9 @@ def RequestData(self, request, inInfo, outInfo):
else:
outData.ShallowCopy(inData)

if self.cached_points and self.cached_points.GetMTime() >= max(
inData.GetPoints().GetMTime(), self.GetMTime()
):
in_points = inData.GetPoints()
cache_key = (id(in_points), self.project, self.translate)
if self.cached_points is not None and self._cached_key == cache_key:
outData.SetPoints(self.cached_points)
else:
# we modify the points, so copy them
Expand All @@ -337,7 +435,7 @@ def RequestData(self, request, inInfo, outInfo):
return 1

xformer = Transformer.from_proj(latlon, proj, always_xy=True)
res = xformer.transform(x, y)
res = _threaded_transform(xformer, x, y)
except Exception as e:
print(f"Projection error: {e}")
# If projection fails, return without modifying coordinates
Expand All @@ -351,6 +449,8 @@ def RequestData(self, request, inInfo, outInfo):
# the previous cached_points, if any, is available for
# garbage collection after this assignment
self.cached_points = out_points_vtk
self._cached_input_points = in_points # hold ref so id() stays valid
self._cached_key = cache_key

return 1

Expand Down Expand Up @@ -472,6 +572,7 @@ def __init__(self):
self.trim_lat = [0, 0]
self.cached_cell_centers = None
self._cached_output = None
self._last_was_trimmed = False

def SetTrimLongitude(self, left, right):
if left < 0 or left > 360 or right < 0 or right > 360 or left > (360 - right):
Expand All @@ -498,10 +599,12 @@ def RequestData(self, request, inInfo, outInfo):
outData = self.GetOutputData(outInfo, 0)
if self.trim_lon == [0, 0] and self.trim_lat == [0, 0]:
outData.ShallowCopy(inData)
# if the filter execution follows an another execution that trims the
# number of points, the downstream filter could think that
# the trimmed points are still valid which results in a crash
outData.GetPoints().Modified()
# Only invalidate the shared points when transitioning *out* of a
# trimmed state — the original code did it unconditionally, which
# defeated EAMProject's cache on every pipeline update.
if self._last_was_trimmed:
outData.GetPoints().Modified()
self._last_was_trimmed = False
return 1

if self.cached_cell_centers and self.cached_cell_centers.GetMTime() >= max(
Expand Down Expand Up @@ -574,6 +677,7 @@ def RequestData(self, request, inInfo, outInfo):

self._cached_output = outData.NewInstance()
self._cached_output.ShallowCopy(outData)
self._last_was_trimmed = True
return 1


Expand Down
Loading