Source code for raster_tools.raster

import numbers
import warnings

import dask
import dask.array as da
import dask.dataframe as dd
import geopandas as gpd
import numpy as np
import xarray as xr
from affine import Affine
from numba import jit
from shapely.geometry import Point

from raster_tools.dask_utils import dask_nanmax, dask_nanmin
from raster_tools.dtypes import (
    DTYPE_INPUT_TO_DTYPE,
    is_bool,
    is_float,
    is_int,
    is_scalar,
    is_str,
    promote_dtype_to_float,
    should_promote_to_fit,
)
from raster_tools.masking import (
    create_null_mask,
    get_default_null_value,
    reconcile_nullvalue_with_dtype,
)
from raster_tools.utils import can_broadcast, make_raster_ds, merge_masks

from .io import (
    IO_UNDERSTOOD_TYPES,
    RasterDataError,
    RasterIOError,
    chunk,
    is_batch_file,
    normalize_xarray_data,
    open_raster_from_path,
    write_raster,
)


[docs]class RasterDeviceMismatchError(BaseException): pass
[docs]class RasterDeviceError(BaseException): pass
[docs]class RasterNoDataError(BaseException): pass
_REDUCTION_FUNCS = ( "all", "any", ) _NAN_REDUCTION_FUNCS = ( "max", "mean", "min", "prod", "std", "sum", "var", ) _REDUCTION_DOCSTRING = """\ Reduce the raster to a single dask value by applying `{}` across all bands. All arguments are ignored. """ def _inject_reductions(cls): pass funcs = [(name, getattr(np, name)) for name in _REDUCTION_FUNCS] funcs += [ (name, getattr(np, "nan" + name)) for name in _NAN_REDUCTION_FUNCS ] for name, f in funcs: func = cls._build_reduce_method(f) func.__name__ = name func.__doc__ = _REDUCTION_DOCSTRING.format(name) setattr(cls, name, func) _MIN_MAX_FUNC_MAP = { np.max: dask_nanmax, np.nanmax: dask_nanmax, np.min: dask_nanmin, np.nanmin: dask_nanmin, } class _ReductionsMixin: """ This mixin class adds reduction methods like `all`, `sum`, etc to the Raster class. Having these methods also allows numpy reduction functions to use them for dynamic dispatch. So `np.sum(raster)` will call the class' `sum` method. All reductions return dask results. """ __slots__ = () def __init_subclass__(cls, **kwargs): _inject_reductions(cls) super().__init_subclass__(**kwargs) @classmethod def _build_reduce_method(cls, func): if func in _MIN_MAX_FUNC_MAP: func = _MIN_MAX_FUNC_MAP[func] def method(self, *args, **kwargs): # Take args and kwargs to stay compatible with numpy data = self.data values = data[~self._ds.mask.data] return func(values) return method def _normalize_ufunc_other(other, this): if other is None or isinstance(other, numbers.Number): return other if isinstance(other, xr.Dataset): return Raster(other) elif isinstance(other, (list, tuple, range, np.ndarray, da.Array)): # allow only if broadcastable if isinstance(other, da.Array) and np.isnan(other.size): raise ValueError("Dask arrays must have known chunks") other = np.atleast_1d(other) if not can_broadcast(this.shape, other.shape): msg = ( "Raster could not broadcast together with array of shape" f" {other.shape}." ) if other.squeeze().ndim == 1 and other.size == this.shape[0]: msg += " Did you mean to use Raster.bandwise?" raise ValueError(msg) return other elif isinstance(other, xr.DataArray): return Raster(other) else: assert isinstance(other, Raster) return other def _apply_ufunc(ufunc, this, left, right=None, kwargs=None, out=None): args = [left] if right is not None: args.append(right) other = left if left is not this else right types = [getattr(a, "dtype", type(a)) for a in args] masked = any(getattr(r, "_masked", False) for r in args) ufunc_args = [getattr(a, "xdata", a) for a in args] kwargs = kwargs or {} out_crs = None if this.crs is not None: out_crs = this.crs elif isinstance(other, Raster) and other.crs is not None: out_crs = other.crs ufname = ufunc.__name__ if ufname.startswith("bitwise") and any(is_float(t) for t in types): raise TypeError( "Bitwise operations are not compatible with float dtypes. You may" " want to convert one or more inputs to a boolean or integer dtype" " (e.g. 'raster > 0' or 'raster.astype(int)')." ) if ufname.endswith("shift") and any(is_float(t) for t in types): raise TypeError( "right_shift and left_shift operations are not compatible with" " float dtypes. You may want to convert one or more inputs to an" " integer dtype (e.g. 'raster.astype(int)')." ) xr_out = ufunc(*ufunc_args, **kwargs) multi_out = ufunc.nout > 1 if not masked: mask = create_null_mask(xr_out, None) if not multi_out: xmask = xr.DataArray(mask, dims=xr_out.dims, coords=xr_out.coords) xr_out = xr_out.rio.write_nodata(None) ds_out = make_raster_ds(xr_out, xmask) else: xmask = xr.DataArray( mask, dims=xr_out[0].dims, coords=xr_out[0].coords ) xr_out = [x.rio.write_nodata(None) for x in xr_out] ds_out = [make_raster_ds(x, xmask) for x in xr_out] else: xmask = merge_masks([r.xmask for r in args if isinstance(r, Raster)]) if not multi_out: nv = get_default_null_value(xr_out.dtype) xr_out = xr.where(xmask, nv, xr_out).rio.write_nodata(nv) ds_out = make_raster_ds(xr_out, xmask) else: nvs = [get_default_null_value(x.dtype) for x in xr_out] xr_out = [ xr.where(xmask, nv, x).rio.write_nodata(nv) for x, nv in zip(xr_out, nvs) ] ds_out = [make_raster_ds(x, xmask) for x in xr_out] if out_crs is not None: if multi_out: ds_out = [x.rio.write_crs(out_crs) for x in ds_out] else: ds_out = ds_out.rio.write_crs(out_crs) if out is not None: # "Inplace" out._ds = ds_out return out if not multi_out: return Raster(ds_out, _fast_path=True) rs_outs = tuple(Raster(x, _fast_path=True) for x in ds_out) return rs_outs _UNARY_UFUNCS = frozenset( (np.absolute, np.invert, np.logical_not, np.negative, np.positive) ) _UNSUPPORED_UFUNCS = frozenset((np.isnat, np.matmul)) class _RasterBase(np.lib.mixins.NDArrayOperatorsMixin, _ReductionsMixin): """This class implements methods for handling numpy ufuncs.""" __slots__ = () _HANDLED_TYPES = ( numbers.Number, xr.DataArray, np.ndarray, da.Array, list, tuple, range, ) # Higher than xarray objects __array_priority__ = 70 def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): out = kwargs.pop("out", ()) for x in inputs + out: if not isinstance(x, self._HANDLED_TYPES + (_RasterBase,)): return NotImplemented if ufunc in _UNSUPPORED_UFUNCS: raise TypeError( f"Raster objects are not supported for ufunc: '{ufunc}'." ) if ufunc.signature is not None: raise NotImplementedError("Raster does not support gufuncs") if method != "__call__": raise NotImplementedError( f"{method} for ufunc {ufunc} is not implemented on Raster " "objects." ) if len(inputs) > ufunc.nin: raise TypeError( "Too many inputs for ufunc:" f" inputs={len(inputs)}, ufunc={ufunc!r}" ) if len(out): if len(out) > 1: raise NotImplementedError( "The 'out' keyword is only supported for inplace " "operations, not operations with multiple outputs." ) (out,) = out if out is not self: raise NotImplementedError( "'out' must be the raster being operated on." ) else: out = None left = inputs[0] right = inputs[1] if ufunc.nin == 2 else None other = left if left is not self else right other = _normalize_ufunc_other(other, self) if self is left: right = other else: left = other return _apply_ufunc(ufunc, self, left, right=right, out=out) def __array__(self, dtype=None): return self._ds.raster.__array__(dtype) def _normalize_bandwise_other(other, target_shape): other = np.atleast_1d(other) if isinstance(other, da.Array) and np.isnan(other.size): raise ValueError("Dask arrays must have known chunks") if other.size == 1: return other.ravel() if can_broadcast(other.shape, target_shape): raise ValueError( "Operands can be broadcast together. A bandwise operation is not" " needed here." ) if other.ndim == 1 and other.size == target_shape[0]: # Map a list/array of values that has same length as # of bands to # each band return other.reshape((-1, 1, 1)) raise ValueError( "Received array with incompatible shape in bandwise operation: " f"{other.shape}" )
[docs]class BandwiseOperationAdapter(np.lib.mixins.NDArrayOperatorsMixin): __slots__ = ("_raster",) _HANDLED_TYPES = ( numbers.Number, np.ndarray, da.Array, list, tuple, range, ) # Higher than xarray objects __array_priority__ = 70 def __init__(self, raster): self._raster = raster def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): out = kwargs.pop("out", None) for x in inputs: if not isinstance( x, self._HANDLED_TYPES + ( Raster, BandwiseOperationAdapter, ), ): return NotImplemented if len(inputs) < 2: raise ValueError( "Bandwise operations are binary and require 2 inputs. Only" " one given." ) if len(inputs) > 2: raise ValueError( "Bandwise operations are binary and require 2 inputs." f" {len(inputs)} given." ) if ufunc in _UNSUPPORED_UFUNCS: raise TypeError( "The given ufunc is not supported for bandwise operations:" f" {ufunc!r}." ) if ufunc.signature is not None: raise NotImplementedError("Raster does not support gufuncs") if method != "__call__": raise NotImplementedError( f"{method} for ufunc {ufunc} is not implemented for bandwise " " operations." ) if len(inputs) > ufunc.nin: raise TypeError( "Too many inputs for ufunc:" f" inputs={len(inputs)}, ufunc={ufunc!r}" ) if out is not None: raise NotImplementedError( "The 'out' keyword is not supported for bandwise operations" ) left, right = inputs other = left if left is not self else right other = _normalize_bandwise_other(other, self._raster.shape) if self is left: left = self._raster right = other else: left = other right = self._raster return ufunc(left, right, **kwargs)
def _array_to_xarray(ar): if len(ar.shape) > 3 or len(ar.shape) < 2: raise ValueError(f"Invalid raster shape for numpy array: {ar.shape}") if len(ar.shape) == 2: # Add band dim ar = ar[None] if not isinstance(ar, da.Array): has_nan = is_float(ar.dtype) and np.isnan(ar).any() else: has_nan = is_float(ar.dtype) # The band dimension needs to start at 1 to match raster conventions coords = [np.arange(1, ar.shape[0] + 1)] # Add 0.5 offset to move coords to center of cell # y coords.append(np.arange(ar.shape[1])[::-1] + 0.5) coords.append(np.arange(ar.shape[2]) + 0.5) xrs = xr.DataArray(ar, dims=["band", "y", "x"], coords=coords) if has_nan: xrs.attrs["_FillValue"] = np.nan else: # No way to know null value xrs.attrs["_FillValue"] = None return xrs def _try_to_get_null_value_xarray(xrs): if ( "_FillValue" in xrs.attrs and xrs.attrs["_FillValue"] is None and xrs.rio.nodata is None and xrs.rio.encoded_nodata is None and not dask.is_dask_collection(xrs) ): return None null = xrs.attrs.get("_FillValue", None) if null is not None: return null nv1 = xrs.rio.nodata nv2 = xrs.rio.encoded_nodata # All are None, return nan for float and None otherwise if all(nv is None for nv in [nv1, nv2]): if is_float(xrs.dtype): if dask.is_dask_collection(xrs) or np.isnan(xrs.data).any(): return np.nan else: return None else: return None # All are not None, return nv1 if all(nv is not None for nv in [nv1, nv2]): return nv1 # One is None and the other is not, return the valid null value return nv1 if nv1 is not None else nv2 def _dataarry_to_raster_ds(xin): nv = _try_to_get_null_value_xarray(xin) xin = normalize_xarray_data(xin) if not dask.is_dask_collection(xin) or xin.data.chunksize[0] != 1: xin = chunk(xin) if nv is None: mask = create_null_mask(xin, nv) mask = xr.DataArray( create_null_mask(xin, nv), dims=xin.dims, coords=xin.coords ) if xin.rio.crs is not None: mask = mask.rio.write_crs(xin.rio.crs) elif np.isnan(nv): mask = np.isnan(xin) else: mask = xin == nv orig_nv = nv nv = reconcile_nullvalue_with_dtype(nv, xin.dtype) crs = xin.rio.crs if nv is not None and (np.isnan(orig_nv) or orig_nv != nv): xin = xr.where(mask, nv, xin) if crs is not None: xin = xin.rio.write_crs(crs) xin = xin.rio.write_nodata(nv) return make_raster_ds(xin, mask) def _dataset_to_raster_ds(xin): if ( dask.is_dask_collection(xin) and set(xin.data_vars) == set(["raster", "mask"]) and dask.is_dask_collection(xin.raster) and dask.is_dask_collection(xin.mask) and is_bool(xin.mask.dtype) and xin.raster.dims == ("band", "y", "x") and xin.mask.dims == ("band", "y", "x") and xin.raster.chunks == xin.mask.chunks and xin.raster.rio.crs == xin.mask.rio.crs ): return xin nvars = len(xin.data_vars) if nvars < 2 or nvars > 2 or set(xin.data_vars) != set(["raster", "mask"]): raise ValueError( "Dataset input must have only 'raster' and 'mask' variables" ) raster = xin.raster mask = xin.mask if not is_bool(mask.dtype): raise TypeError("mask must have boolean dtype") raster = normalize_xarray_data(raster) if not dask.is_dask_collection(raster): raster = chunk(raster) mask = normalize_xarray_data(mask) if mask.shape != raster.shape: raise ValueError("raster and mask dimensions do not match") if not dask.is_dask_collection(mask): mask = mask.chunk(chunks=raster.chunks) nv = _try_to_get_null_value_xarray(raster) if nv is None: nv = get_default_null_value(raster.dtype) else: nv = reconcile_nullvalue_with_dtype(nv, raster.dtype) if nv is not None: raster.data = xr.where(mask, nv, raster.data) raster.rio.write_nodata(nv, inplace=True) crs = raster.rio.crs if crs != mask.rio.crs: mask.rio.write_crs(crs, inplace=True) return make_raster_ds(raster, mask) def _xarray_to_raster_ds(xin): if isinstance(xin, xr.DataArray): return _dataarry_to_raster_ds(xin) return _dataset_to_raster_ds(xin)
[docs]def get_raster_ds(raster): # Copy to take ownership if isinstance(raster, Raster): ds = raster._ds.copy(deep=True) elif isinstance(raster, (xr.DataArray, xr.Dataset)): ds = _xarray_to_raster_ds(raster.copy()) elif isinstance(raster, (np.ndarray, da.Array)): ds = _xarray_to_raster_ds(_array_to_xarray(raster.copy())) elif type(raster) in IO_UNDERSTOOD_TYPES: if is_batch_file(raster): # Import here to avoid circular import errors from raster_tools.batch import parse_batch_script ds = parse_batch_script(raster).final_raster._ds else: rs, mask, nv = open_raster_from_path(raster) xmask = xr.DataArray(mask, dims=rs.dims, coords=rs.coords) ds = make_raster_ds(rs.rio.write_nodata(nv), xmask) ds = _xarray_to_raster_ds(ds) else: raise TypeError(f"Could not resolve input to a raster: {raster!r}") return ds
[docs]class Raster(_RasterBase): """Abstraction of georeferenced raster data with lazy function evaluation. Raster is a wrapper around xarray Datasets and DataArrays. It takes advantage of xarray's dask integration to allow lazy loading and evaluation. It allows a pipeline of operations on underlying raster sources to be built in a lazy fashion and then evaluated effiently. Most mathematical operations have been overloaded so operations such as ``z = x - y`` and ``r = x**2`` are possible. All operations on a Raster return a new Raster. Parameters ---------- raster : str, Raster, xarray.DataArray, xarray.Dataset, numpy.ndarray The raster source to use for this Raster. If `raster` is a string, it is treated like a path. If `raster` is a Raster, a copy is made and its raster source is used. If `raster` is an xarray DataArray, Dataset, or numpy array, it is used as the source. Dataset objects must contain 'raster' and 'mask' variables. The dimensions of these variables are assumed to be "band", "y", "x", in that order. The 'mask' variable must have boolean dtype. """ __slots__ = ("_ds",)
[docs] def __init__(self, raster, _fast_path=False): if _fast_path: self._ds = raster else: self._ds = get_raster_ds(raster)
def __repr__(self): # TODO: implement return repr(self._ds.raster) @property def null_value(self): """The raster's null value used to fill missing or invalid entries.""" return self._ds.raster.rio.nodata @property def _masked(self): return self.null_value is not None @property def xdata(self): """The underlying :class:`xarray.DataArray` (read only)""" return self._ds.raster @property def data(self): """The underlying dask array of data""" return self._ds.raster.data @property def values(self): """ The raw internal raster as a numpy array. .. note:: This triggers computation and loads the raster data into memory. """ return self._ds.raster.values @property def mask(self): """The null value mask as a dask array.""" return self._ds.mask.data @property def xmask(self): """The null value mask as an xarray DataArray.""" return self._ds.mask @property def x(self): """The x (horizontal) coordinate values.""" return self._ds.x.data @property def y(self): """The y (vertical) coordinate values.""" return self._ds.y.data @property def dtype(self): """The dtype of the data.""" return self._ds.raster.dtype @property def shape(self): """The shape of the underlying raster. Dim 0 is the band dimension. The shape will always be of the form ``(B, Y, X)`` where ``B`` is the band dimension, ``Y`` is the y dimension, and ``X`` is the x dimension. """ return self._ds.raster.shape @property def nbands(self): """The number of bands.""" return self._ds.band.size @property def crs(self): """The raster's CRS. This is a :obj:`rasterio.crs.CRS` object.""" return self._ds.rio.crs @property def affine(self): """The affine transformation for the raster data. This is an :obj:`affine.Affine` object. """ return self._ds.rio.transform(True) @property def resolution(self): """The x and y cell sizes as a tuple. Values are always positive.""" return self._ds.rio.resolution(True) @property def bounds(self): """Bounds tuple (minx, miny, maxx, maxy)""" r, c = self.shape[1:] minx, maxy = self.xy(0, 0, offset="ul") maxx, miny = self.xy(r - 1, c - 1, "lr") return (minx, miny, maxx, maxy) @property def bandwise(self): """Returns an adapter for band-wise operations on this raster The returned adapter allows lists/arrays of values to be mapped to the raster's bands for binary operations. """ return BandwiseOperationAdapter(self) @property def _rs(self): warnings.warn( "'Raster._rs' is deprecated. It will soon be removed. Use xdata.", DeprecationWarning, stacklevel=2, ) return self._ds.raster @property def _data(self): warnings.warn( "'Raster._data' is deprecated. It will soon be removed. Use data.", DeprecationWarning, stacklevel=2, ) return self._ds.raster.data @property def _values(self): warnings.warn( "'Raster._values' is deprecated. It will soon be removed. Use " "values.", DeprecationWarning, stacklevel=2, ) return self._ds.raster.values @property def _mask(self): warnings.warn( "'Raster._mask' is deprecated. It will soon be removed. Use mask.", DeprecationWarning, stacklevel=2, ) return self._ds.mask.data @property def _null_value(self): warnings.warn( "'Raster._null_value' is deprecated. It will soon be removed. Use " "null_value.", DeprecationWarning, stacklevel=2, ) return self.null_value @property def xrs(self): warnings.warn( "'Raster.xrs' is deprecated. It will soon be removed. Use xdata.", DeprecationWarning, stacklevel=2, ) return self._ds.raster @property def pxrs(self): warnings.warn( "'Raster.pxrs' is deprecated. It will soon be removed.", DeprecationWarning, stacklevel=2, ) return get_raster(self, null_to_nan=True)._ds.raster
[docs] def chunk(self, chunks): """Rechunk the underlying raster Parameters ---------- chunks : tuple of int The chunk sizes for each dimension ("band", "y", "x", in that order). Must be a 3-tuple. Returns ------- Raster The rechunked raster """ assert len(chunks) == 3 return Raster( self._ds.chunk( {d: cs for d, cs in zip(["band", "y", "x"], chunks)} ), _fast_path=True, )
[docs] def to_dataset(self): """Returns the underlying `xarray.Dataset`.""" return self._ds
[docs] def plot(self, *args, **kwargs): """ Plot the raster. Args and kwargs are passed to xarray's plot function. """ return get_raster(self, null_to_nan=True).xdata.plot(*args, **kwargs)
[docs] def get_chunked_coords(self): """Get lazy coordinate arrays, in x-y order, chunked to match data.""" xc = da.from_array(self.x, chunks=self.data.chunks[2]).reshape( (1, 1, -1) ) yc = da.from_array(self.y, chunks=self.data.chunks[1]).reshape( (1, -1, 1) ) return xc, yc
[docs] def close(self): """Close the underlying source""" self._ds.close()
[docs] def save(self, path, no_data_value=None, **gdal_kwargs): """Compute the final raster and save it to the provided location. Parameters ---------- path : str The target location to save the raster to. no_data_value : scalar, optional A new null value to use when saving. **gdal_kwargs : kwargs, optional Additional keyword arguments to to pass to rasterio and GDAL when writing the raster data. Returns ------- Raster A raster pointing to the saved location. """ # TODO: warn of overwrite rs = self if no_data_value is not None: rs = self.set_null_value(no_data_value) if rs._masked and is_bool(rs.dtype): # Cast to uint and burn in mask rs = rs.astype("uint8").set_null_value( get_default_null_value("uint8") ) xrs = rs.xdata write_raster( xrs, path, rs.null_value if no_data_value is None else no_data_value, **gdal_kwargs, ) return Raster(path)
[docs] def eval(self): """Compute any applied operations and return the result as new Raster. Note that the unerlying sources will be loaded into memory for the computations and the result will be fixed in memory. The original Raster will be unaltered. """ ds = chunk(self._ds.compute()) # A new raster is returned to mirror the xarray and dask APIs return Raster(ds)
[docs] def copy(self): """Returns a copy of this Raster.""" return Raster(self)
[docs] def astype(self, dtype, warn_about_null_change=True): """Return a copy of the Raster cast to the specified type. Parameters ---------- dtype : str, type, numpy.dtype The new type to cast the raster to. The null value will also be cast to this type. If the null value cannot be safely cast to the new type, a default null value for the given `dtype` is used. This will produce a warning unless `warn_about_null_change` is set to ``False``. warn_about_null_change : bool, optional Can be used to silence warnings. The default is to always warn about null value changes. Returns ------- Raster The new `dtype` raster. """ if dtype not in DTYPE_INPUT_TO_DTYPE: raise ValueError(f"Unsupported type: '{dtype}'") dtype = DTYPE_INPUT_TO_DTYPE[dtype] if dtype == self.dtype: return self.copy() xrs = self._ds.raster nv = self.null_value mask = self._ds.mask xrs = xrs.astype(dtype) if self._masked: nv = reconcile_nullvalue_with_dtype( nv, dtype, warn_about_null_change ) if nv != self.null_value: xrs = xr.where(mask, nv, xrs).rio.write_nodata(nv) ds = make_raster_ds(xrs, mask) if self.crs is not None: ds = ds.rio.write_crs(self.crs) return Raster(ds, _fast_path=True)
[docs] def get_bands(self, bands): """Retrieve the specified bands as a new Raster. Indexing starts at 1. Parameters ---------- bands : int or sequence of ints The band or bands to be retrieved. A single band Raster is returned if `bands` is an int and a multiband Raster is returned if it is a sequence. The band numbers may be out of order and contain repeated elements. if `bands` is a sequence, the multiband raster's bands will be in the order provided. Returns ------- Raster The resulting raster composed of the specified bands in the order given. """ if is_int(bands): bands = [bands] if not all(is_int(b) for b in bands): raise TypeError("All band numbers must be integers") bands = list(bands) if len(bands) == 0: raise ValueError("No bands provided") if any(b < 1 or b > self.nbands for b in bands): raise IndexError( f"One or more band numbers were out of bounds: {bands}" ) if len(bands) == 1 and self.nbands == 1: return self.copy() # We have to do sel() followed by concat() in order to get size 1 # chunking along the band dim dss = [self._ds.sel(band=b) for b in bands] ds = xr.concat(dss, dim="band") # Reset band numbering ds["band"] = np.arange(len(bands)) + 1 return Raster(ds, _fast_path=True)
[docs] def set_crs(self, crs): """Set the CRS for the underlying data. Parameters ---------- crs : object The desired CRS. This can be anything accepted by :obj:`rasterio.crs.CRS.from_user_input` (i.e. `4326`, `"epsg:4326"`, etc). Returns ------- Raster A new raster with the specified CRS. """ return Raster(self._ds.rio.write_crs(crs), _fast_path=True)
[docs] def set_null_value(self, value): """Sets or replaces the null value for the raster. If there was previously no null value for the raster, one is set. If there was already a null value, then it is replaced. If the raster has an integer dtype and `value` is a float, the dtype will be promoted to a float dtype. If `value` is None, the null value is cleared. The raster data is not changed in this case. Parameters ---------- value : scalar or None The new null value for the raster. If None, the resulting raster will have no null value, i.e. the null value is cleared. Returns ------- Raster The new resulting Raster. """ if value is not None and not (is_scalar(value) or is_bool(value)): raise TypeError(f"Value must be a scalar or None: {value}") xrs = self._ds.raster.copy() # Cast up to float if needed if should_promote_to_fit(self.dtype, value): xrs = xrs.astype(promote_dtype_to_float(self.dtype)) if value is None: # Burn in current mask values and then clear null value # TODO: burning should not be needed as the values should already # be set xrs = self.burn_mask().xdata.rio.write_nodata(None) mask = xr.zeros_like(xrs, dtype=bool) return Raster(make_raster_ds(xrs, mask), _fast_path=True) # Update mask mask = self._ds.mask temp_mask = np.isnan(xrs) if np.isnan(value) else xrs == value if self._masked: mask = mask | temp_mask else: mask = temp_mask xrs = xrs.rio.write_nodata(value) return Raster(make_raster_ds(xrs, mask), _fast_path=True).burn_mask()
[docs] def burn_mask(self): """Fill null-masked cells with null value. Use this as a way to return the underlying data to a known state. If dtype is boolean, the null cells are set to ``True`` instead of promoting to fit the null value. """ if not self._masked: return self nv = self.null_value if is_bool(self.dtype): nv = get_default_null_value(self.dtype) # call write_nodata because xr.where drops the nodata info xrs = xr.where(self._ds.mask, nv, self._ds.raster).rio.write_nodata(nv) if self.crs is not None: xrs = xrs.rio.write_crs(self.crs) return Raster(make_raster_ds(xrs, self._ds.mask), _fast_path=True)
[docs] def to_null_mask(self): """ Returns a boolean Raster with True at null values and False otherwise. Returns ------- Raster The resulting mask Raster. It is True where this raster contains null values and False everywhere else. """ return Raster( make_raster_ds( self._ds.mask.rio.write_nodata(None), xr.zeros_like(self._ds.mask), ), _fast_path=True, )
[docs] def replace_null(self, value): """Replaces null values with `value`. Parameters ---------- value : scalar The new value to replace null values with. Returns ------- Raster The new resulting Raster. If `value` is a float and the raster dtype is int, the raster type will be promoted to float. """ if not is_scalar(value): raise TypeError("value must be a scalar") return self.set_null_value(value)
[docs] def where(self, condition, other): """Filter elements from this raster according to `condition`. Parameters ---------- condition : str or Raster A boolean or int raster that indicates where elements in this raster should be preserved and where `other` should be used. If the condition is an int raster, it is coerced to bool using `condition > 0`. ``True`` cells pull values from this raster and ``False`` cells pull from `other`. *str* is treated as a path to a raster. other : scalar, str or Raster A raster or value to use in locations where `condition` is ``False``. *str* is treated as a path to a raster. Returns ------- Raster The resulting filtered Raster. """ from raster_tools.general import where return where(condition, self, other)
[docs] def remap_range(self, mapping, inclusivity="left"): """Remaps values based on a mapping or list of mappings. Mappings are applied all at once with earlier mappings taking precedence. Parameters ---------- mapping : 3-tuple of scalars or list of 3-tuples of scalars A tuple or list of tuples containing ``(min, max, new_value)`` scalars. The mappiing(s) map values between the min and max to the ``new_value``. If `mapping` is a list and there are mappings that conflict or overlap, earlier mappings take precedence. `inclusivity` determines which sides of the range are inclusive and exclusive. inclusivity : str, optional Determines whether to be inclusive or exclusive on either end of the range. Default is `'left'`. 'left' [min, max) Left (min) side is inclusive and right (max) side is exclusive. 'right' (min, max] Left (min) side is exclusive and right (max) side is inclusive. 'both' [min, max] Both sides are inclusive. 'none' (min, max) Both sides are exclusive. Returns ------- Raster The resulting Raster. See Also -------- raster_tools.general.remap_range """ # local import to avoid circular import from raster_tools.general import remap_range return remap_range(self, mapping, inclusivity=inclusivity)
[docs] def reclassify(self, remapping, unmapped_to_null=False): """Reclassify raster values based on a mapping. The raster must have an integer type. Parameters ---------- remapping : str, dict Can be either a ``dict`` or a path string. If a ``dict`` is provided, the keys will be reclassified to the corresponding values. If a path string, it is treated as an ASCII remap file where each line looks like ``a:b`` and ``a`` and ``b`` are integers. All remap values (both from and to) must be integers. unmapped_to_null : bool, optional If ``True``, values not included in the mapping are instead mapped to the null value. Default is ``False``. Returns ------- Raster The remapped raster. """ from raster_tools.general import reclassify return reclassify(self, remapping, unmapped_to_null=unmapped_to_null)
[docs] def round(self, decimals=0): """Evenly round to the given number of decimals Parameters ---------- decimals : int, optional The number of decimal places to round to. If negative, value specifies the number of positions to the left of the decimal point. Default is 0. Returns ------- Raster Rounded raster. """ rast = self._ds.raster if self._masked: rast = xr.where( self._ds.mask, self.null_value, rast.round(decimals=decimals) ).rio.write_nodata(self.null_value) else: rast = rast.round(decimals=decimals) ds = make_raster_ds(rast, self._ds.mask) if self.crs is not None: ds = ds.rio.write_crs(self.crs) return Raster(ds, _fast_path=True)
[docs] def xy(self, row, col, offset="center"): """Return the `(x, y)` coordinates of the pixel at `(row, col)`. Parameters ---------- row : int, float, array-like row value(s) in the raster's CRS col : int, float, array-like row value(s) in the raster's CRS offset : str, optional Determines if the returned coordinates are for the center or a corner. Default is `'center'` 'center' The pixel center. 'ul' The upper left corner. 'ur' The upper right corner. 'll' The lower left corner. 'lr' The lower right corner. Returns ------- tuple (x value(s), y value(s)). If the inputs where array-like, the output values will be arrays. """ return rowcol_to_xy(row, col, self.affine, offset)
[docs] def index(self, x, y): """Return the `(row, col)` index of the pixel at `(x, y)`. Parameters ---------- x : float, array-like x value(s) in the raster's CRS y : float, array-like y value(s) in the raster's CRS Returns ------- tuple (row value(s), col value(s)). If the inputs where array-like, the output values will be arrays. """ return xy_to_rowcol(x, y, self.affine)
[docs] def to_vector(self): """ Convert the raster into a vector where each non-null cell is a point. The resulting points are located at the center of each grid cell. Returns ------- result The result is a dask_geopandas GeoDataFrame with the following columns: value, row, col, geometry. The value is the cell value at the row and column in the raster. """ xrs = self._ds.raster mask = self._ds.mask.data data_delayed = xrs.data.to_delayed() chunks_shape = data_delayed.shape data_delayed = data_delayed.ravel() mask_delayed = mask.to_delayed().ravel() xc = da.from_array(xrs.x.data, chunks=xrs.chunks[2]) xc_delayed = xc.to_delayed() yc = da.from_array(xrs.y.data, chunks=xrs.chunks[1]) yc_delayed = yc.to_delayed() # See issue for inspiration: https://github.com/dask/dask/issues/7589 # Group chunk data with corresponding coordinate data chunks = [] for k, (d, m) in enumerate(zip(data_delayed, mask_delayed)): # band index, chunk y index, chunk x index b, i, j = np.unravel_index(k, chunks_shape) xck = xc_delayed[j] yck = yc_delayed[i] # Pass affine as tuple because of how dask handles namedtuple # subclasses. # ref: https://github.com/rasterio/affine/issues/79 chunks.append( (d, m, xck, yck, b + 1, self.crs, tuple(self.affine)) ) meta = gpd.GeoDataFrame( { "value": [self.dtype.type(1)], "band": [1], "row": [0], "col": [0], "geometry": [Point(xrs.x.data[0], xrs.y.data[0])], }, crs=self.crs, ) results = [ dd.from_delayed(_vectorize(*chunk), meta=meta) for chunk in chunks ] return dd.concat(results)
_offset_name_to_rc_offset = { "center": (0.5, 0.5), "ul": (0, 0), "ur": (0, 1), "ll": (1, 0), "lr": (1, 1), }
[docs]def rowcol_to_xy(row, col, affine, offset): """ Convert (row, col) index values to (x, y) coords using the transformation. """ roffset, coffset = _offset_name_to_rc_offset[offset] T = Affine.identity().translation(coffset, roffset) return affine * T * (col, row)
[docs]def xy_to_rowcol(x, y, affine): """ Convert (x, y) coords to (row, col) index values using the transformation. """ col, row = (~affine) * (x, y) row = np.floor(row).astype(int) col = np.floor(col).astype(int) return row, col
@jit(nopython=True, nogil=True) def _extract_points(mask, xc, yc): n = mask.size - np.sum(mask) rx = np.empty(n, dtype=xc.dtype) ry = np.empty(n, dtype=yc.dtype) if n == 0: return rx, ry k = 0 for i in range(mask.shape[1]): for j in range(mask.shape[2]): if mask[0, i, j]: continue rx[k] = xc[j] ry[k] = yc[i] k += 1 return (rx, ry) @jit(nopython=True, nogil=True) def _extract_values(data, mask): n = mask.size - np.sum(mask) results = np.empty(n, dtype=data.dtype) if n == 0: return results k = 0 for i in range(data.shape[1]): for j in range(data.shape[2]): if mask[0, i, j]: continue results[k] = data[0, i, j] k += 1 return results @dask.delayed def _vectorize(data, mask, xc, yc, band, crs, affine_tuple): affine = Affine(*affine_tuple[:6]) xpoints, ypoints = _extract_points(mask, xc, yc) if len(xpoints): values = _extract_values(data, mask) points = gpd.GeoSeries.from_xy(xpoints, ypoints, crs=crs) rows, cols = xy_to_rowcol(xpoints, ypoints, affine) bands = [band] * len(values) else: values = [] points = gpd.GeoSeries([], crs=crs) bands = [] rows = [] cols = [] df = gpd.GeoDataFrame( { "value": values, "band": bands, "row": rows, "col": cols, "geometry": points, }, crs=crs, ) if len(xpoints): return df # Add astype() as a workaround for # https://github.com/geopandas/dask-geopandas/issues/190 return df.astype( {"value": data.dtype, "band": int, "row": int, "col": int} )
[docs]def get_raster(src, strict=True, null_to_nan=False): rs = None if isinstance(src, Raster): rs = src elif is_str(src): rs = Raster(src) elif strict: raise TypeError( f"Input must be a Raster or path string. Got {repr(type(src))}" ) else: try: rs = Raster(src) except (ValueError, TypeError, RasterDataError, RasterIOError): raise ValueError(f"Could not convert input to Raster: {repr(src)}") if null_to_nan and rs._masked: rs = rs.copy() data = rs.data new_dtype = promote_dtype_to_float(data.dtype) if new_dtype != data.dtype: data = data.astype(new_dtype) rs.xdata.data = da.where(rs.mask, np.nan, data) return rs