# -*- coding: utf-8 -*-
from africanus.dft.kernels import im_to_vis_docs, vis_to_im_docs
from africanus.dft.kernels import im_to_vis as np_im_to_vis
from africanus.dft.kernels import vis_to_im as np_vis_to_im
from africanus.util.docs import doc_tuple_to_str
from africanus.util.requirements import requires_optional
import numpy as np
try:
import dask.array as da
except ImportError as e:
dask_import_error = e
else:
dask_import_error = None
def _im_to_vis_wrapper(image, uvw, lm, frequency, convention, dtype_):
return np_im_to_vis(
image[0], uvw[0], lm[0][0], frequency, convention=convention, dtype=dtype_
)
[docs]
@requires_optional("dask.array", dask_import_error)
def im_to_vis(image, uvw, lm, frequency, convention="fourier", dtype=np.complex128):
"""Dask wrapper for im_to_vis function"""
if lm.chunks[0][0] != lm.shape[0]:
raise ValueError("lm chunks must match lm shape " "on first axis")
if image.chunks[0][0] != image.shape[0]:
raise ValueError("Image chunks must match image " "shape on first axis")
if image.chunks[0][0] != lm.chunks[0][0]:
raise ValueError("Image chunks and lm chunks must " "match on first axis")
if image.chunks[1] != frequency.chunks[0]:
raise ValueError("Image chunks must match frequency " "chunks on second axis")
return da.core.blockwise(
_im_to_vis_wrapper,
("row", "chan", "corr"),
image,
("source", "chan", "corr"),
uvw,
("row", "(u,v,w)"),
lm,
("source", "(l,m)"),
frequency,
("chan",),
convention=convention,
dtype=dtype,
dtype_=dtype,
)
def _vis_to_im_wrapper(vis, uvw, lm, frequency, flags, convention, dtype_):
return np_vis_to_im(
vis, uvw[0], lm[0], frequency, flags, convention=convention, dtype=dtype_
)[None, :]
[docs]
@requires_optional("dask.array", dask_import_error)
def vis_to_im(vis, uvw, lm, frequency, flags, convention="fourier", dtype=np.float64):
"""Dask wrapper for vis_to_im function"""
if vis.chunks[0] != uvw.chunks[0]:
raise ValueError("Vis chunks and uvw chunks must " "match on first axis")
if vis.chunks[1] != frequency.chunks[0]:
raise ValueError("Vis chunks must match frequency " "chunks on second axis")
if vis.chunks != flags.chunks:
raise ValueError("Vis chunks must match flags " "chunks on all axes")
ims = da.core.blockwise(
_vis_to_im_wrapper,
("row", "source", "chan", "corr"),
vis,
("row", "chan", "corr"),
uvw,
("row", "(u,v,w)"),
lm,
("source", "(l,m)"),
frequency,
("chan",),
flags,
("row", "chan", "corr"),
adjust_chunks={"row": 1},
convention=convention,
dtype=dtype,
dtype_=dtype,
)
return ims.sum(axis=0)
im_to_vis.__doc__ = doc_tuple_to_str(
im_to_vis_docs, [(":class:`numpy.ndarray`", ":class:`dask.array.Array`")]
)
vis_to_im.__doc__ = doc_tuple_to_str(
vis_to_im_docs, [(":class:`numpy.ndarray`", ":class:`dask.array.Array`")]
)