from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from motac.spatial.crs import LonLatToXY
from motac.spatial.types import Grid
[docs]
@dataclass(frozen=True, slots=True)
class GridCellLookup:
"""Fast lon/lat -> cell_id lookup for regular grids.
Notes
-----
- Assumes the grid was generated by :func:`motac.spatial.grid_builder.build_regular_grid`.
- Cell ids follow the ravel order used by ``np.meshgrid(xs, ys)`` and ``ravel()``:
x varies fastest (row-major), so ``cell_id = iy * nx + ix``.
- Returns ``-1`` for points outside the grid extent.
"""
tf: LonLatToXY
x0_edge: float
y0_edge: float
nx: int
ny: int
cell_size_m: float
[docs]
@classmethod
def from_grid(cls, grid: Grid) -> GridCellLookup:
n_cells = int(np.asarray(grid.lon).shape[0])
if n_cells < 2:
raise ValueError("grid must contain at least 2 cells")
lon0 = float(np.mean(grid.lon))
lat0 = float(np.mean(grid.lat))
tf = LonLatToXY.for_lonlat(lon0, lat0)
x, y = tf.to_xy.transform(grid.lon, grid.lat)
x = np.asarray(x, dtype=float)
y = np.asarray(y, dtype=float)
if x.shape != grid.lon.shape or y.shape != grid.lat.shape:
raise ValueError("grid.lon/grid.lat must be 1D arrays of equal length")
cs = float(grid.cell_size_m)
if not np.isfinite(cs) or cs <= 0:
raise ValueError("grid.cell_size_m must be finite and > 0")
xmin = float(np.min(x))
ymin = float(np.min(y))
xmax = float(np.max(x))
ymax = float(np.max(y))
# Estimate rectangle shape from the full extent, not just max recovered indices.
# This is robust to the small translation you get by picking a different
# lon/lat reference for the local projection.
nx = int(np.rint((xmax - xmin) / cs)) + 1
ny = int(np.rint((ymax - ymin) / cs)) + 1
# Recover integer (ix, iy) indices for each centroid.
ix = np.rint((x - xmin) / cs).astype(int)
iy = np.rint((y - ymin) / cs).astype(int)
if np.any(ix < 0) or np.any(iy < 0):
raise ValueError("grid appears to be non-regular or contains invalid coordinates")
n_expected = nx * ny
if n_expected != int(grid.lon.shape[0]):
raise ValueError(
"grid does not look like a full regular rectangle: "
f"nx*ny={n_expected} != n_cells={int(grid.lon.shape[0])}"
)
# Ensure bijection between recovered ids and [0, n_cells).
cid = iy * nx + ix
if (
np.unique(cid).shape[0] != cid.shape[0]
or cid.min() != 0
or cid.max() != cid.shape[0] - 1
):
raise ValueError("grid does not map cleanly to a regular (nx, ny) indexing")
x0_edge = xmin - 0.5 * cs
y0_edge = ymin - 0.5 * cs
return cls(
tf=tf,
x0_edge=float(x0_edge),
y0_edge=float(y0_edge),
nx=nx,
ny=ny,
cell_size_m=cs,
)
[docs]
def lonlat_to_cell_id(
self, lon: float | np.ndarray, lat: float | np.ndarray
) -> int | np.ndarray:
"""Map lon/lat to cell id(s).
Parameters
----------
lon, lat:
Scalars or numpy arrays of equal shape.
Returns
-------
int or np.ndarray:
``-1`` (or array with ``-1``) for points outside the grid.
"""
x, y = self.tf.to_xy.transform(lon, lat)
x = np.asarray(x, dtype=float)
y = np.asarray(y, dtype=float)
if x.shape != y.shape:
raise ValueError("lon and lat must have the same shape")
# Boundary convention:
# - left/bottom edges inclusive
# - right/top edges exclusive
# We compute `inside` in the continuous projected space to avoid
# misclassifying points infinitesimally outside the max edges due to
# floating point effects in the floor/index computation.
x1_edge = self.x0_edge + self.nx * self.cell_size_m
y1_edge = self.y0_edge + self.ny * self.cell_size_m
inside = (
np.isfinite(x)
& np.isfinite(y)
& (x >= self.x0_edge)
& (x < x1_edge)
& (y >= self.y0_edge)
& (y < y1_edge)
)
ix = np.floor((x - self.x0_edge) / self.cell_size_m).astype(int)
iy = np.floor((y - self.y0_edge) / self.cell_size_m).astype(int)
cid = iy * self.nx + ix
if cid.ndim == 0:
return int(cid) if bool(inside) else -1
out = np.full(cid.shape, -1, dtype=int)
out[inside] = cid[inside]
return out
[docs]
def lonlat_to_cell_id(
grid: Grid, *, lon: float | np.ndarray, lat: float | np.ndarray
) -> int | np.ndarray:
"""Convenience wrapper: build lookup from grid and map lon/lat -> cell_id."""
return GridCellLookup.from_grid(grid).lonlat_to_cell_id(lon=lon, lat=lat)