Source code for africanus.experimental.rime.fused.dask
import numpy as np
from africanus.util.requirements import requires_optional
from africanus.experimental.rime.fused.core import RimeFactory, consolidate_args
try:
import dask.array as da
except ImportError as e:
opt_import_err = e
else:
opt_import_err = None
def rime_dask_wrapper(factory, names, nconcat_dims, *args):
# Call the rime factory
assert len(names) == len(args)
out = factory(**dict(zip(names, args)))
# (1) Reintroduce source dimension,
# (2) slice the existing dimensions
# (3) expand by the contraction dims which will
# be removed in the later dask reduction
return out[(None,) + (slice(None),) * out.ndim + (None,) * nconcat_dims]
[docs]
@requires_optional("dask.array", opt_import_err)
def rime(rime_spec, *args, **kw):
"""Like :func:`~africanus.experimental.rime.fused.core.rime`, but for
a dask paradigm"""
factory = RimeFactory(rime_spec)
names, args = factory.dask_blockwise_args(**consolidate_args(args, kw))
dims = ("source", "row", "chan", "corr")
contract_dims = set(d for ds in args[1::2] if ds is not None for d in ds)
contract_dims -= set(dims)
out_dims = dims + tuple(contract_dims)
# Source and concatenation dimension are reduced to 1 element
adjust_chunks = {"source": 1, **{d: 1 for d in contract_dims}}
new_axes = {"corr": len(factory.rime_spec.corrs)}
# This is needed otherwise, dask will call rime_dask_wrapper
# with dummy arguments to infer the output dtype.
# This incurs memory allocations within numba, as well as
# exceptions, leading to memory leaks as described
# in https://github.com/numba/numba/issues/3263
meta = np.empty((0,) * len(out_dims), dtype=np.complex128)
# Construct the wrapper call from given arguments
out = da.blockwise(
rime_dask_wrapper,
out_dims,
factory,
None,
names,
None,
len(contract_dims),
None,
*args,
concatenate=False,
adjust_chunks=adjust_chunks,
new_axes=new_axes,
meta=meta,
)
# Contract over source and concatenation dims
axes = (0,) + tuple(range(len(dims), len(dims) + len(contract_dims)))
return out.sum(axis=axes)