Source code for africanus.experimental.rime.fused.dask

import numpy as np

from africanus.util.requirements import requires_optional
from africanus.experimental.rime.fused.core import (
    RimeFactory, consolidate_args)

try:
    import dask.array as da
except ImportError as e:
    opt_import_err = e
else:
    opt_import_err = None


def rime_dask_wrapper(factory, names, nconcat_dims, *args):
    # Call the rime factory
    assert len(names) == len(args)
    out = factory(**dict(zip(names, args)))
    # (1) Reintroduce source dimension,
    # (2) slice the existing dimes
    # (3) expand by the contraction dims which will
    #     be removed in the later dask reduction
    return out[(None,) + (slice(None),)*out.ndim + (None,)*nconcat_dims]


[docs]@requires_optional("dask.array", opt_import_err) def rime(rime_spec, *args, **kw): """Like :func:`~africanus.experimental.rime.fused.core.rime`, but for a dask paradigm""" factory = RimeFactory(rime_spec) names, args = factory.dask_blockwise_args(**consolidate_args(args, kw)) dims = ("source", "row", "chan", "corr") contract_dims = set(d for ds in args[1::2] if ds is not None for d in ds) contract_dims -= set(dims) out_dims = dims + tuple(contract_dims) # Source and concatenation dimension are reduced to 1 element adjust_chunks = {"source": 1, **{d: 1 for d in contract_dims}} # This is needed otherwise, dask will call rime_dask_wrapper # with dummy arugments to infer the output dtype. # This incurs memory allocations within numba, as well as # exceptions, leading to memory leaks as described # in https://github.com/numba/numba/issues/3263 meta = np.empty((0,)*len(out_dims), dtype=np.complex128) # Construct the wrapper call from given arguments out = da.blockwise(rime_dask_wrapper, out_dims, factory, None, names, None, len(contract_dims), None, *args, concatenate=False, adjust_chunks=adjust_chunks, meta=meta) # Contract over source and concatenation dims axes = (0,) + tuple(range(len(dims), len(dims) + len(contract_dims))) return out.sum(axis=axes)