import os
from pathlib import Path
import dask
import numpy as np
import rasterio as rio
import rioxarray as rxr
import xarray as xr
from dask.array.core import normalize_chunks as dask_chunks
from raster_tools.dtypes import F32, F64, I64, U8, is_bool, is_float, is_int
from raster_tools.masking import create_null_mask
from raster_tools.utils import (
is_strictly_decreasing,
is_strictly_increasing,
validate_path,
)
[docs]class RasterIOError(BaseException):
pass
[docs]class RasterDataError(BaseException):
pass
def _get_extension(path):
return os.path.splitext(path)[-1].lower()
def _get_chunking_info_from_file(src_file):
with rio.open(src_file) as src:
tile_shape = (1, *src.block_shapes[0])
shape = (src.count, *src.shape)
dtype = np.dtype(src.dtypes[0])
return tile_shape, shape, dtype
def _get_chunks(data=None, src_file=None):
chunks = (1, "auto", "auto")
if data is None:
if src_file is None:
return chunks
tile_shape, shape, dtype = _get_chunking_info_from_file(src_file)
return dask_chunks(
chunks, shape, dtype=dtype, previous_chunks=tile_shape
)
else:
shape = data.shape
dtype = data.dtype
tile_shape = None
if dask.is_dask_collection(data):
tile_shape = data.chunks
elif src_file is not None:
_, tile_shape, _ = _get_chunking_info_from_file(src_file)
return dask_chunks(
chunks, shape, dtype=dtype, previous_chunks=tile_shape
)
[docs]def chunk(xrs, src_file=None):
if isinstance(xrs, xr.Dataset):
chunks = _get_chunks(xrs.raster, src_file)
return xrs.chunk({d: c for d, c in zip(xrs.raster.dims, chunks)})
else:
return xrs.chunk(_get_chunks(xrs, src_file))
TIFF_EXTS = frozenset((".tif", ".tiff"))
NC_EXTS = frozenset((".cdf", ".nc", ".nc4"))
HDF_EXTS = frozenset((".hdf", ".h4", ".hdf4", ".he2", ".h5", ".hdf5", ".he5"))
GRIB_EXTS = frozenset((".grib", ".grib2", ".grb", ".grb2", ".gb", ".gb2"))
BATCH_EXTS = frozenset((".bch",))
# File extenstions that can't be read in yet
READ_NOT_IMPLEMENTED_EXTS = NC_EXTS | HDF_EXTS | GRIB_EXTS
# File extenstions that can't be written out yet
WRITE_NOT_IMPLEMENTED_EXTS = NC_EXTS | HDF_EXTS | GRIB_EXTS
IO_UNDERSTOOD_TYPES = (str, Path)
[docs]def is_batch_file(path):
return _get_extension(path) in BATCH_EXTS
[docs]def normalize_xarray_data(xrs):
if len(xrs.shape) > 3 or len(xrs.shape) < 2:
raise ValueError(
"Invalid shape. xarray.DataArray objects must have 2D or 3D "
"shapes."
)
if len(xrs.shape) == 2:
# Add band dim
xrs = xrs.expand_dims({"band": [1]})
dims = xrs.dims
if "lon" in dims:
xrs = xrs.rename({"lon": "x"})
dims = xrs.dims
if "lat" in dims:
xrs = xrs.rename({"lat": "y"})
dims = xrs.dims
if not dims == ("band", "y", "x"):
# No easy way to figure out how best to transpose based on dim names so
# just assume the order is valid and rename.
xrs = xrs.rename(
{
d: new_d
for d, new_d in zip(dims, ("band", "y", "x"))
if d != new_d
}
)
if xrs.band.values[0] != 1:
xrs["band"] = np.arange(1, len(xrs.band) + 1)
if any(dim not in xrs.coords for dim in xrs.dims):
raise ValueError(
"Invalid coordinates on xarray.DataArray object:\n{xrs!r}"
)
# Make sure that x and y are always increasing. xarray will auto align
# rasters but when a raster is converted to a numpy or dask array, the
# data may not be aligned. This ensures that rasters converted to
# non-georeferenecd formats will be oriented the same.
if is_strictly_decreasing(xrs.x):
xrs = xrs.isel(x=slice(None, None, -1))
if is_strictly_increasing(xrs.y):
xrs = xrs.isel(y=slice(None, None, -1))
tf = xrs.rio.transform(True)
xrs = xrs.rio.write_transform(tf)
return xrs
ESRI_DEFAULT_F32_NV = np.finfo(F32).min
[docs]def normalize_null_value(nv, dtype):
# Make sure that ESRI's default F32 null value is properly
# registered as F32
if dtype == F32 and nv is not None and np.isclose(nv, ESRI_DEFAULT_F32_NV):
nv = F32.type(nv)
# Some rasters have (u)int dtype and a null value that is a whole number
# but it gets read in as a float. This can cause a lot of accidental type
# promotions down the pipeline. Check for this case and correct it.
if is_int(dtype) and is_float(nv) and float(nv).is_integer():
nv = int(nv)
return nv
[docs]def open_raster_from_path(path):
if type(path) in IO_UNDERSTOOD_TYPES:
path = str(path)
path = os.path.abspath(path)
else:
raise RasterIOError(
f"Could not resolve input to a raster path: '{path}'"
)
validate_path(path)
ext = _get_extension(path)
xrs = None
# Try to let gdal open anything but NC, HDF, GRIB files
if not ext or ext not in READ_NOT_IMPLEMENTED_EXTS:
try:
xrs = rxr.open_rasterio(path, chunks=_get_chunks())
except rio.errors.RasterioIOError as e:
raise RasterIOError(str(e))
elif ext in READ_NOT_IMPLEMENTED_EXTS:
raise NotImplementedError(
"Reading of NetCDF, HDF, and GRIB files is not supported at this"
" time."
)
else:
raise RasterIOError("Unknown file type")
if isinstance(xrs, xr.Dataset):
raise RasterDataError("Too many data variables in input data")
assert isinstance(
xrs, xr.DataArray
), "Resulting data structure must be a DataArray"
if not dask.is_dask_collection(xrs):
xrs = chunk(xrs, path)
xrs = normalize_xarray_data(xrs)
nv = xrs.attrs.get("_FillValue", None)
nv = normalize_null_value(nv, xrs.dtype)
mask = create_null_mask(xrs, nv)
return xrs, mask, nv
[docs]def write_raster(xrs, path, no_data_value, **rio_gdal_kwargs):
ext = _get_extension(path)
rio_is_bool = False
if ext in TIFF_EXTS or len(ext) == 0:
if xrs.dtype == I64:
# GDAL, and thus rioxarray and rasterio, doesn't support I64 so
# cast up to float. This avoids to_raster throwing a TypeError.
xrs = xrs.astype(F64)
elif is_bool(xrs.dtype):
# GDAL doesn't support boolean dtype either so convert to uint8
# 0-1 encoding.
rio_is_bool = True
xrs = xrs.astype(U8)
if not ext or ext not in WRITE_NOT_IMPLEMENTED_EXTS:
kwargs = {"lock": True, "compute": True, **rio_gdal_kwargs}
if "blockheight" in kwargs:
value = kwargs.pop("blockheight")
kwargs["blockysize"] = value
if "blockwidth" in kwargs:
value = kwargs.pop("blockwidth")
kwargs["blockxsize"] = value
if rio_is_bool:
# Store each entry using a single bit
kwargs["nbits"] = 1
xrs.rio.to_raster(path, **kwargs)
else:
# TODO: populate
raise NotImplementedError()