Source code for africanus.rime.dask_predict

# -*- coding: utf-8 -*-

from functools import reduce
from itertools import product
from operator import mul

try:
    from collections.abc import Mapping
except ImportError:
    from collections import Mapping

import numpy as np

from africanus.util.requirements import requires_optional

from africanus.rime.predict import (PREDICT_DOCS, predict_checks,
                                    predict_vis as np_predict_vis)
from africanus.rime.wsclean_predict import (
                                WSCLEAN_PREDICT_DOCS,
                                wsclean_predict_impl as wsclean_predict_body)
from africanus.model.wsclean.spec_model import spectra as wsclean_spectra


try:
    import dask.array as da
    from dask.base import tokenize
    import dask.blockwise as db
    from dask.utils import funcname
    from dask.highlevelgraph import HighLevelGraph
except ImportError as e:
    opt_import_error = e
else:
    opt_import_error = None


def _ind_map(arg, ind, out_ind, dim_map, dim_blocks):
    # Yield name as first tuple element
    yield arg

    for j in ind:
        try:
            dim_idx = dim_map[j]
        except KeyError:
            # The blockid is not in the output key.
            # Assume (and check for a single blockid)
            try:
                db = dim_blocks[j]
            except KeyError:
                raise ValueError("%s not in block mapping" % j)
            else:
                if db != 1:
                    raise ValueError("Dimension %s must be a single block" % j)

                yield 0
        else:
            # Extract blockid for this index from the output key
            yield out_ind[dim_idx]


class LinearReduction(Mapping):
    def __init__(
        self,
        func,
        output_indices,
        indices,
        numblocks,
        feed_index=0,
        axis=None,
    ):
        self.func = func
        self.output_indices = tuple(output_indices)
        self.indices = tuple((name, tuple(ind) if ind is not None else ind)
                             for name, ind in indices)
        self.numblocks = numblocks

        if axis is None:
            raise ValueError("axis not set")

        if axis in self.output_indices:
            raise ValueError("axis in output_indices")

        self.feed_index = feed_index
        self.axis = axis

        token = tokenize(self.func,
                         self.output_indices,
                         self.indices,
                         self.numblocks,
                         self.feed_index,
                         self.axis)

        self.func_name = funcname(self.func)
        self.name = "-".join((self.func_name, token))

    @property
    def _dict(self):
        if hasattr(self, "_cached_dict"):
            return self._cached_dict
        else:
            # Reduction axis
            ax = self.axis
            feed_index = self.feed_index

            # Number of blocks for each dimension, derived from the input
            dim_blocks = db.broadcast_dimensions(self.indices, self.numblocks)
            last_block = dim_blocks[ax] - 1

            out_dims = (ax,) + self.output_indices
            dim_map = {k: i for i, k in enumerate(out_dims)}

            dsk = {}
            int_name = "-".join((self.func_name,
                                 "intermediate",
                                 tokenize(self.name)))

            # Iterate over the output keys creating associated task
            for out_ind in product(*[range(dim_blocks[d]) for d in out_dims]):
                task = [self.func]

                for i, (arg, ind) in enumerate(self.indices):
                    if i == feed_index:
                        # First reduction block, feed in None
                        if out_ind[0] == 0:
                            task.append(None)

                        # Otherwise feed in the result of the last operation
                        else:
                            task.append((int_name,) +
                                        # Index last reduction block
                                        # always in first axis
                                        (out_ind[0] - 1,) +
                                        out_ind[1:])

                    elif ind is None:
                        # Literal arg, embed
                        task.append(arg)
                    else:
                        # Derive input key from output key indices
                        task.append(tuple(_ind_map(arg, ind, out_ind,
                                                   dim_map, dim_blocks)))

                # Final block
                if out_ind[0] == last_block:
                    dsk[(self.name,) + out_ind[1:]] = tuple(task)
                # Intermediate block
                else:
                    dsk[(int_name,) + out_ind] = tuple(task)

            self._cached_dict = dsk

        return self._cached_dict

    def __getitem__(self, key):
        return self._dict[key]

    def __iter__(self):
        return iter(self._dict)

    def __len__(self):
        return reduce(mul, self._out_numblocks().values(), 1)

    def _out_numblocks(self):
        d = {}
        indices = {k: v for k, v in self.indices if v is not None}
        for k, v in self.numblocks.items():
            for a, b in zip(indices[k], v):
                d[a] = max(d.get(a, 0), b)

        return {k: v for k, v in d.items() if k in self.output_indices}


def linear_reduction(time_index, antenna1, antenna2,
                     dde1_jones, source_coh, dde2_jones,
                     predict_check_tup, out_dtype):

    (have_ddes1, have_coh, have_ddes2,
     have_dies1, have_bvis, have_dies2) = predict_check_tup

    have_ddes = have_ddes1 and have_ddes2

    if have_ddes:
        cdims = tuple("corr-%d" % i for i in range(len(dde1_jones.shape[4:])))
    elif have_coh:
        cdims = tuple("corr-%d" % i for i in range(len(source_coh.shape[3:])))
    else:
        raise ValueError("need ddes or source coherencies")

    args = [(time_index, ("row",)),
            (antenna1, ("row",)),
            (antenna2, ("row",)),
            (dde1_jones, ("source", "row", "ant", "chan") + cdims),
            (source_coh, ("source", "row", "chan") + cdims),
            (dde2_jones, ("source", "row", "ant", "chan") + cdims),
            (None, None),
            (None, None),
            (None, None)]

    name_args = [(None, None) if a is None else
                 (a.name, i) if isinstance(a, da.Array) else
                 (a, i) for a, i in args]

    numblocks = {a.name: a.numblocks
                 for a, i in args
                 if a is not None}

    lr = LinearReduction(np_predict_vis, ("row", "chan") + cdims,
                         name_args,
                         numblocks=numblocks,
                         feed_index=7,
                         axis='source')

    graph = HighLevelGraph.from_collections(lr.name, lr,
                                            [a for a, i in args
                                             if a is not None])

    chunk_map = {d: arg.chunks[i] for arg, ind in args
                 if arg is not None and ind is not None
                 for i, d in enumerate(ind)}
    chunk_map['row'] = time_index.chunks[0]  # Override

    chunks = tuple(chunk_map[d] for d in ('row', 'chan') + cdims)
    return da.Array(graph, lr.name, chunks, dtype=out_dtype)


def _predict_coh_wrapper(time_index, antenna1, antenna2,
                         dde1_jones, source_coh, dde2_jones,
                         base_vis,
                         reduce_single_source=False):

    if reduce_single_source:
        # All these arrays contract over a single 'source' chunk
        dde1_jones = dde1_jones[0] if dde1_jones else None
        source_coh = source_coh[0] if source_coh else None
        dde2_jones = dde2_jones[0] if dde2_jones else None

    vis = np_predict_vis(time_index, antenna1, antenna2,
                         # dde1_jones contracts over a single 'ant' chunk
                         dde1_jones[0] if dde1_jones else None,
                         source_coh,
                         # dde2_jones contracts over a single 'ant' chunk
                         dde2_jones[0] if dde2_jones else None,
                         None,
                         base_vis,
                         None)

    if reduce_single_source:
        return vis

    return vis[None, ...]


def _predict_dies_wrapper(time_index, antenna1, antenna2,
                          die1_jones, base_vis, die2_jones):

    return np_predict_vis(time_index, antenna1, antenna2,
                          None,
                          None,
                          None,
                          # die1_jones loses the 'ant' dim
                          die1_jones[0] if die1_jones else None,
                          base_vis,
                          # die2_jones loses the 'ant' dim
                          die2_jones[0] if die2_jones else None)


def parallel_reduction(time_index, antenna1, antenna2,
                       dde1_jones, source_coh, dde2_jones,
                       predict_check_tup, out_dtype):
    """ Does a standard dask tree reduction over source coherencies """
    (have_ddes1, have_coh, have_ddes2,
     have_dies1, have_bvis, have_dies2) = predict_check_tup

    have_ddes = have_ddes1 and have_ddes2

    if have_ddes:
        cdims = tuple("corr-%d" % i for i in range(len(dde1_jones.shape[4:])))
    elif have_coh:
        cdims = tuple("corr-%d" % i for i in range(len(source_coh.shape[3:])))
    else:
        raise ValueError("need ddes or source coherencies")

    ajones_dims = ("src", "row", "ant", "chan") + cdims
    src_coh_dims = ("src", "row", "chan") + cdims

    coherencies = da.blockwise(
        _predict_coh_wrapper, src_coh_dims,
        time_index, ("row",),
        antenna1, ("row",),
        antenna2, ("row",),
        dde1_jones, None if dde1_jones is None else ajones_dims,
        source_coh, None if source_coh is None else src_coh_dims,
        dde2_jones, None if dde2_jones is None else ajones_dims,
        None, None,
        # time+row dimension chunks are equivalent but differently sized
        align_arrays=False,
        # Force row dimension to take row chunking scheme,
        # instead of time chunking scheme
        adjust_chunks={'row': time_index.chunks[0]},
        meta=np.empty((0,)*len(src_coh_dims), dtype=out_dtype),
        dtype=out_dtype)

    return coherencies.sum(axis=0)


def apply_dies(time_index, antenna1, antenna2,
               die1_jones, base_vis, die2_jones,
               predict_check_tup, out_dtype):
    """ Apply any Direction-Independent Effects and Base Visibilities """

    # Now apply any Direction Independent Effect Terms
    (have_ddes1, have_coh, have_ddes2,
     have_dies1, have_bvis, have_dies2) = predict_check_tup

    have_dies = have_dies1 and have_dies2

    # Generate strings for the correlation dimensions
    # This also has the effect of checking that we have all valid inputs
    if have_dies:
        cdims = tuple("corr-%d" % i for i in range(len(die1_jones.shape[3:])))
    elif have_bvis:
        cdims = tuple("corr-%d" % i for i in range(len(base_vis.shape[2:])))
    else:
        raise ValueError("Missing both antenna and baseline jones terms")

    # In the case of predict_vis, the "row" and "time" dimensions
    # are intimately related -- a contiguous series of rows
    # are related to a contiguous series of timesteps.
    # This means that the number of chunks of these
    # two dimensions must match even though the chunk sizes may not.
    # blockwise insists on matching chunk sizes.
    # For this reason, we use the lower level blockwise and
    # substitute "row" for "time" in arrays such as dde1_jones
    # and die1_jones.
    gjones_dims = ("row", "ant", "chan") + cdims
    vis_dims = ("row", "chan") + cdims

    return da.blockwise(
        _predict_dies_wrapper, vis_dims,
        time_index, ("row",),
        antenna1, ("row",),
        antenna2, ("row",),
        die1_jones, None if die1_jones is None else gjones_dims,
        base_vis, None if base_vis is None else vis_dims,
        die2_jones, None if die2_jones is None else gjones_dims,
        # time+row dimension chunks are equivalent but differently sized
        align_arrays=False,
        # Force row dimension to take row chunking scheme,
        # instead of time chunking scheme
        adjust_chunks={'row': time_index.chunks[0]},
        meta=np.empty((0,)*len(vis_dims), dtype=out_dtype),
        dtype=out_dtype)


[docs]@requires_optional('dask.array', opt_import_error) def predict_vis(time_index, antenna1, antenna2, dde1_jones=None, source_coh=None, dde2_jones=None, die1_jones=None, base_vis=None, die2_jones=None, streams=None): predict_check_tup = predict_checks(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones) (have_ddes1, have_coh, have_ddes2, have_dies1, have_bvis, have_dies2) = predict_check_tup have_ddes = have_ddes1 and have_ddes2 if have_ddes: if dde1_jones.shape[2] != dde1_jones.chunks[2][0]: raise ValueError("Subdivision of antenna dimension into " "multiple chunks is not supported.") if dde2_jones.shape[2] != dde2_jones.chunks[2][0]: raise ValueError("Subdivision of antenna dimension into " "multiple chunks is not supported.") if dde1_jones.chunks != dde2_jones.chunks: raise ValueError("dde1_jones.chunks != dde2_jones.chunks") if len(dde1_jones.chunks[1]) != len(time_index.chunks[0]): raise ValueError("Number of row chunks (%s) does not equal " "number of time chunks (%s)." % (time_index.chunks[0], dde1_jones.chunks[1])) have_dies = have_dies1 and have_dies2 if have_dies: if die1_jones.shape[1] != die1_jones.chunks[1][0]: raise ValueError("Subdivision of antenna dimension into " "multiple chunks is not supported.") if die2_jones.shape[1] != die2_jones.chunks[1][0]: raise ValueError("Subdivision of antenna dimension into " "multiple chunks is not supported.") if die1_jones.chunks != die2_jones.chunks: raise ValueError("die1_jones.chunks != die2_jones.chunks") if len(die1_jones.chunks[0]) != len(time_index.chunks[0]): raise ValueError("Number of row chunks (%s) does not equal " "number of time chunks (%s)." % (time_index.chunks[0], die1_jones.chunks[1])) # Infer the output dtype dtype_arrays = [dde1_jones, source_coh, dde2_jones, die1_jones, die2_jones] out_dtype = np.result_type(*(np.dtype(a.dtype.name) for a in dtype_arrays if a is not None)) # Apply direction dependent effects if have_coh or have_ddes: # We create separate graphs for computing coherencies and applying # the gains because coherencies are chunked over source which # must be summed and added to the (possibly present) base visibilities if streams is True: sum_coherencies = linear_reduction(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, predict_check_tup, out_dtype) else: sum_coherencies = parallel_reduction(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, predict_check_tup, out_dtype) else: assert have_dies or have_bvis sum_coherencies = None # No more effects to apply, return at this point if not have_dies and not have_bvis: return sum_coherencies # Add coherencies to the base visibilities if sum_coherencies is not None: if not have_bvis: # Set base_vis = summed coherencies base_vis = sum_coherencies predict_check_tup = (have_ddes1, have_coh, have_ddes2, have_dies1, True, have_dies2) else: base_vis += sum_coherencies # Apply direction independent effects return apply_dies(time_index, antenna1, antenna2, die1_jones, base_vis, die2_jones, predict_check_tup, out_dtype)
def wsclean_spectrum_wrapper(flux, coeffs, log_poly, ref_freq, frequency): return wsclean_spectra(flux, coeffs[0], log_poly, ref_freq, frequency) def wsclean_body_wrapper(uvw, lm, source_type, gauss_shape, frequency, spectrum, dtype_): return wsclean_predict_body(uvw[0], lm[0], source_type, gauss_shape[0], frequency, spectrum, dtype_)[None, :]
[docs]@requires_optional('dask.array', opt_import_error) def wsclean_predict(uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency): spectrum_dtype = np.result_type(*(a.dtype for a in (flux, coeffs, log_poly, ref_freq, frequency))) spectrum = da.blockwise(wsclean_spectrum_wrapper, ("source", "chan"), flux, ("source",), coeffs, ("source", "comp"), log_poly, ("source",), ref_freq, ("source",), frequency, ("chan",), dtype=spectrum_dtype) out_dtype = np.result_type(uvw.dtype, lm.dtype, frequency.dtype, spectrum.dtype, np.complex64) vis = da.blockwise(wsclean_body_wrapper, ("source", "row", "chan", "corr"), uvw, ("row", "uvw"), lm, ("source", "lm"), source_type, ("source",), gauss_shape, ("source", "gauss"), frequency, ("chan",), spectrum, ("source", "chan"), out_dtype, None, adjust_chunks={"source": 1}, new_axes={"corr": 1}, dtype=out_dtype) return vis.sum(axis=0)
EXTRA_DASK_ARGS = """ streams : {False, True} If ``True`` the coherencies are serially summed in a linear chain. If ``False``, dask uses a tree style reduction algorithm. """ EXTRA_DASK_NOTES = """ * The ``ant`` dimension should only contain a single chunk equal to the number of antenna. Since each ``row`` can contain any antenna, random access must be preserved along this dimension. * The chunks in the ``row`` and ``time`` dimension **must** align. This subtle point **must be understood otherwise invalid results will be produced** by the chunking scheme. In the example below we have four unique time indices :code:`[0,1,2,3]`, and four unique antenna :code:`[0,1,2,3]` indexing :code:`10` rows. .. code-block:: python # Row indices into the time/antenna indexed arrays time_idx = np.asarray([0,0,1,1,2,2,2,2,3,3]) ant1 = np.asarray( [0,0,0,0,1,1,1,2,2,3] ant2 = np.asarray( [0,1,2,3,1,2,3,2,3,3]) A reasonable chunking scheme for the ``row`` and ``time`` dimension would be :code:`(4,4,2)` and :code:`(2,1,1)` respectively. Another way of explaining this is that the first four rows contain two unique timesteps, the second four rows contain one unique timestep and the last two rows contain one unique timestep. Some rules of thumb: 1. The number chunks in ``row`` and ``time`` must match although the individual chunk sizes need not. 2. Unique timesteps should not be split across row chunks. 3. For a Measurement Set whose rows are ordered on the ``TIME`` column, the following is a good way of obtaining the row chunking strategy: .. code-block:: python import numpy as np import pyrap.tables as pt ms = pt.table("data.ms") times = ms.getcol("TIME") unique_times, chunks = np.unique(times, return_counts=True) 4. Use :func:`~africanus.util.shapes.aggregate_chunks` to aggregate multiple ``row`` and ``time`` chunks into chunks large enough such that functions operating on the resulting data can drop the GIL and spend time processing the data. Expanding the previous example: .. code-block:: python # Aggregate row utimes = unique_times.size # Single chunk for each unique time time_chunks = (1,)*utimes # Aggregate row chunks into chunks <= 10000 aggregate_chunks((chunks, time_chunks), (10000, utimes)) """ try: predict_vis.__doc__ = PREDICT_DOCS.substitute( array_type=":class:`dask.array.Array`", get_time_index=":code:`time.map_blocks(" "lambda a: np.unique(a, " "return_inverse=True)[1])`", extra_args=EXTRA_DASK_ARGS, extra_notes=EXTRA_DASK_NOTES) except AttributeError: pass wsclean_predict.__doc__ = WSCLEAN_PREDICT_DOCS.substitute( array_type=":class:`dask.array.Array`")