# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from operator import getitem
from africanus.averaging.time_and_channel_mapping import (row_mapper,
channel_mapper)
from africanus.averaging.time_and_channel_avg import (row_average,
row_chan_average,
chan_average,
merge_flags,
AVERAGING_DOCS,
AverageOutput,
ChannelAverageOutput,
RowAverageOutput,
RowChanAverageOutput)
from africanus.util.requirements import requires_optional
import numpy as np
try:
from dask.base import tokenize
import dask.array as da
import dask.blockwise as db
from dask.highlevelgraph import HighLevelGraph
except ImportError as e:
dask_import_error = e
else:
dask_import_error = None
def _row_chan_metadata(arrays, chan_bin_size):
""" Create dask array with channel metadata for each chunk channel """
for array in arrays:
if array is None:
continue
# Create a dask channel mapping structure
name = "channel-mapper-" + tokenize(array.chunks[1], chan_bin_size)
layers = {(name, i): (channel_mapper, c, chan_bin_size)
for i, c in enumerate(array.chunks[1])}
graph = HighLevelGraph.from_collections(name, layers, ())
chunks = (array.chunks[1],)
chan_mapper = da.Array(graph, name, chunks, dtype=np.object)
return chan_mapper
return None
def _dask_row_mapper(time, interval, antenna1, antenna2,
flag_row=None, time_bin_secs=1.0):
""" Create a dask row mapping structure for each row chunk """
return da.blockwise(row_mapper, ("row",),
time, ("row",),
interval, ("row",),
antenna1, ("row",),
antenna2, ("row",),
flag_row, None if flag_row is None else ("row",),
time_bin_secs=time_bin_secs,
dtype=np.object)
def _getitem_row(avg, idx, dtype):
""" Extract row-like arrays from a dask array of tuples """
name = ("row-average-getitem-%d-" % idx) + tokenize(avg, idx)
layers = db.blockwise(getitem, name, ("row",),
avg.name, ("row",),
idx, None,
numblocks={avg.name: avg.numblocks})
graph = HighLevelGraph.from_collections(name, layers, (avg,))
return da.Array(graph, name, avg.chunks, dtype=dtype)
def _dask_row_average(row_meta, ant1, ant2, flag_row=None,
time_centroid=None, exposure=None, uvw=None,
weight=None, sigma=None):
""" Average row-based dask arrays """
rd = ("row",)
rcd = ("row", "corr")
avg = da.blockwise(row_average, rd,
row_meta, rd,
ant1, rd,
ant2, rd,
flag_row, None if flag_row is None else rd,
time_centroid, None if time_centroid is None else rd,
exposure, None if exposure is None else rd,
uvw, None if uvw is None else ("row", "3"),
weight, None if weight is None else rcd,
sigma, None if sigma is None else rcd,
adjust_chunks={"row": lambda x: np.nan},
dtype=np.object)
tuple_gets = (None if a is None else _getitem_row(avg, i, a.dtype)
for i, a in enumerate([ant1, ant2, time_centroid, exposure,
uvw, weight, sigma]))
return RowAverageOutput(*tuple_gets)
def _getitem_row_chan(avg, idx, dtype):
""" Extract (row,chan,corr) arrays from dask array of tuples """
name = ("row-chan-average-getitem-%d-" % idx) + tokenize(avg, idx)
dim = ("row", "chan", "corr")
layers = db.blockwise(getitem, name, dim,
avg.name, dim,
idx, None,
numblocks={avg.name: avg.numblocks})
graph = HighLevelGraph.from_collections(name, layers, (avg,))
return da.Array(graph, name, avg.chunks, dtype=dtype)
_row_chan_avg_dims = ("row", "chan", "corr")
def _dask_row_chan_average(row_meta, chan_meta, flag_row=None, weight=None,
vis=None, flag=None,
weight_spectrum=None, sigma_spectrum=None,
chan_bin_size=1):
""" Average (row,chan,corr)-based dask arrays """
# We don't know how many rows are in each row chunk,
# but we can simply divide each channel chunk size by the bin size
adjust_chunks = {
"row": lambda r: np.nan,
"chan": lambda c: (c + chan_bin_size - 1) // chan_bin_size
}
flag_row_dims = None if flag_row is None else ("row",)
weight_dims = None if weight is None else ("row",)
vis_dims = None if vis is None else _row_chan_avg_dims
flag_dims = None if flag is None else _row_chan_avg_dims
ws_dims = None if weight_spectrum is None else _row_chan_avg_dims
ss_dims = None if sigma_spectrum is None else _row_chan_avg_dims
avg = da.blockwise(row_chan_average, _row_chan_avg_dims,
row_meta, ("row",),
chan_meta, ("chan",),
flag_row, flag_row_dims,
weight, weight_dims,
vis, vis_dims,
flag, flag_dims,
weight_spectrum, ws_dims,
sigma_spectrum, ss_dims,
adjust_chunks=adjust_chunks,
dtype=np.object)
tuple_gets = (None if a is None else _getitem_row_chan(avg, i, a.dtype)
for i, a in enumerate([vis, flag,
weight_spectrum,
sigma_spectrum]))
return RowChanAverageOutput(*tuple_gets)
def _getitem_chan(avg, idx, dtype):
""" Extract row-like arrays from a dask array of tuples """
name = ("chan-average-getitem-%d-" % idx) + tokenize(avg, idx)
layers = db.blockwise(getitem, name, ("chan",),
avg.name, ("chan",),
idx, None,
numblocks={avg.name: avg.numblocks})
graph = HighLevelGraph.from_collections(name, layers, (avg,))
return da.Array(graph, name, avg.chunks, dtype=dtype)
def _dask_chan_average(chan_meta, chan_freq=None, chan_width=None,
chan_bin_size=1):
adjust_chunks = {
"chan": lambda c: (c + chan_bin_size - 1) // chan_bin_size
}
avg = da.blockwise(chan_average, ("chan",),
chan_meta, ("chan",),
chan_freq, None if chan_freq is None else ("chan",),
chan_width, None if chan_width is None else ("chan",),
adjust_chunks=adjust_chunks,
dtype=np.object)
tuple_gets = (None if a is None else _getitem_chan(avg, i, a.dtype)
for i, a in enumerate([chan_freq, chan_width]))
return ChannelAverageOutput(*tuple_gets)
def _dask_merge_flags(flag_row, flag):
""" Perform flag merging on dask arrays """
if flag_row is None and flag is not None:
return da.blockwise(merge_flags, "r",
flag_row, None,
flag, "rfc",
concatenate=True,
dtype=flag.dtype)
elif flag_row is not None and flag is None:
return da.blockwise(merge_flags, "r",
flag_row, "r",
None, None,
dtype=flag_row.dtype)
elif flag_row is not None and flag is not None:
return da.blockwise(merge_flags, "r",
flag_row, "r",
flag, "rfc",
concatenate=True,
dtype=flag_row.dtype)
else:
return None
[docs]@requires_optional("dask.array", dask_import_error)
def time_and_channel(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
chan_freq=None, chan_width=None,
vis=None, flag=None,
weight_spectrum=None, sigma_spectrum=None,
time_bin_secs=1.0, chan_bin_size=1):
row_chan_arrays = (vis, flag, weight_spectrum, sigma_spectrum)
# The flow of this function should match that of the numba
# time_and_channel implementation
# Merge flag_row and flag arrays
flag_row = _dask_merge_flags(flag_row, flag)
# Generate row mapping metadata
row_meta = _dask_row_mapper(time, interval,
antenna1, antenna2,
flag_row=flag_row,
time_bin_secs=time_bin_secs)
# Generate channel mapping metadata
chan_meta = _row_chan_metadata(row_chan_arrays, chan_bin_size)
# Average row data
row_data = _dask_row_average(row_meta, antenna1, antenna2,
flag_row=flag_row,
time_centroid=time_centroid,
exposure=exposure, uvw=uvw,
weight=weight, sigma=sigma)
# Average channel data
row_chan_data = _dask_row_chan_average(row_meta, chan_meta,
flag_row=flag_row, weight=weight,
vis=vis, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum,
chan_bin_size=chan_bin_size)
chan_data = _dask_chan_average(chan_meta, chan_freq=chan_freq,
chan_width=chan_width)
# Merge output tuples
return AverageOutput(_getitem_row(row_meta, 1, time.dtype),
_getitem_row(row_meta, 2, interval.dtype),
(_getitem_row(row_meta, 3, flag_row.dtype)
if flag_row is not None else None),
row_data.antenna1,
row_data.antenna2,
row_data.time_centroid,
row_data.exposure,
row_data.uvw,
row_data.weight,
row_data.sigma,
chan_data.chan_freq,
chan_data.chan_width,
row_chan_data.vis,
row_chan_data.flag,
row_chan_data.weight_spectrum,
row_chan_data.sigma_spectrum)
try:
time_and_channel.__doc__ = AVERAGING_DOCS.substitute(
array_type=":class:`dask.array.Array`")
except AttributeError:
pass