Source code for raster_tools.general

"""
   Description: general module used to perform common spatial analyses
   on Raster objects

   * `ESRI Generalization Tools <https://pro.arcgis.com/en/pro-app/latest/tool-reference/spatial-analyst/an-overview-of-the-generalization-tools.htm>`_
   * `ESRI Local Tools <https://pro.arcgis.com/en/pro-app/latest/tool-reference/spatial-analyst/an-overview-of-the-local-tools.htm>`_

"""  # noqa: E501
import os
import re
from collections.abc import Iterable, Sequence
from functools import partial

import dask.array as da
import numba as nb
import numpy as np
import xarray as xr
from dask_image import ndmeasure as ndm
from scipy.ndimage import (
    binary_dilation,
    binary_erosion,
    grey_dilation,
    grey_erosion,
)

from raster_tools.creation import empty_like, zeros_like
from raster_tools.dtypes import (
    BOOL,
    F16,
    F32,
    F64,
    I8,
    I16,
    I32,
    I64,
    U8,
    U16,
    U32,
    U64,
    get_common_dtype,
    is_bool,
    is_int,
    is_scalar,
    is_str,
)
from raster_tools.focal import _get_offsets
from raster_tools.masking import (
    create_null_mask,
    get_default_null_value,
    reconcile_nullvalue_with_dtype,
)
from raster_tools.raster import Raster, get_raster
from raster_tools.stat_common import (
    nan_unique_count_jit,
    nanargmax_jit,
    nanargmin_jit,
    nanasm_jit,
    nanentropy_jit,
    nanmode_jit,
)
from raster_tools.utils import make_raster_ds

__all__ = [
    "aggregate",
    "band_concat",
    "dilate",
    "erode",
    "local_stats",
    "predict_model",
    "reclassify",
    "regions",
    "remap_range",
]

# TODO: mosaic


def _create_labels(xarr, wd, uarr=None):
    if uarr is None:
        uarr = da.unique(xarr).compute()
    # Drop 0 so they are skipped later
    uarr = uarr[uarr != 0]

    cum_num = 0
    result = da.zeros_like(xarr)
    for v in uarr:
        labeled_array, num_features = ndm.label((xarr == v), structure=wd)
        result += da.where(
            labeled_array > 0, labeled_array + cum_num, 0
        ).astype(result.dtype)
        cum_num += num_features
    return result


[docs]def regions(raster, neighbors=4, unique_values=None): """Calculates the unique regions (patches) within a raster band. The approach is based on ESRI's region group calculation. Parameters ---------- raster : Raster or path str The raster to perform the calculation on. All unique non zero values will be used to define unique regions. neighbors : int, optional The neighborhood connectivity value. Valid values are 4 and 8. If 4, a rook pattern is used, e.g. the neighbors along the horizontal and vertical directions are used. if 8, then all of the 8 neighbors are used. Default is 4. unique_values : array or list, optional Values that represent zones from which regions will be made. Values not included will be grouped together to form one large zone. If `None`, each unique value in the raster will be considered a zone and will be calculated up front. Returns ------- Raster The resulting raster of unique regions values. The bands will have the same shape as the original Raster. The null value mask from the origial raster will be applied to the result. References ---------- * `ESRI: Region Group <https://pro.arcgis.com/en/pro-app/latest/tool-reference/spatial-analyst/region-group.htm>`_ """ # noqa: E501 raster = get_raster(raster) rs_out = zeros_like(raster, dtype=U64) if not is_int(neighbors): raise TypeError( f"neighbors argument must be an int. Got {type(neighbors)}" ) if neighbors == 4: wd = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) elif neighbors == 8: wd = np.ones((3, 3), dtype=int) else: raise ValueError( f"Invalid value for neighbors parameter: {repr(neighbors)}" ) if unique_values is not None: if isinstance(unique_values, (np.ndarray, da.Array, Sequence)): unique_values = np.asarray(unique_values) else: raise TypeError("Invalid type for unique_values parameter") data = raster.data dout = rs_out.data if raster._masked: # Set null values to 0 to skip them in the labeling phase data = da.where(raster._ds.mask.data, 0, data) for bnd in range(data.shape[0]): dout[bnd] = _create_labels(data[bnd], wd, unique_values) if raster._masked: nv = reconcile_nullvalue_with_dtype(raster.null_value, rs_out.dtype) dout = da.where(raster._ds.mask.data, nv, dout) rs_out._ds.raster.data = dout return rs_out
@nb.jit(nopython=True, nogil=True, parallel=True) def _coarsen_chunk(x, axis, func, out_dtype, check_nan): dims = sorted(set(range(len(x.shape))) - set(axis)) shape = (x.shape[dims[0]], x.shape[dims[1]], x.shape[dims[2]]) out = np.empty(shape, out_dtype) for i in range(shape[0]): for j in nb.prange(shape[1]): for k in nb.prange(shape[2]): v = func(x[i, j, :, k, :]) if check_nan and np.isnan(v): # It doesn't matter what value is swapped with nan values # since those cells will be masked out later. v = 0 out[i, j, k] = v return out def _coarsen_block_map(x, axis, agg_func, out_dtype, check_nan): dims = sorted(set(range(len(x.shape))) - set(axis)) chunks = tuple(x.chunks[d] for d in dims) return da.map_blocks( partial( _coarsen_chunk, axis=axis, func=agg_func, out_dtype=out_dtype, check_nan=check_nan, ), x, chunks=chunks, drop_axis=axis, meta=np.array((), dtype=out_dtype), ) def _get_coarsen_dtype(stat, window_size, input_dtype): if stat == "unique": return np.min_scalar_type(window_size) if stat in ("mode", "min", "max"): return input_dtype if input_dtype == F32: return F32 return F64 _COARSEN_STYPE_TO_FUNC = { "max": lambda x: x.max(), "mean": lambda x: x.mean(), "median": lambda x: x.median(), "min": lambda x: x.min(), "prod": lambda x: x.prod(), "std": lambda x: x.std(), "sum": lambda x: x.sum(), "var": lambda x: x.var(), } _COARSEN_STYPE_TO_CUSTOM_FUNC = { "asm": nanasm_jit, "entropy": nanentropy_jit, "mode": nanmode_jit, "unique": nan_unique_count_jit, } def _get_unique_dtype(cur_dtype): if cur_dtype in (I8, U8): return I16 if cur_dtype == U16: return I32 if cur_dtype == U32: return I64 return cur_dtype
[docs]def aggregate(raster, expand_cells, stype): """Creates a Raster of aggregated cell values for a new resolution. The approach is based on ESRI's aggregate and majority filter functions. Parameters ---------- raster : Raster or path str Input Raster object or path string expand_cells : 2-tuple, list, array-like Tuple, array, or list of the number of cells to expand in y and x directions. The first element corresponds to the y dimension and the second to the x dimension. stype : str Summarization type. Valid opition are mean, std, var, max, min, unique prod, median, mode, sum, unique, entropy, asm. Returns ------- Raster The resulting raster of aggregated values. References ---------- * `ESRI aggregate <https://pro.arcgis.com/en/pro-app/latest/tool-reference/spatial-analyst/aggregate.htm>`_ * `ESRI majority filter <https://pro.arcgis.com/en/pro-app/latest/tool-reference/spatial-analyst/majority-filter.htm>`_ """ # noqa: E501 expand_cells = np.atleast_1d(expand_cells) if not is_int(expand_cells.dtype): raise TypeError("expand_cells must contain integers") if expand_cells.shape != (2,): raise ValueError("expand_cells must contain 2 elements") if (expand_cells == 1).all(): raise ValueError("expand_cells values cannont both be one") if not (expand_cells >= 1).all(): raise ValueError("All expand_cells values must be >= 1") if not is_str(stype): raise TypeError("stype argument must be a string") stype = stype.lower() if ( stype not in _COARSEN_STYPE_TO_FUNC and stype not in _COARSEN_STYPE_TO_CUSTOM_FUNC ): raise ValueError(f"Invalid stype argument: {repr(stype)}") orig_dtype = get_raster(raster).dtype rs = get_raster(raster, null_to_nan=True) xda = rs.xdata xmask = rs.xmask dim_map = {"y": expand_cells[0], "x": expand_cells[1]} xdac = xda.coarsen(dim=dim_map, boundary="trim") if stype in _COARSEN_STYPE_TO_FUNC: xda = _COARSEN_STYPE_TO_FUNC[stype](xdac) else: custom_stat_func = _COARSEN_STYPE_TO_CUSTOM_FUNC[stype] check_nan = stype == "mode" out_dtype = _get_coarsen_dtype( stype, np.prod(expand_cells), orig_dtype ) # Use partial because reduce seems to be bugged and grabs kwargs that # it shouldn't. xda = xdac.reduce( partial( _coarsen_block_map, agg_func=custom_stat_func, out_dtype=out_dtype, check_nan=check_nan, ) ) # Coarsen mask as well if rs._masked: xmask = rs.xmask.coarsen(dim=dim_map, boundary="trim").all() ds_out = make_raster_ds(xda, xmask) if rs._masked: if stype == "unique": ds_out["raster"] = ds_out.raster.astype( _get_unique_dtype(ds_out.raster.dtype) ) elif stype == "mode": # Cast back to original dtype. Original null value will work ds_out["raster"] = ds_out.raster.astype(orig_dtype) # Replace null cells with null value acording to mask nv = get_default_null_value(ds_out.raster.dtype) ds_out["raster"] = xr.where(xmask, nv, ds_out.raster).rio.write_nodata( nv ) else: ds_out["mask"] = xr.zeros_like(rs.get_bands(1).xmask) if rs.crs is not None: ds_out = ds_out.rio.write_crs(rs.crs) return Raster(ds_out, _fast_path=True)
[docs]def predict_model(raster, model): """ Predict cell estimates from a model using raster band values as predictors. Predictor bands correspond to the order of the predictor variables within the model. Outputs are raster surfaces with bands cell values depending on the type of model. The function uses the `model` class' predict method to estimate a new raster surface. Parameters ---------- raster : Raster or path str Raster of predictor variables where the bands correspond to variables in the model (one band for each variable in the model). model : object The model used to estimate new values. Must have a `predict` method that takes an `xarray.DataArray` object. The provided DataArray will have null cells marked with ``NaN``. Returns ------- Raster The resulting raster of estmated values. """ rs = get_raster(raster, null_to_nan=True) xarr = rs.xdata xarrout = model.predict(xarr) return Raster(xarrout)
@nb.jit(nopython=True, nogil=True) def _local_chunk(x, func, out): dshape = x.shape rws = dshape[1] clms = dshape[2] for r in range(rws): for c in range(clms): farr = x[:, r, c] out[0, r, c] = func(farr) return out def _get_local_chunk_func(func, out_dtype): def wrapped(x): return _local_chunk( x, func, np.empty((1, *x.shape[1:]), dtype=out_dtype) ) return wrapped _LOCAL_STYPE_TO_NUMPY_FUNC = { "mean": np.nanmean, "std": np.nanstd, "var": np.nanvar, "max": np.nanmax, "min": np.nanmin, "prod": np.nanprod, "sum": np.nansum, "median": np.nanmedian, } # Use custom min/max band because numpy versions throw errors if all values are # nan. _LOCAL_STYPE_TO_CUSTOM_FUNC = { "asm": nanasm_jit, "entropy": nanentropy_jit, "minband": nanargmin_jit, "maxband": nanargmax_jit, "mode": nanmode_jit, "unique": nan_unique_count_jit, } def _local(data, stype): bnds = data.shape[0] orig_chunks = data.chunks # Rechunk so band dim is contiguous and chunk sizes are reasonable data = da.rechunk(data, chunks=(bnds, "auto", "auto")) ffun = _LOCAL_STYPE_TO_CUSTOM_FUNC[stype] if stype == "unique": out_dtype = np.min_scalar_type(data.shape[0]) elif stype == "mode": out_dtype = data.dtype elif stype in ("minband", "maxband"): # min/max band inner funcs will return < 0 if all values are nan which # will cause overflow for the unsigned type returned by # min_scalar_type. This is fine since we can mask them out later. out_dtype = np.min_scalar_type(data.shape[0] - 1) elif data.dtype == F32: out_dtype = F32 else: out_dtype = F64 ffun = _get_local_chunk_func(ffun, out_dtype) out_chunks = ((1,), *data.chunks[1:]) data_out = da.map_blocks( ffun, data, chunks=out_chunks, dtype=out_dtype, meta=np.array((), dtype=out_dtype), ) # Rechunk again to expand/contract chunks to a reasonable size and split # bands apart return da.rechunk(data_out, chunks=(1, *orig_chunks[1:]))
[docs]def local_stats(raster, stype): """Creates a Raster of summarized values across bands. The approach is based on ESRI's local function. Parameters ---------- raster : Raster or path str Input Raster object or path string stype : str Summarization type. Valid opition are mean, std, var, max, min, maxband, minband, prod, sum, mode, median, unique, entropy, asm Returns ------- Raster The resulting raster of values aggregated along the band dimension. References ---------- * `ESRI local <https://pro.arcgis.com/en/pro-app/latest/tool-reference/spatial-analyst/an-overview-of-the-local-tools.htm>`_ """ # noqa: E501 if not is_str(stype): raise TypeError("stype argument must be a string") stype = stype.lower() if ( stype not in _LOCAL_STYPE_TO_NUMPY_FUNC and stype not in _LOCAL_STYPE_TO_CUSTOM_FUNC ): raise ValueError(f"Invalid stype aregument: {repr(stype)}") orig_dtype = get_raster(raster).dtype rs = get_raster(raster, null_to_nan=True) xda = rs.xdata.copy() xmask = rs.xmask if stype in _LOCAL_STYPE_TO_NUMPY_FUNC: xndata = xda.reduce( _LOCAL_STYPE_TO_NUMPY_FUNC[stype], dim="band", keepdims=True ) else: data = xda.data xndata = empty_like(rs.get_bands(1)).xdata xndata.data = _local(data, stype) if rs._masked: xmask = xmask.reduce( np.all, dim="band", keepdims=True, ) else: xmask = xr.zeros_like(xmask, dtype=bool) ds_out = make_raster_ds(xndata, xmask) if rs._masked: if stype in ("unique", "minband", "maxband"): ds_out["raster"] = ds_out.raster.astype( _get_unique_dtype(ds_out.raster.dtype) ) elif stype == "mode": # Cast back to original dtype. Original null value will work ds_out["raster"] = ds_out.raster.astype(orig_dtype) nv = get_default_null_value(ds_out.raster.dtype) ds_out["raster"] = ds_out.raster.rio.write_nodata(nv) ds_out.raster.data = da.where(ds_out.mask.data, nv, ds_out.raster.data) if rs.crs is not None: ds_out = ds_out.rio.write_crs(rs.crs) return Raster(ds_out, _fast_path=True)
def _morph_op_chunk(x, footprint, cval, morph_op, binary=False): if x.ndim > 2: x = x[0] if not binary: if morph_op == "dilation": morph_func = grey_dilation else: morph_func = grey_erosion out = morph_func(x, footprint=footprint, mode="constant", cval=cval) else: if morph_op == "dilation": morph_func = binary_dilation else: morph_func = binary_erosion out = morph_func(x, structure=footprint) return out[None] def _get_footprint(size): if isinstance(size, Iterable): size = tuple(size) if len(size) != 2: raise ValueError("size sequence must have lenght 2.") if not all(is_int(s) for s in size): raise TypeError("size sequence must only contain ints.") elif is_int(size): size = (size, size) else: raise TypeError("size input must be an int or sequence of ints.") if not all(s > 0 for s in size): raise ValueError("size values must be greater than 0.") if all(s == 1 for s in size): raise ValueError("At least one size value must be greater than 1.") footprint = np.ones(size) > 0 return footprint def _get_fill(dtype, op): if is_int(dtype): type_info = np.iinfo(dtype) else: type_info = np.finfo(dtype) if op == "erosion": fill = type_info.max else: # dilation fill = type_info.min return fill def _erosion_or_dilation_filter(rs, footprint, op): data = rs.data fill = _get_fill(rs.dtype, op) if rs._masked: data = da.where(rs.mask, fill, data) rpad, cpad = _get_offsets(footprint.shape) # Take max because map_overlap does not support asymmetrical overlaps when # a boundary value is given depth = {0: 0, 1: max(rpad), 2: max(cpad)} data = da.map_overlap( partial( _morph_op_chunk, footprint=footprint, cval=fill, morph_op=op, ), data, depth=depth, boundary=fill, dtype=rs.dtype, meta=np.array((), dtype=rs.dtype), ) mask = rs.mask if rs._masked: mask = da.map_overlap( partial( _morph_op_chunk, footprint=footprint, cval=fill, morph_op=op, binary=True, ), ~mask, depth=depth, boundary=False, dtype=BOOL, meta=np.array((), dtype=BOOL), ) mask = ~mask data = da.where(mask, rs.null_value, data) else: mask = mask.copy() xrs_out = xr.zeros_like(rs.xdata).rio.write_nodata(rs.null_value) xrs_out.data = data xmask = xr.DataArray(mask, coords=xrs_out.coords, dims=xrs_out.dims) ds_out = make_raster_ds(xrs_out, xmask) if rs.crs is not None: ds_out = ds_out.rio.write_crs(rs.crs) return Raster(ds_out, _fast_path=True)
[docs]def dilate(raster, size): """Perform dilation on a raster Dilation increases the thickness of raster features. Features with higher values will cover features with lower values. At each cell, the miximum value within a window, defined by `size`, is stored in the output location. This is very similar to the max focal filter except that raster features are dilated (expanded) into null regions. Dilation is performed on each band separately. Parameters ---------- raster : Raster or path str The raster to dilate size : int or 2-tuple of ints The shape of the window to use when dilating. A Single int is interpreted as the size of a square window. A tuple of 2 ints is used as the dimensions of rectangular window. At least one value must be greater than 1. Values cannot be less than 1. Returns ------- Raster The resulting raster with eroded features. This raster will have the same shape and meta data as the original See also -------- erode, raster_tools.focal.focal References ---------- .. [1] https://en.wikipedia.org/wiki/Dilation_%28morphology%29 .. [2] https://en.wikipedia.org/wiki/Mathematical_morphology """ raster = get_raster(raster) footprint = _get_footprint(size) return _erosion_or_dilation_filter(raster, footprint, "dilation")
[docs]def erode(raster, size): """Perform erosion on a raster Erosion increases the thickness of raster features. Features with higher values will cover features with lower values. At each cell, the miximum value within a window, defined by `size`, is stored in the output location. This is very similar to the max focal filter except that raster features are eroded (contracted) into null regions. Erosion is performed on each band separately. Parameters ---------- raster : Raster or path str The raster to erode size : int or 2-tuple of ints The shape of the window to use when eroding. A Single int is interpreted as the size of a square window. A tuple of 2 ints is used as the dimensions of rectangular window. At least one value must be greater than 1. Values cannot be less than 1. Returns ------- Raster The resulting raster with eroded features. This raster will have the same shape and meta data as the original See also -------- dilate, raster_tools.focal.focal References ---------- .. [1] https://en.wikipedia.org/wiki/Erosion_%28morphology%29 .. [2] https://en.wikipedia.org/wiki/Mathematical_morphology """ raster = get_raster(raster) footprint = _get_footprint(size) return _erosion_or_dilation_filter(raster, footprint, "erosion")
[docs]def band_concat(rasters): """Join a sequence of rasters along the band dimension. Parameters ---------- rasters : sequence of Rasters and/or paths The rasters to concatenate. These can be a mix of Rasters and paths. All rasters must have the same shape in the x and y dimensions. Returns ------- Raster The resulting concatenated Raster. """ rasters = [get_raster(raster) for raster in rasters] if not rasters: raise ValueError("No rasters provided") if len(rasters) == 1: return rasters[0] # TODO: make join a user option or set join="exact"? ds = xr.concat([r._ds for r in rasters], dim="band", join="inner") ds["band"] = np.arange(np.sum([r.nbands for r in rasters])) + 1 if any(r._masked for r in rasters): nv = get_default_null_value(ds.raster.dtype) ds["raster"] = xr.where(ds.mask, nv, ds.raster).rio.write_nodata(nv) crs = ([None] + [r.crs for r in rasters if r.crs is not None]).pop() if ds.rio.crs is None and crs is not None: ds = ds.rio.write_crs(crs) return Raster(ds, _fast_path=True)
@nb.jit(nopython=True, nogil=True) def _remap_values(x, mask, mappings, inclusivity): outx = np.zeros_like(x) bands, rows, columns = x.shape rngs = mappings.shape[0] for bnd in range(bands): for rw in range(rows): for cl in range(columns): if mask[bnd, rw, cl]: continue vl = int(x[bnd, rw, cl]) remap = False for imap in range(rngs): left, right, new = mappings[imap] if inclusivity == 0: remap = left <= vl < right elif inclusivity == 1: remap = left < vl <= right elif inclusivity == 2: remap = left <= vl <= right elif inclusivity == 3: remap = left < vl < right if remap: outx[bnd, rw, cl] = new break if not remap: outx[bnd, rw, cl] = x[bnd, rw, cl] return outx def _normalize_mappings(mappings): if not isinstance(mappings, (list, tuple)): raise TypeError( "Mappings must be either single 3-tuple or list of 3-tuples of " "scalars" ) if not len(mappings): raise ValueError("No mappings provided") if len(mappings) and is_scalar(mappings[0]): mappings = [mappings] try: mappings = [list(m) for m in mappings] except TypeError: raise TypeError( "Mappings must be either single 3-tuple or list of 3-tuples of " "scalars" ) for m in mappings: if len(m) != 3: raise ValueError( "Mappings must be either single 3-tuple or list of 3-tuples of" " scalars" ) if not all(is_scalar(mi) for mi in m): raise TypeError("Mappings values must be scalars") if any(np.isnan(mi) for mi in m[:2]): raise ValueError("Mapping min and max values cannot be NaN") if m[0] >= m[1]: raise ValueError( "Mapping min value must be strictly less than max value:" f" {m[0]}, {m[1]}" ) return mappings
[docs]def remap_range(raster, mapping, inclusivity="left"): """Remaps values based on a mapping or list of mappings. Mappings are applied all at once with earlier mappings taking precedence. Parameters ---------- raster : Raster or str Path string or Raster to perform remap on. mapping : 3-tuple of scalars or list of 3-tuples of scalars A tuple or list of tuples containing ``(min, max, new_value)`` scalars. The mappiing(s) map values between the min and max to the ``new_value``. If `mapping` is a list and there are mappings that conflict or overlap, earlier mappings take precedence. `inclusivity` determines which sides of the range are inclusive and exclusive. inclusivity : str, optional Determines whether to be inclusive or exclusive on either end of the range. Default is `'left'`. 'left' [min, max) Left (min) side is inclusive and right (max) side is exclusive. 'right' (min, max] Left (min) side is exclusive and right (max) side is inclusive. 'both' [min, max] Both sides are inclusive. 'none' (min, max) Both sides are exclusive. Returns ------- Raster The resulting Raster. """ raster = get_raster(raster) mappings = _normalize_mappings(mapping) if not is_str(inclusivity): raise TypeError( f"inclusivity must be a str. Got type: {type(inclusivity)}" ) inc_map = { name: value for name, value in zip(("left", "right", "both", "none"), range(4)) } if inclusivity not in inc_map: raise ValueError(f"Invalid inclusivity value. Got: {inclusivity!r}") mappings_common_dtype = get_common_dtype([m[-1] for m in mappings]) out_dtype = np.promote_types(raster.dtype, mappings_common_dtype) # numba doesn't understand f16 so use f32 and then downcast f16_workaround = out_dtype == F16 mappings = np.atleast_2d(mappings) outrs = raster.copy() if out_dtype != outrs.dtype: if not f16_workaround: outrs = outrs.astype(out_dtype) else: outrs = outrs.astype(F32) elif f16_workaround: outrs = outrs.astype(F32) data = outrs.data func = partial( _remap_values, mappings=mappings, inclusivity=inc_map[inclusivity] ) outrs.xdata.data = data.map_blocks( func, raster.mask, dtype=data.dtype, meta=np.array((), dtype=data.dtype), ) if f16_workaround: outrs = outrs.astype(F16) return outrs
def where(condition, true_rast, false_rast): """ Return elements chosen from `true_rast` or `false_rast` depending on `condition`. Parameters ---------- condition : str or Raster A boolean or int raster that indicates where elements in the result should be selected from. If the condition is an int raster, it is coerced to bool using `condition > 0`. ``True`` cells pull values from `true_rast` and ``False`` cells pull from `y`. *str* is treated as a path to a raster. true_rast : scalar, Raster, str Raster or scalar to pull from if the corresponding location in `condition` is ``True``. false_rast : scalar, Raster, str Raster or scalar to pull from if the corresponding location in `condition` is ``False``. Returns ------- Raster The resulting Raster. """ condition = get_raster(condition) if not is_bool(condition.dtype) and not is_int(condition.dtype): raise TypeError( "Condition argument must be a boolean or integer raster" ) args = [] for r, name in [(true_rast, "true_rast"), (false_rast, "false_rast")]: if not is_scalar(r): try: r = get_raster(r) except TypeError: raise TypeError( f"Could not understand {name} argument. Got: {r!r}" ) args.append(r) true_rast, false_rast = args out_crs = None for r in [condition, true_rast, false_rast]: crs = getattr(r, "crs", None) if crs is not None: out_crs = crs break xtrue, xfalse = [r.xdata if isinstance(r, Raster) else r for r in args] masked = any(r._masked if isinstance(r, Raster) else False for r in args) scalar_and_nan = all(is_scalar(r) for r in args) and np.isnan(args).any() masked |= scalar_and_nan xcondition = condition.xdata if is_int(condition.dtype): # if condition.dtype is not bool then must be an int raster so # assume that condition is raster of 0 and 1 values. # condition > 0 will grab all 1/True values. xcondition = xcondition > 0 out_xrs = xr.where(xcondition, xtrue, xfalse) if masked and not scalar_and_nan: xtrue_mask, xfalse_mask = [ r.xmask if isinstance(r, Raster) else xr.DataArray( create_null_mask(condition.xdata, None), dims=condition.xdata.dims, coords=condition.xdata.coords, ) for r in args ] xmask = xr.where(xcondition, xtrue_mask, xfalse_mask) elif scalar_and_nan: xmask = np.isnan(out_xrs) else: xmask = xr.DataArray( create_null_mask(condition.xdata, None), dims=condition.xdata.dims, coords=condition.xdata.coords, ) if masked or scalar_and_nan: nv = get_default_null_value(out_xrs.dtype) out_xrs = xr.where(xmask, nv, out_xrs).rio.write_nodata(nv) out_ds = make_raster_ds(out_xrs, xmask) if out_crs is not None: out_ds = out_ds.rio.write_crs(out_crs) return Raster(out_ds, _fast_path=True) @nb.jit(nopython=True, nogil=True) def _reclassify_chunk(x, mask, mapping_array, unmapped_to_null, null): mapping = dict() for i in range(mapping_array.shape[0]): mapping[mapping_array[i, 0]] = mapping_array[i, 1] out = np.empty_like(x) nb, ny, nx = x.shape for b in range(nb): for i in range(ny): for j in range(nx): if mask[b, i, j]: out[b, i, j] = null break v = x[b, i, j] reclass = v in mapping if reclass: out[b, i, j] = mapping[v] elif unmapped_to_null: out[b, i, j] = null else: out[b, i, j] = v return out class RemapFileParseError(Exception): pass _REMAPPING_LINE_PATTERN = re.compile(r"^(?P<from>[+-]*\d+):(?P<to>[+-]*\d+)$") def _parse_ascii_remap_file(path): mapping = {} with open(path) as fd: for line in fd: line = line.strip() m = _REMAPPING_LINE_PATTERN.match(line) if m is None: raise RemapFileParseError(f"Invalid remap line: {line!r}") k = int(m.group("from")) v = int(m.group("to")) if k in mapping: raise ValueError("Found duplicate mapping: '{k}:{v}'.") mapping[k] = v return mapping def _get_remapping(mapping): if is_str(mapping): if os.path.exists(mapping): mapping = _parse_ascii_remap_file(mapping) else: raise IOError(f"No such file: {mapping!r}") elif not isinstance(mapping, dict): raise TypeError( f"Remapping must be a str or dict. Got: {type(mapping)!r}" ) if not all(is_int(k) for k in mapping.keys()) or not all( is_int(v) for v in mapping.values() ): raise TypeError("Remapping values must all be integer types") return mapping
[docs]def reclassify(raster, remapping, unmapped_to_null=False): """Reclassify the input raster values based on a mapping. This function only works with integer type rasters. Parameters ---------- raster : str, Raster The input raster to reclassify. Can be a path string or Raster object. The raster dtype must be integer. remapping : str, dict Can be either a ``dict`` or a path string. If a ``dict`` is provided, the keys will be reclassified to the corresponding values. If a path string, it is treated as an ASCII remap file where each line looks like ``a:b`` and ``a`` and ``b`` are integers. All remap values (both from and to) must be integers. unmapped_to_null : bool, optional If ``True``, values not included in the mapping are instead mapped to the null value. Default is ``False``. Returns ------- Raster The remapped raster. """ raster = get_raster(raster) remapping = _get_remapping(remapping) if not is_int(raster.dtype): raise TypeError( f"Input raster must be an integer type. Got {raster.dtype!r}" ) out_dtype = raster.dtype to_values = list(remapping.values()) if any(not np.can_cast(v, out_dtype) for v in to_values): extra_min = raster.dtype.type(np.iinfo(raster.dtype).min) extra_max = raster.dtype.type(np.iinfo(raster.dtype).max) out_dtype = get_common_dtype(to_values + [extra_min, extra_max, -1]) if unmapped_to_null: if raster._masked: nv = raster.null_value else: nv = get_default_null_value(out_dtype) else: nv = 0 mapping = np.array(list(remapping.items())) data = da.map_blocks( _reclassify_chunk, raster.data.copy().astype(out_dtype), raster.mask, mapping_array=mapping, unmapped_to_null=unmapped_to_null, null=nv, dtype=out_dtype, meta=np.array((), dtype=out_dtype), ) xdata = xr.DataArray( data, coords=raster.xdata.coords, dims=raster.xdata.dims ).rio.write_nodata(nv) if raster.crs is not None: xdata = xdata.rio.write_crs(raster.crs) return Raster(xdata)