Source code for motac.spatial.lookup

from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from motac.spatial.crs import LonLatToXY
from motac.spatial.types import Grid


[docs] @dataclass(frozen=True, slots=True) class GridCellLookup: """Fast lon/lat -> cell_id lookup for regular grids. Notes ----- - Assumes the grid was generated by :func:`motac.spatial.grid_builder.build_regular_grid`. - Cell ids follow the ravel order used by ``np.meshgrid(xs, ys)`` and ``ravel()``: x varies fastest (row-major), so ``cell_id = iy * nx + ix``. - Returns ``-1`` for points outside the grid extent. """ tf: LonLatToXY x0_edge: float y0_edge: float nx: int ny: int cell_size_m: float
[docs] @classmethod def from_grid(cls, grid: Grid) -> GridCellLookup: n_cells = int(np.asarray(grid.lon).shape[0]) if n_cells < 2: raise ValueError("grid must contain at least 2 cells") lon0 = float(np.mean(grid.lon)) lat0 = float(np.mean(grid.lat)) tf = LonLatToXY.for_lonlat(lon0, lat0) x, y = tf.to_xy.transform(grid.lon, grid.lat) x = np.asarray(x, dtype=float) y = np.asarray(y, dtype=float) if x.shape != grid.lon.shape or y.shape != grid.lat.shape: raise ValueError("grid.lon/grid.lat must be 1D arrays of equal length") cs = float(grid.cell_size_m) if not np.isfinite(cs) or cs <= 0: raise ValueError("grid.cell_size_m must be finite and > 0") xmin = float(np.min(x)) ymin = float(np.min(y)) xmax = float(np.max(x)) ymax = float(np.max(y)) # Estimate rectangle shape from the full extent, not just max recovered indices. # This is robust to the small translation you get by picking a different # lon/lat reference for the local projection. nx = int(np.rint((xmax - xmin) / cs)) + 1 ny = int(np.rint((ymax - ymin) / cs)) + 1 # Recover integer (ix, iy) indices for each centroid. ix = np.rint((x - xmin) / cs).astype(int) iy = np.rint((y - ymin) / cs).astype(int) if np.any(ix < 0) or np.any(iy < 0): raise ValueError("grid appears to be non-regular or contains invalid coordinates") n_expected = nx * ny if n_expected != int(grid.lon.shape[0]): raise ValueError( "grid does not look like a full regular rectangle: " f"nx*ny={n_expected} != n_cells={int(grid.lon.shape[0])}" ) # Ensure bijection between recovered ids and [0, n_cells). cid = iy * nx + ix if ( np.unique(cid).shape[0] != cid.shape[0] or cid.min() != 0 or cid.max() != cid.shape[0] - 1 ): raise ValueError("grid does not map cleanly to a regular (nx, ny) indexing") x0_edge = xmin - 0.5 * cs y0_edge = ymin - 0.5 * cs return cls( tf=tf, x0_edge=float(x0_edge), y0_edge=float(y0_edge), nx=nx, ny=ny, cell_size_m=cs, )
[docs] def lonlat_to_cell_id( self, lon: float | np.ndarray, lat: float | np.ndarray ) -> int | np.ndarray: """Map lon/lat to cell id(s). Parameters ---------- lon, lat: Scalars or numpy arrays of equal shape. Returns ------- int or np.ndarray: ``-1`` (or array with ``-1``) for points outside the grid. """ x, y = self.tf.to_xy.transform(lon, lat) x = np.asarray(x, dtype=float) y = np.asarray(y, dtype=float) if x.shape != y.shape: raise ValueError("lon and lat must have the same shape") # Boundary convention: # - left/bottom edges inclusive # - right/top edges exclusive # We compute `inside` in the continuous projected space to avoid # misclassifying points infinitesimally outside the max edges due to # floating point effects in the floor/index computation. x1_edge = self.x0_edge + self.nx * self.cell_size_m y1_edge = self.y0_edge + self.ny * self.cell_size_m inside = ( np.isfinite(x) & np.isfinite(y) & (x >= self.x0_edge) & (x < x1_edge) & (y >= self.y0_edge) & (y < y1_edge) ) ix = np.floor((x - self.x0_edge) / self.cell_size_m).astype(int) iy = np.floor((y - self.y0_edge) / self.cell_size_m).astype(int) cid = iy * self.nx + ix if cid.ndim == 0: return int(cid) if bool(inside) else -1 out = np.full(cid.shape, -1, dtype=int) out[inside] = cid[inside] return out
[docs] def lonlat_to_cell_id( grid: Grid, *, lon: float | np.ndarray, lat: float | np.ndarray ) -> int | np.ndarray: """Convenience wrapper: build lookup from grid and map lon/lat -> cell_id.""" return GridCellLookup.from_grid(grid).lonlat_to_cell_id(lon=lon, lat=lat)