Source code for raster_tools.zonal

from collections.abc import Iterable, Iterator, Sequence
from functools import partial

import dask
import dask.array as da
import dask.dataframe as dd
import numba as nb
import numpy as np
import pandas as pd
from dask_image import ndmeasure

from raster_tools.dask_utils import dask_nanmax, dask_nanmin
from raster_tools.dtypes import F64, I64, is_int, is_str
from raster_tools.raster import Raster, get_raster, xy_to_rowcol
from raster_tools.vector import Vector, get_vector

__all__ = ["ZONAL_STAT_FUNCS", "zonal_stats"]


def _nan_count(x):
    return da.count_nonzero(~np.isnan(x))


def _nan_median(x):
    x = da.asarray(x)
    return da.nanmedian(x, axis=0)


def _nan_unique(x):
    return _nan_count(da.unique(x))


def _flatten_gen(x):
    """
    A generator that recursively yields numpy arrays from arbitrarily nested
    lists of arrays.
    """
    for xi in x:
        if isinstance(x, Iterable) and not isinstance(xi, np.ndarray):
            yield from _flatten_gen(xi)
        else:
            yield xi


def _flatten(x):
    """Flatten nested lists of arrays."""
    if isinstance(x, np.ndarray):
        return [x]
    return list(_flatten_gen(x))


def _recursive_map(func, *seqs):
    """Apply a function to items in nested sequences."""
    if isinstance(seqs[0], (list, Iterator)):
        return [_recursive_map(func, *items) for items in zip(*seqs)]
    return func(*seqs)


def _unique_with_counts_chunk(x, computing_meta=False, axis=(), **kwargs):
    """Reduce a dask chunk to a dict of unique values and counts.

    This is the leaf operation in the reduction tree.
    """
    if computing_meta:
        return x
    x_non_nan = x[~np.isnan(x)]
    values, counts = np.unique(x_non_nan, return_counts=True)
    while values.ndim < len(axis):
        values = np.expand_dims(values, axis=0)
        counts = np.expand_dims(counts, axis=0)
    return {"values": values, "counts": counts}


def _ravel_key(item, key):
    return item[key].ravel()


_ravel_values = partial(_ravel_key, key="values")
_ravel_counts = partial(_ravel_key, key="counts")


def _split_concat(pairs, split_func):
    # Split out a key from lists of dicts, ravel them, and concat all together
    split = _recursive_map(split_func, pairs)
    return np.concatenate(_flatten(split))


def _unique_with_counts_combine(
    pairs, computing_meta=False, axis=(), **kwargs
):
    """Merge/combine branches of the unique-with-counts reduction tree.

    This includes results from multiple _unique_with_counts_chunk calls and
    from prior _unique_with_counts_combine calls.
    """
    values = (
        _recursive_map(_ravel_values, pairs) if not computing_meta else pairs
    )
    values = np.concatenate(_flatten(values))
    if computing_meta:
        return np.array([[[0]]], dtype=pairs.dtype)

    counts = _split_concat(pairs, _ravel_counts)
    res = {v: 0 for v in values}
    for v, c in zip(values, counts):
        res[v] += c
    values = np.array(list(res.keys()))
    counts = np.array(list(res.values()))
    while values.ndim < len(axis):
        values = np.expand_dims(values, axis=0)
        counts = np.expand_dims(counts, axis=0)
    return {"values": values, "counts": counts}


def _mode_agg(pairs, computing_meta=False, axis=(), **kwargs):
    """Perform the final aggregation to a single mode value."""
    values = (
        _split_concat(pairs, _ravel_values) if not computing_meta else pairs
    )
    if computing_meta:
        return pairs.dtype.type(0)
    if len(values) == 0:
        # See note below about wrapping in np.array()
        return np.array(np.nan)
    counts = _split_concat(pairs, _ravel_counts)
    res = {v: 0 for v in values}
    for v, c in zip(values, counts):
        res[v] += c
    values = res.keys()
    counts = res.values()
    sorted_pairs = sorted(zip(counts, values), reverse=True)
    # Find the minimum mode when there is a tie. This is the same behavior as
    # scipy.
    i = -1
    c = sorted_pairs[0][0]
    for pair in sorted_pairs:
        if pair[0] == c:
            i += 1
        else:
            break
    # NOTE: wrapping the value in an array is a hack to prevent dask from
    # mishandling the return value as an array with dims, leading to index
    # errors. I can't pierce the veil of black magic that is causing the
    # mishandling so this is the best fix I can come up with.
    return np.array(sorted_pairs[i][1])


def _nan_mode(x):
    """
    Compute the statistical mode of an array using a dask reduction operation.
    """
    return da.reduction(
        x,
        chunk=_unique_with_counts_chunk,
        combine=_unique_with_counts_combine,
        aggregate=_mode_agg,
        # F64 to allow for potential empty input array. In that case a NaN is
        # returned.
        dtype=F64,
        # Turn off concatenation to prevent dask from trying to concat the
        # dicts of variable length values and counts. Dask tries to concat
        # along the wrong axis, which causes errors.
        concatenate=False,
    )


@nb.jit(nopython=True, nogil=True)
def _entropy(values, counts):
    if len(values) == 0:
        return np.nan
    res = {v: 0 for v in values}
    for v, c in zip(values, counts):
        res[v] += c
    counts = res.values()
    entropy = 0.0
    frac = 1 / len(res)
    for cnt in counts:
        p = cnt * frac
        entropy -= p * np.log(p)
    return entropy


@nb.jit(nopython=True, nogil=True)
def _asm(values, counts):
    if len(values) == 0:
        return np.nan
    res = {v: 0 for v in values}
    for v, c in zip(values, counts):
        res[v] += c
    counts = res.values()
    asm = 0.0
    frac = 1 / len(res)
    for cnt in counts:
        p = cnt * frac
        asm += p * p
    return asm


def _entropy_asm_agg(
    pairs, compute_entropy, computing_meta=False, axis=(), **kwargs
):
    """Perform the final aggregation to a single entropy or ASM value."""
    if computing_meta:
        return 0
    values = _split_concat(pairs, _ravel_values)
    if len(values) == 0:
        return np.array([])
    counts = _split_concat(pairs, _ravel_counts)
    # NOTE: wrapping the value in an array is a hack to prevent dask from
    # mishandling the return value as an array with dims, leading to index
    # errors. I can't pierce the veil of black magic that is causing the
    # mishandling so this is the best fix I can come up with.
    if compute_entropy:
        return np.array(_entropy(values, counts))
    return np.array(_asm(values, counts))


def _nan_entropy(x):
    """Compute the entropy of an array using a dask reduction operation."""
    return da.reduction(
        x,
        # mode chunk and combine funcs can be reused here
        chunk=_unique_with_counts_chunk,
        combine=_unique_with_counts_combine,
        aggregate=partial(_entropy_asm_agg, compute_entropy=True),
        dtype=F64,
        # Turn off concatenation to prevent dask from trying to concat the
        # dicts of variable length values and counts. Dask tries to concat
        # along the wrong axis, which causes errors.
        concatenate=False,
    )


def _nan_asm(x):
    """Compute the ASM of an array using a dask reduction operation.

    Angular second moment.
    """
    return da.reduction(
        x,
        # mode chunk and combine funcs can be reused here
        chunk=_unique_with_counts_chunk,
        combine=_unique_with_counts_combine,
        aggregate=partial(_entropy_asm_agg, compute_entropy=False),
        dtype=F64,
        # Turn off concatenation to prevent dask from trying to concat the
        # dicts of variable length values and counts. Dask tries to concat
        # along the wrong axis, which causes errors.
        concatenate=False,
    )


_ZONAL_STAT_FUNCS = {
    "asm": _nan_asm,
    "count": _nan_count,
    "entropy": _nan_entropy,
    "max": dask_nanmax,
    "mean": da.nanmean,
    "median": _nan_median,
    "min": dask_nanmin,
    "mode": _nan_mode,
    "std": da.nanstd,
    "sum": da.nansum,
    "unique": _nan_unique,
    "var": da.nanvar,
}
# The set of valid zonal function names/keys
ZONAL_STAT_FUNCS = frozenset(_ZONAL_STAT_FUNCS)


def _build_zonal_stats_data(data_raster, feat_raster, feat_labels, stats):
    nbands = data_raster.shape[0]
    feat_data = feat_raster.data
    # data will end up looking like:
    # {
    #   # band number
    #   1: {
    #     # Stat results
    #     "mean": [X, X, X], <- dask array
    #     "std": [X, X, X],
    #     ...
    #   },
    #   2: {
    #     # Stat results
    #     "mean": [X, X, X],
    #     "std": [X, X, X],
    #     ...
    #   },
    #   ...
    data = {}
    raster_data = get_raster(data_raster, null_to_nan=True).data
    for ibnd in range(nbands):
        ibnd += 1
        data[ibnd] = {}
        # Use range to keep band dimension intact
        band_data = raster_data[ibnd - 1 : ibnd]
        for f in stats:
            result_delayed = dask.delayed(ndmeasure.labeled_comprehension)(
                band_data,
                feat_data,
                feat_labels,
                _ZONAL_STAT_FUNCS[f],
                F64,
                np.nan,
            )
            data[ibnd][f] = da.from_delayed(
                result_delayed,
                feat_labels.shape,
                dtype=F64,
                meta=np.array([], dtype=F64),
            )
    return data


def _create_dask_range_index(start, stop):
    # dask.dataframe only allows dask.dataframe.index objects but doesn't have
    # a way to create them. this is a hack to create one using from_pandas.
    dummy = pd.DataFrame(
        {"tmp": np.zeros(stop - start, dtype="u1")},
        index=pd.RangeIndex(start, stop),
    )
    return dd.from_pandas(dummy, 1).index


def _build_zonal_stats_dataframe(zonal_data, nparts=None):
    bands = list(zonal_data)
    snames = list(zonal_data[bands[0]])
    n = zonal_data[bands[0]][snames[0]].size
    if nparts is None:
        # Get the number of partitions that dask thinks is reasonable. The data
        # arrays have chunks of size 1 so we need to rechunk later and then
        # repartition everything else in the dataframe to match.
        nparts = zonal_data[bands[0]][snames[0]].rechunk().npartitions

    df = None
    for bnd in bands:
        df_part = None
        band_data = zonal_data[bnd]
        band = da.full(n, bnd, dtype=I64)
        # We need to create an index because the concat operation later will
        # blindly paste in each dataframe's index. If an explicit index is not
        # set, the default is a range index from 0 to n. Thus the final
        # resulting dataframe would have identical indexes chained end-to-end:
        # [0, 1, ..., n-1, 0, 1, ..., n-1, 0, 1..., n-1]. By setting an index
        # we get [0, 1, ..., n, n+1, ..., n + n, ...].
        ind_start = n * (bnd - 1)
        ind_end = ind_start + n
        index = _create_dask_range_index(ind_start, ind_end)
        df_part = band.to_dask_dataframe("band", index=index).to_frame()
        # Repartition to match the data
        df_part = df_part.repartition(npartitions=nparts)
        index = index.repartition(npartitions=nparts)
        for name in snames:
            df_part[name] = (
                band_data[name].rechunk().to_dask_dataframe(name, index=index)
            )
        if df is None:
            df = df_part
        else:
            # Use interleave_partitions to keep partition and division info
            df = dd.concat([df, df_part], interleave_partitions=True)
    return df


[docs]def zonal_stats(features, data_raster, stats, raster_feature_values=None): """Apply stat functions to a raster based on a set of features. Parameters ---------- features : str, Vector, Raster A `Vector` or path string pointing to a vector file or a categorical Raster. The vector features are used like cookie cutters to pull data from the `data_raster` bands. If `features` is a Raster, it must be an int dtype and have only one band. data_raster : Raster, str A `Raster` or path string pointing to a raster file. The data raster to pull data from and apply the stat functions to. stats : str, list of str A single string or list of strings corresponding to stat funcstions. These functions will be applied to the raster data for each of the features in `features`. Valid string values: 'asm' Angular second moment. Applies -sum(P(g)**2) where P(g) gives the probability of g within the neighborhood. 'count' Count valid cells. 'entropy' Calculates the entropy. Applies -sum(P(g) * log(P(g))). See 'asm' above. 'max' Find the maximum value. 'mean' Calculate the mean. 'median' Calculate the median value. 'min' Find the minimum value. 'mode' Compute the statistical mode of the data. In the case of a tie, the lowest value is returned. 'std' Calculate the standard deviation. 'sum' Calculate the sum. 'unique' Count unique values. 'var' Calculate the variance. raster_feature_values : sequence of ints, optional Unique values to be used when the `features` argument is a Raster. If `features` is a Raster and this is not provided the unique values in the raster will be calculated. Returns ------- dask.dataframe.DataFrame A delayed dask DataFrame. The columns are the values in `stats` plus a column indicating the band the calculation was carried out on. Each row is the set of statistical calculations carried out on data pulled from `data_raster` based on the corresponding feature in `features`. NaN values indicate where a feature was outside of the raster or all data under the feature was null. """ if is_str(features) or isinstance(features, Vector): features = get_vector(features) elif isinstance(features, Raster): if not is_int(features.dtype): raise TypeError("Feature raster must be an integer type.") if features.shape[0] > 1: raise ValueError("Feature raster must have only 1 band.") else: raise TypeError( "Could not understand features arg. Must be Vector, str or Raster" ) data_raster = get_raster(data_raster) if is_str(stats): stats = [stats] elif isinstance(stats, Sequence): stats = list(stats) if not stats: raise ValueError("No stat functions provide") else: raise ValueError(f"Could not understand stats arg: {repr(stats)}") for stat in stats: if stat not in ZONAL_STAT_FUNCS: raise ValueError(f"Invalid stats function: {repr(stat)}") if isinstance(features, Raster): if features.crs != data_raster.crs: raise ValueError("Feature raster CRS must match data raster") if features.shape != data_raster.shape: raise ValueError("Feature raster shape must match data raster") feature_labels = None features_raster = None if isinstance(features, Vector): feature_labels = np.arange(1, len(features) + 1) features_raster = features.to_raster(data_raster) else: if raster_feature_values is None: (raster_feature_values,) = dask.compute(np.unique(features.data)) else: raster_feature_values = np.atleast_1d(raster_feature_values) raster_feature_values = raster_feature_values[ raster_feature_values > 0 ] feature_labels = raster_feature_values features_raster = features data = _build_zonal_stats_data( data_raster, features_raster, feature_labels, stats ) df = _build_zonal_stats_dataframe(data) return df
def _xy_to_rowcol_wrapper(x, y, affine): return np.stack(xy_to_rowcol(x, y, affine), axis=0) def _extract_points(data, r, c, valid_mask): r, c, valid_mask = dask.compute(r, c, valid_mask) extracted = np.full((len(valid_mask),), np.nan, dtype=F64) extracted[valid_mask] = data[r[valid_mask], c[valid_mask]] return extracted def _build_zonal_stats_data_from_points(data, mask, x, y, affine): r, c = da.blockwise( _xy_to_rowcol_wrapper, "zi", x, "i", y, "i", affine=affine, new_axes={"z": 2}, dtype=np.int64, ) _, rn, cn = data.shape r, c = dask.compute(r, c) valid_mask = (r >= 0) & (r < rn) & (c >= 0) & (c < cn) out = { i + 1: {"extracted": da.full(len(valid_mask), np.nan, dtype=F64)} for i in range(data.shape[0]) } for i in range(data.shape[0]): extracted = da.full(len(valid_mask), np.nan, dtype=F64) extracted[valid_mask] = data.vindex[i, r[valid_mask], c[valid_mask]] # Mask out missing points within the valid zones exmask = da.zeros(len(valid_mask), dtype=bool) exmask[valid_mask] = mask.vindex[i, r[valid_mask], c[valid_mask]] extracted[exmask] = np.nan out[i + 1]["extracted"] = extracted return out def point_extraction(points, raster, skip_validation=False): """Extract the raster cell values using point features This finds the grid cells that the points fall into and extracts the value at each point. The input feature will be partially computed to make sure that all of the geometries are points, unless `skip_validation` is set to `True`. Note ---- This is experimental. Parameters ---------- points : str, Vector The points to use for extracting data. raster : str, Raster The raster to pull data from. skip_validation : bool, optional If `True`, the input `points` is not validated to make sure that all features are points. This prevents partially computing the data. Default is `False`. Returns ------- dask.dataframe.DataFrame The columns are "band" and "extracted". These are the band data was pulled from and the extracted value, respectively. NaN values in the "extracted" column are where there was missing data in the raster or the point was outside the raster's domain. """ # TODO: properly test points = get_vector(points) if ( not skip_validation and not (points.geometry.geom_type == "Point").all().compute() ): raise TypeError("All geometries must be points.") data_raster = get_raster(raster, null_to_nan=True) points = points.to_crs(data_raster.crs) x = points.geometry.x.to_dask_array(True) y = points.geometry.y.to_dask_array(True) data = _build_zonal_stats_data_from_points( data_raster.data, data_raster.mask, x, y, data_raster.affine ) n = len(data[1]["extracted"]) df = _build_zonal_stats_dataframe(data, nparts=1).reset_index(drop=True) df.divisions = (0, n - 1) return df