Source code for africanus.rime.dask_predict

# -*- coding: utf-8 -*-


try:
    from collections.abc import Mapping
except ImportError:
    from collections import Mapping

from functools import reduce
from itertools import product
from operator import mul

import numpy as np

from africanus.util.requirements import requires_optional

from africanus.rime.predict import (PREDICT_DOCS, predict_checks,
                                    predict_vis as np_predict_vis)

try:
    from dask.blockwise import blockwise
    from dask.base import tokenize
    import dask.array as da
    from dask.highlevelgraph import HighLevelGraph
except ImportError as e:
    opt_import_error = e
else:
    opt_import_error = None


def _source_stream_blocks(source_blocks, streams):
    return (source_blocks + streams - 1) // streams


def _extract_blocks(time_index, dde1_jones, source_coh, dde2_jones):
    """
    Returns
    -------
    blocks : tuple
        :code:`(source, row, ant, chan, corr1, ..., corrn)
    """

    if dde1_jones is not None:
        return ((dde1_jones.numblocks[0], time_index.numblocks[0]) +
                (1, dde1_jones.numblocks[3]) +
                dde1_jones.numblocks[4:])
    elif source_coh is not None:
        return (source_coh.numblocks[:2] +
                (1, source_coh.numblocks[2]) +
                source_coh.numblocks[3:])
    else:
        raise ValueError("need ddes or coherencies")


def _extract_chunks(time_index, dde1_jones, source_coh, dde2_jones):
    """
    Returns
    -------
    chunks : tuple
        :code:`(source, row, chan, corr1, ..., corrn)
    """

    if dde1_jones is not None:
        return ((dde1_jones.chunks[0], time_index.chunks[0]) +
                (dde1_jones.chunks[3],) +
                dde1_jones.chunks[4:])
    elif source_coh is not None:
        return source_coh.chunks
    else:
        raise ValueError("need ddes or coherencies")


class CoherencyStreamReduction(Mapping):
    """
    tl;dr this is a dictionary that is expanded in place when
    first acccessed. Saves memory when pickled for sending
    to the dask scheduler.

    See :class:`dask.blockwise.Blockwise` for further insight.

    Produces graph serially summing coherencies in
    ``stream`` parallel streams.
    """

    def __init__(self, time_index, antenna1, antenna2,
                 dde1_jones, source_coh, dde2_jones,
                 out_name, streams):
        self.time_index_name = None if time_index is None else time_index.name
        self.ant1_name = None if antenna1 is None else antenna1.name
        self.ant2_name = None if antenna2 is None else antenna2.name
        self.dde1_name = None if dde1_jones is None else dde1_jones.name
        self.coh_name = None if source_coh is None else source_coh.name
        self.dde2_name = None if dde2_jones is None else dde2_jones.name

        self.out_name = out_name

        self.blocks = _extract_blocks(time_index, dde1_jones,
                                      source_coh, dde2_jones)
        self.streams = streams

    @property
    def _dict(self):
        if hasattr(self, "_cached_dict"):
            return self._cached_dict
        else:
            self._cached_dict = self._create_dict()
            return self._cached_dict

    def __getitem__(self, key):
        return self._dict[key]

    def __iter__(self):
        return iter(self._dict)

    def __len__(self):
        # Extract dimension blocks
        (source, row, _, chan), corr = self.blocks[:4], self.blocks[4:]
        return reduce(mul, (source, row, chan) + corr, 1)

    def _create_dict(self):
        # Graph dictionary
        layers = {}

        # For loop performance
        out_name = self.out_name
        ti = self.time_index_name
        a1 = self.ant1_name
        a2 = self.ant2_name
        dde1 = self.dde1_name
        coh = self.coh_name
        dde2 = self.dde2_name

        # Extract dimension blocks
        (source_blocks, row_blocks, ant_blocks,
         chan_blocks), corr_blocks = self.blocks[:4], self.blocks[4:]

        assert ant_blocks == 1
        ab = 0

        # Subdivide number of source blocks by number of streams
        source_block_chunks = _source_stream_blocks(source_blocks,
                                                    self.streams)

        # Iterator of block id's for row, channel and correlation blocks
        # We don't reduce over these dimensions
        block_ids = enumerate(product(range(row_blocks), range(chan_blocks),
                                      *[range(cb) for cb in corr_blocks]))

        for flat_bid, bid in block_ids:
            rb, fb = bid[0:2]
            cb = bid[2:]

            # Create the streamed reduction proper.
            # For a stream, the base visibilities are set to the result
            # of the previous result in the stream (last_key)
            for sb_start in range(0, source_blocks, source_block_chunks):
                sb_end = min(sb_start + source_block_chunks, source_blocks)
                last_key = None

                for sb in range(sb_start, sb_end):
                    # Dask task object calling predict vis
                    task = (np_predict_vis,
                            (ti, rb), (a1, rb), (a2, rb),
                            (dde1, sb, rb, ab, fb) + cb if dde1 else None,
                            (coh, sb, rb, fb) + cb if coh else None,
                            (dde2, sb, rb, ab, fb) + cb if dde2 else None,
                            None, last_key, None)

                    key = (out_name, sb, flat_bid)
                    layers[key] = task
                    last_key = key

        return layers


class CoherencyFinalReduction(Mapping):
    """
    tl;dr this is a dictionary that is expanded in place when
    first acccessed. Saves memory when pickled for sending
    to the dask scheduler.

    See :class:`dask.blockwise.Blockwise` for further insight.

    Produces graph reducing results of ``stream`` parallel streams in
    CoherencyStreamReduction.
    """

    def __init__(self, out_name, coherency_stream_reduction):
        self.in_name = coherency_stream_reduction.out_name
        self.blocks = coherency_stream_reduction.blocks
        self.streams = coherency_stream_reduction.streams
        self.out_name = out_name

    @property
    def _dict(self):
        if hasattr(self, "_cached_dict"):
            return self._cached_dict
        else:
            self._cached_dict = self._create_dict()
            return self._cached_dict

    def __getitem__(self, key):
        return self._dict[key]

    def __iter__(self):
        return iter(self._dict)

    def __len__(self):
        (source, row, _, chan), corrs = self.blocks[:4], self.blocks[4:]
        return reduce(mul, (source, row, chan) + corrs, 1)

    def _create_dict(self):
        (source, row, _, chan), corrs = self.blocks[:4], self.blocks[4:]

        # Iterator of block id's for row, channel and correlation blocks
        # We don't reduce over these dimensions
        block_ids = enumerate(product(range(row), range(chan),
                                      *[range(cb) for cb in corrs]))

        source_block_chunks = _source_stream_blocks(source, self.streams)

        layers = {}

        # This looping structure should match
        for flat_bid, bid in block_ids:
            rb, fb = bid[0:2]
            cb = bid[2:]

            last_stream_keys = []

            for sb_start in range(0, source, source_block_chunks):
                sb_end = min(sb_start + source_block_chunks, source)
                key = (sb_end - 1, flat_bid)
                last_stream_keys.append((self.in_name, sb_end - 1, flat_bid))

            key = (self.out_name, rb, fb) + cb
            task = (sum, last_stream_keys)
            layers[key] = task

        return layers


def _predict_coh_wrapper(time_index, antenna1, antenna2,
                         dde1_jones, source_coh, dde2_jones,
                         die1_jones, base_vis, die2_jones):

    return (np_predict_vis(time_index, antenna1, antenna2,
                           # dde1_jones loses the 'ant' dim
                           dde1_jones[0] if dde1_jones else None,
                           # source_coh loses the 'source' dim
                           source_coh,
                           # dde2_jones loses the 'source' and 'ant' dims
                           dde2_jones[0] if dde2_jones else None,
                           # die1_jones loses the 'ant' dim
                           die1_jones[0] if die1_jones else None,
                           base_vis,
                           # die2_jones loses the 'ant' dim
                           die2_jones[0] if die2_jones else None)
            # Introduce an extra dimension (source dim reduced to 1)
            [None, ...])


def _predict_dies_wrapper(time_index, antenna1, antenna2,
                          dde1_jones, source_coh, dde2_jones,
                          die1_jones, base_vis, die2_jones):

    return np_predict_vis(time_index, antenna1, antenna2,
                          # dde1_jones loses the 'source' and 'ant' dims
                          dde1_jones[0][0] if dde1_jones else None,
                          # source_coh loses the 'source' dim
                          source_coh[0] if source_coh else None,
                          # dde2_jones loses the 'source' and 'ant' dims
                          dde2_jones[0][0] if dde2_jones else None,
                          # die1_jones loses the 'ant' dim
                          die1_jones[0] if die1_jones else None,
                          base_vis,
                          # die2_jones loses the 'ant' dim
                          die2_jones[0] if die2_jones else None)


def stream_reduction(time_index, antenna1, antenna2,
                     dde1_jones, source_coh, dde2_jones,
                     predict_check_tup, out_dtype, streams):
    """
    Reduces source coherencies + ddes over the source dimension in
    ``N`` parallel streams.

    This is accomplished by calling predict_vis on on ddes and source
    coherencies to produce visibilities which are passed into
    the `base_vis` argument of ``predict_vis`` for the next chunk.
    """

    # Unique name and token for this operation
    token = tokenize(time_index, antenna1, antenna2,
                     dde1_jones, source_coh, dde2_jones,
                     streams)

    name = 'stream-coherency-reduction-' + token

    # Number of dim blocks
    blocks = _extract_blocks(time_index, dde1_jones, source_coh, dde2_jones)
    (src_blocks, row_blocks, _,
     chan_blocks), corr_blocks = blocks[:4], blocks[4:]

    # Total number of other dimension blocks
    nblocks = reduce(mul, (row_blocks, chan_blocks) + corr_blocks, 1)

    # Create the compressed mapping
    layers = CoherencyStreamReduction(time_index, antenna1, antenna2,
                                      dde1_jones, source_coh, dde2_jones,
                                      name, streams)

    # Create the graph
    extra_deps = [a for a in (dde1_jones, source_coh, dde2_jones)
                  if a is not None]
    deps = [time_index, antenna1, antenna2] + extra_deps

    graph = HighLevelGraph.from_collections(name, layers, deps)

    chunks = ((1,) * src_blocks, (1,)*nblocks)
    # This should never be directly computed, reported chunks
    # and dtype don't match the actual data. We create it
    # because it makes chaining HighLevelGraphs easier
    stream_reduction = da.Array(graph, name, chunks, dtype=np.int8)

    name = "coherency-reduction-" + tokenize(stream_reduction)
    layers = CoherencyFinalReduction(name, layers)
    graph = HighLevelGraph.from_collections(name, layers, [stream_reduction])

    chunks = _extract_chunks(time_index, dde1_jones, source_coh, dde2_jones)
    return da.Array(graph, name, chunks[1:], dtype=out_dtype)


def fan_reduction(time_index, antenna1, antenna2,
                  dde1_jones, source_coh, dde2_jones,
                  predict_check_tup, out_dtype):
    """ Does a standard dask tree reduction over source coherencies """
    (have_ddes1, have_coh, have_ddes2,
     have_dies1, have_bvis, have_dies2) = predict_check_tup

    have_ddes = have_ddes1 and have_ddes2

    if have_ddes:
        cdims = tuple("corr-%d" % i for i in range(len(dde1_jones.shape[4:])))
    elif have_coh:
        cdims = tuple("corr-%d" % i for i in range(len(source_coh.shape[3:])))
    else:
        raise ValueError("need ddes or source coherencies")

    ajones_dims = ("src", "row", "ant", "chan") + cdims

    # Setup
    # 1. Optional blockwise arguments
    # 2. Optional numblocks kwarg
    # 3. HighLevelGraph dependencies
    bw_args = [time_index.name, ("row",),
               antenna1.name, ("row",),
               antenna2.name, ("row",)]
    numblocks = {
        time_index.name: time_index.numblocks,
        antenna1.name: antenna1.numblocks,
        antenna2.name: antenna2.numblocks
    }

    # Dependencies
    deps = [time_index, antenna1, antenna2]

    # Handle presence/absence of dde1_jones
    if have_ddes:
        bw_args.extend([dde1_jones.name, ajones_dims])
        numblocks[dde1_jones.name] = dde1_jones.numblocks
        deps.append(dde1_jones)
        other_chunks = dde1_jones.chunks[3:]
        src_chunks = dde1_jones.chunks[0]
    else:
        bw_args.extend([None, None])

    # Handle presence/absence of source_coh
    if have_coh:
        bw_args.extend([source_coh.name, ("src", "row", "chan") + cdims])
        numblocks[source_coh.name] = source_coh.numblocks
        deps.append(source_coh)
        other_chunks = source_coh.chunks[2:]
        src_chunks = source_coh.chunks[0]
    else:
        bw_args.extend([None, None])

    # Handle presence/absence of dde2_jones
    if have_ddes:
        bw_args.extend([dde2_jones.name, ajones_dims])
        numblocks[dde2_jones.name] = dde2_jones.numblocks
        deps.append(dde2_jones)
        other_chunks = dde2_jones.chunks[3:]
        src_chunks = dde2_jones.chunks[0]
    else:
        bw_args.extend([None, None])

    # die1_jones, base_vis and die2_jones absent for this part of the graph
    bw_args.extend([None, None, None, None, None, None])

    assert len(bw_args) // 2 == 9, len(bw_args) // 2

    token = da.core.tokenize(time_index, antenna1, antenna2,
                             dde1_jones, source_coh, dde2_jones)
    name = "-".join(("predict-vis-sum-coh", token))
    layer = blockwise(_predict_coh_wrapper,
                      name, ("src", "row", "chan") + cdims,
                      *bw_args, numblocks=numblocks)

    graph = HighLevelGraph.from_collections(name, layer, deps)

    # We can infer output chunk sizes from source_coh
    chunks = ((1,)*len(src_chunks), time_index.chunks[0],) + other_chunks

    # Create array
    sum_coherencies = da.Array(graph, name, chunks, dtype=out_dtype)

    # Reduce source axis
    return sum_coherencies.sum(axis=0)


def apply_dies(time_index, antenna1, antenna2,
               die1_jones, base_vis, die2_jones,
               predict_check_tup, out_dtype):
    """ Apply any Direction-Independent Effects and Base Visibilities """

    # Now apply any Direction Independent Effect Terms
    (have_ddes1, have_coh, have_ddes2,
     have_dies1, have_bvis, have_dies2) = predict_check_tup

    have_dies = have_dies1 and have_dies2

    # Generate strings for the correlation dimensions
    # This also has the effect of checking that we have all valid inputs
    if have_dies:
        cdims = tuple("corr-%d" % i for i in range(len(die1_jones.shape[3:])))
    elif have_bvis:
        cdims = tuple("corr-%d" % i for i in range(len(base_vis.shape[2:])))
    else:
        raise ValueError("Missing both antenna and baseline jones terms")

    # In the case of predict_vis, the "row" and "time" dimensions
    # are intimately related -- a contiguous series of rows
    # are related to a contiguous series of timesteps.
    # This means that the number of chunks of these
    # two dimensions must match even though the chunk sizes may not.
    # blockwise insists on matching chunk sizes.
    # For this reason, we use the lower level blockwise and
    # substitute "row" for "time" in arrays such as dde1_jones
    # and die1_jones.
    gjones_dims = ("row", "ant", "chan") + cdims

    # Setup
    # 1. Optional blockwise arguments
    # 2. Optional numblocks kwarg
    # 3. HighLevelGraph dependencies
    bw_args = [time_index.name, ("row",),
               antenna1.name, ("row",),
               antenna2.name, ("row",)]
    numblocks = {
        time_index.name: time_index.numblocks,
        antenna1.name: antenna1.numblocks,
        antenna2.name: antenna2.numblocks
    }

    deps = [time_index, antenna1, antenna2]

    # dde1_jones, source_coh and dde2_jones not present
    # these are already applied into sum_coherencies
    bw_args.extend([None, None, None, None, None, None])

    if have_dies:
        bw_args.extend([die1_jones.name, gjones_dims])
        numblocks[die1_jones.name] = die1_jones.numblocks
        deps.append(die1_jones)
        other_chunks = die1_jones.chunks[2:]
    else:
        bw_args.extend([None, None])

    if have_bvis:
        bw_args.extend([base_vis.name, ("row", "chan") + cdims])
        numblocks[base_vis.name] = base_vis.numblocks
        deps.append(base_vis)
        other_chunks = base_vis.chunks[1:]
    else:
        bw_args.extend([None, None])

    if have_dies:
        bw_args.extend([die2_jones.name, gjones_dims])
        numblocks[die2_jones.name] = die2_jones.numblocks
        deps.append(die2_jones)
        other_chunks = die2_jones.chunks[2:]
    else:
        bw_args.extend([None, None])

    assert len(bw_args) // 2 == 9

    token = da.core.tokenize(time_index, antenna1, antenna2,
                             die1_jones, base_vis, die2_jones)
    name = '-'.join(("predict-vis-apply-dies", token))
    layer = blockwise(_predict_dies_wrapper,
                      name, ("row", "chan") + cdims,
                      *bw_args, numblocks=numblocks)

    graph = HighLevelGraph.from_collections(name, layer, deps)
    chunks = (time_index.chunks[0],) + other_chunks

    return da.Array(graph, name, chunks, dtype=out_dtype)


[docs]@requires_optional('dask.array', opt_import_error) def predict_vis(time_index, antenna1, antenna2, dde1_jones=None, source_coh=None, dde2_jones=None, die1_jones=None, base_vis=None, die2_jones=None, streams=None): predict_check_tup = predict_checks(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones) (have_ddes1, have_coh, have_ddes2, have_dies1, have_bvis, have_dies2) = predict_check_tup have_ddes = have_ddes1 and have_ddes2 if have_ddes: if dde1_jones.shape[2] != dde1_jones.chunks[2][0]: raise ValueError("Subdivision of antenna dimension into " "multiple chunks is not supported.") if dde2_jones.shape[2] != dde2_jones.chunks[2][0]: raise ValueError("Subdivision of antenna dimension into " "multiple chunks is not supported.") if dde1_jones.chunks != dde2_jones.chunks: raise ValueError("dde1_jones.chunks != dde2_jones.chunks") if len(dde1_jones.chunks[1]) != len(time_index.chunks[0]): raise ValueError("Number of row chunks (%s) does not equal " "number of time chunks (%s)." % (time_index.chunks[0], dde1_jones.chunks[1])) have_dies = have_dies1 and have_dies2 if have_dies: if die1_jones.shape[1] != die1_jones.chunks[1][0]: raise ValueError("Subdivision of antenna dimension into " "multiple chunks is not supported.") if die2_jones.shape[1] != die2_jones.chunks[1][0]: raise ValueError("Subdivision of antenna dimension into " "multiple chunks is not supported.") if die1_jones.chunks != die2_jones.chunks: raise ValueError("die1_jones.chunks != die2_jones.chunks") if len(die1_jones.chunks[0]) != len(time_index.chunks[0]): raise ValueError("Number of row chunks (%s) does not equal " "number of time chunks (%s)." % (time_index.chunks[0], die1_jones.chunks[1])) # Infer the output dtype dtype_arrays = [dde1_jones, source_coh, dde2_jones, die1_jones, die2_jones] out_dtype = np.result_type(*(np.dtype(a.dtype.name) for a in dtype_arrays if a is not None)) # Apply direction dependent effects if have_coh or have_ddes: # We create separate graphs for computing coherencies and applying # the gains because coherencies are chunked over source which # must be summed and added to the (possibly present) base visibilities if streams is not None: sum_coherencies = stream_reduction(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, predict_check_tup, out_dtype, streams=streams) else: sum_coherencies = fan_reduction(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, predict_check_tup, out_dtype) else: assert have_dies or have_bvis sum_coherencies = None # No more effects to apply, return at this point if not have_dies and not have_bvis: return sum_coherencies # Add coherencies to the base visibilities if sum_coherencies is not None: if not have_bvis: # Set base_vis = summed coherencies base_vis = sum_coherencies predict_check_tup = (have_ddes1, have_coh, have_ddes2, have_dies1, True, have_dies2) else: base_vis += sum_coherencies # Apply direction independent effects return apply_dies(time_index, antenna1, antenna2, die1_jones, base_vis, die2_jones, predict_check_tup, out_dtype)
EXTRA_DASK_ARGS = """ streams : int, optional Specifies the degree of parallelism along the source dimension. By default, dask uses a tree style reduction algorithm which can require large amounts of memory. Specifying this parameter constrains the dask graph to serially sum coherencies in a specified number of streams, reducing overall memory usage. If ``None``, defaults to a standard, memory-intensive tree style algorithm. Defaults to 1, which means that the source coherencies for each visibility chunk are serially summed, meaning that parallelism will only exists along the row and chan dimensions. """ EXTRA_DASK_NOTES = """ * The ``ant`` dimension should only contain a single chunk equal to the number of antenna. Since each ``row`` can contain any antenna, random access must be preserved along this dimension. * The chunks in the ``row`` and ``time`` dimension **must** align. This subtle point **must be understood otherwise invalid results will be produced** by the chunking scheme. In the example below we have four unique time indices :code:`[0,1,2,3]`, and four unique antenna :code:`[0,1,2,3]` indexing :code:`10` rows. .. code-block:: python # Row indices into the time/antenna indexed arrays time_idx = np.asarray([0,0,1,1,2,2,2,2,3,3]) ant1 = np.asarray( [0,0,0,0,1,1,1,2,2,3] ant2 = np.asarray( [0,1,2,3,1,2,3,2,3,3]) A reasonable chunking scheme for the ``row`` and ``time`` dimension would be :code:`(4,4,2)` and :code:`(2,1,1)` respectively. Another way of explaining this is that the first four rows contain two unique timesteps, the second four rows contain one unique timestep and the last two rows contain one unique timestep. Some rules of thumb: 1. The number chunks in ``row`` and ``time`` must match although the individual chunk sizes need not. 2. Unique timesteps should not be split across row chunks. 3. For a Measurement Set whose rows are ordered on the ``TIME`` column, the following is a good way of obtaining the row chunking strategy: .. code-block:: python import numpy as np import pyrap.tables as pt ms = pt.table("data.ms") times = ms.getcol("TIME") unique_times, chunks = np.unique(times, return_counts=True) 4. Use :func:`~africanus.util.shapes.aggregate_chunks` to aggregate multiple ``row`` and ``time`` chunks into chunks large enough such that functions operating on the resulting data can drop the GIL and spend time processing the data. Expanding the previous example: .. code-block:: python # Aggregate row utimes = unique_times.size # Single chunk for each unique time time_chunks = (1,)*utimes # Aggregate row chunks into chunks <= 10000 aggregate_chunks((chunks, time_chunks), (10000, utimes)) """ try: predict_vis.__doc__ = PREDICT_DOCS.substitute( array_type=":class:`dask.array.Array`", get_time_index=":code:`time.map_blocks(" "lambda a: np.unique(a, " "return_inverse=True)[1])`", extra_args=EXTRA_DASK_ARGS, extra_notes=EXTRA_DASK_NOTES) except AttributeError: pass