Source code for africanus.dft.dask

# -*- coding: utf-8 -*-

from africanus.dft.kernels import im_to_vis_docs, vis_to_im_docs
from africanus.dft.kernels import im_to_vis as np_im_to_vis
from africanus.dft.kernels import vis_to_im as np_vis_to_im

from africanus.util.docs import doc_tuple_to_str
from africanus.util.requirements import requires_optional

import numpy as np

try:
    import dask.array as da
except ImportError as e:
    dask_import_error = e
else:
    dask_import_error = None


def _im_to_vis_wrapper(image, uvw, lm, frequency, convention, dtype_):
    return np_im_to_vis(image[0], uvw[0], lm[0][0], frequency,
                        convention=convention, dtype=dtype_)


[docs]@requires_optional('dask.array', dask_import_error) def im_to_vis(image, uvw, lm, frequency, convention='fourier', dtype=np.complex128): """ Dask wrapper for im_to_vis function """ if lm.chunks[0][0] != lm.shape[0]: raise ValueError("lm chunks must match lm shape " "on first axis") if image.chunks[0][0] != image.shape[0]: raise ValueError("Image chunks must match image " "shape on first axis") if image.chunks[0][0] != lm.chunks[0][0]: raise ValueError("Image chunks and lm chunks must " "match on first axis") if image.chunks[1] != frequency.chunks[0]: raise ValueError("Image chunks must match frequency " "chunks on second axis") return da.core.blockwise(_im_to_vis_wrapper, ("row", "chan", "corr"), image, ("source", "chan", "corr"), uvw, ("row", "(u,v,w)"), lm, ("source", "(l,m)"), frequency, ("chan",), convention=convention, dtype=dtype, dtype_=dtype)
def _vis_to_im_wrapper(vis, uvw, lm, frequency, flags, convention, dtype_): return np_vis_to_im(vis, uvw[0], lm[0], frequency, flags, convention=convention, dtype=dtype_)[None, :]
[docs]@requires_optional('dask.array', dask_import_error) def vis_to_im(vis, uvw, lm, frequency, flags, convention='fourier', dtype=np.float64): """ Dask wrapper for vis_to_im function """ if vis.chunks[0] != uvw.chunks[0]: raise ValueError("Vis chunks and uvw chunks must " "match on first axis") if vis.chunks[1] != frequency.chunks[0]: raise ValueError("Vis chunks must match frequency " "chunks on second axis") if vis.chunks != flags.chunks: raise ValueError("Vis chunks must match flags " "chunks on all axes") ims = da.core.blockwise(_vis_to_im_wrapper, ("row", "source", "chan", "corr"), vis, ("row", "chan", "corr"), uvw, ("row", "(u,v,w)"), lm, ("source", "(l,m)"), frequency, ("chan",), flags, ("row", "chan", "corr"), adjust_chunks={"row": 1}, convention=convention, dtype=dtype, dtype_=dtype) return ims.sum(axis=0)
im_to_vis.__doc__ = doc_tuple_to_str(im_to_vis_docs, [(":class:`numpy.ndarray`", ":class:`dask.array.Array`")]) vis_to_im.__doc__ = doc_tuple_to_str(vis_to_im_docs, [(":class:`numpy.ndarray`", ":class:`dask.array.Array`")])