Source code for africanus.calibration.utils.dask

# -*- coding: utf-8 -*-

from africanus.calibration.utils.correct_vis import CORRECT_VIS_DOCS
from africanus.calibration.utils.corrupt_vis import CORRUPT_VIS_DOCS
from africanus.calibration.utils.residual_vis import RESIDUAL_VIS_DOCS
from africanus.calibration.utils.compute_and_corrupt_vis import (
    COMPUTE_AND_CORRUPT_VIS_DOCS,
)
from africanus.calibration.utils import correct_vis as np_correct_vis
from africanus.calibration.utils import (
    compute_and_corrupt_vis as np_compute_and_corrupt_vis,
)
from africanus.calibration.utils import corrupt_vis as np_corrupt_vis
from africanus.calibration.utils import residual_vis as np_residual_vis
from africanus.calibration.utils import check_type
from africanus.calibration.utils.utils import DIAG_DIAG, DIAG, FULL
from africanus.util.requirements import requires_optional

try:
    from dask.array.core import blockwise
except ImportError as e:
    dask_import_error = e
else:
    dask_import_error = None


def _corrupt_vis_wrapper(
    time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model
):
    return np_corrupt_vis(
        time_bin_indices, time_bin_counts, antenna1, antenna2, jones[0][0], model[0]
    )


[docs] @requires_optional("dask.array", dask_import_error) def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): mode = check_type(jones, model, vis_type="model") if jones.chunks[1][0] != jones.shape[1]: raise ValueError("Cannot chunk jones over antenna") if jones.chunks[3][0] != jones.shape[3]: raise ValueError("Cannot chunk jones over direction") if model.chunks[2][0] != model.shape[2]: raise ValueError("Cannot chunk model over direction") if mode == DIAG_DIAG: out_shape = ("row", "chan", "corr1") model_shape = ("row", "chan", "dir", "corr1") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == DIAG: out_shape = ("row", "chan", "corr1", "corr2") model_shape = ("row", "chan", "dir", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == FULL: out_shape = ("row", "chan", "corr1", "corr2") model_shape = ("row", "chan", "dir", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1", "corr2") else: raise ValueError("Unknown mode argument of %s" % mode) return blockwise( _corrupt_vis_wrapper, out_shape, time_bin_indices, ("row",), time_bin_counts, ("row",), antenna1, ("row",), antenna2, ("row",), jones, jones_shape, model, model_shape, adjust_chunks={"row": antenna1.chunks[0]}, dtype=model.dtype, align_arrays=False, )
def _compute_and_corrupt_vis_wrapper( time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm ): return np_compute_and_corrupt_vis( time_bin_indices, time_bin_counts, antenna1, antenna2, jones[0][0], model[0], uvw[0], freq, lm[0][0], )
[docs] @requires_optional("dask.array", dask_import_error) def compute_and_corrupt_vis( time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm ): if jones.chunks[1][0] != jones.shape[1]: raise ValueError("Cannot chunk jones over antenna") if jones.chunks[3][0] != jones.shape[3]: raise ValueError("Cannot chunk jones over direction") if model.chunks[2][0] != model.shape[2]: raise ValueError("Cannot chunk model over direction") if uvw.chunks[1][0] != uvw.shape[1]: raise ValueError("Cannot chunk uvw over last axis") if lm.chunks[1][0] != lm.shape[1]: raise ValueError("Cannot chunks lm over direction") if lm.chunks[2][0] != lm.shape[2]: raise ValueError("Cannot chunks lm over last axis") mode = check_type(jones, model, vis_type="model") if mode == DIAG_DIAG: out_shape = ("row", "chan", "corr1") model_shape = ("row", "chan", "dir", "corr1") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == DIAG: out_shape = ("row", "chan", "corr1", "corr2") model_shape = ("row", "chan", "dir", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == FULL: out_shape = ("row", "chan", "corr1", "corr2") model_shape = ("row", "chan", "dir", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1", "corr2") else: raise ValueError("Unknown mode argument of %s" % mode) return blockwise( _compute_and_corrupt_vis_wrapper, out_shape, time_bin_indices, ("row",), time_bin_counts, ("row",), antenna1, ("row",), antenna2, ("row",), jones, jones_shape, model, model_shape, uvw, ("row", "three"), freq, ("chan",), lm, ("row", "dir", "two"), adjust_chunks={"row": antenna1.chunks[0]}, dtype=model.dtype, align_arrays=False, )
def _correct_vis_wrapper( time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag ): return np_correct_vis( time_bin_indices, time_bin_counts, antenna1, antenna2, jones[0][0], vis, flag )
[docs] @requires_optional("dask.array", dask_import_error) def correct_vis( time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag ): if jones.chunks[1][0] != jones.shape[1]: raise ValueError("Cannot chunk jones over antenna") if jones.chunks[3][0] != jones.shape[3]: raise ValueError("Cannot chunk jones over direction") mode = check_type(jones, vis) if mode == DIAG_DIAG: out_shape = ("row", "chan", "corr1") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == DIAG: out_shape = ("row", "chan", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == FULL: out_shape = ("row", "chan", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1", "corr2") else: raise ValueError("Unknown mode argument of %s" % mode) return blockwise( _correct_vis_wrapper, out_shape, time_bin_indices, ("row",), time_bin_counts, ("row",), antenna1, ("row",), antenna2, ("row",), jones, jones_shape, vis, out_shape, flag, out_shape, adjust_chunks={"row": antenna1.chunks[0]}, dtype=vis.dtype, align_arrays=False, )
def _residual_vis_wrapper( time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model ): return np_residual_vis( time_bin_indices, time_bin_counts, antenna1, antenna2, jones[0][0], vis, flag, model[0], )
[docs] @requires_optional("dask.array", dask_import_error) def residual_vis( time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model ): if jones.chunks[1][0] != jones.shape[1]: raise ValueError("Cannot chunk jones over antenna") if jones.chunks[3][0] != jones.shape[3]: raise ValueError("Cannot chunk jones over direction") if model.chunks[2][0] != model.shape[2]: raise ValueError("Cannot chunk model over direction") mode = check_type(jones, vis) if mode == DIAG_DIAG: out_shape = ("row", "chan", "corr1") model_shape = ("row", "chan", "dir", "corr1") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == DIAG: out_shape = ("row", "chan", "corr1", "corr2") model_shape = ("row", "chan", "dir", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == FULL: out_shape = ("row", "chan", "corr1", "corr2") model_shape = ("row", "chan", "dir", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1", "corr2") else: raise ValueError("Unknown mode argument of %s" % mode) return blockwise( _residual_vis_wrapper, out_shape, time_bin_indices, ("row",), time_bin_counts, ("row",), antenna1, ("row",), antenna2, ("row",), jones, jones_shape, vis, out_shape, flag, out_shape, model, model_shape, adjust_chunks={"row": antenna1.chunks[0]}, dtype=vis.dtype, align_arrays=False, )
compute_and_corrupt_vis.__doc__ = COMPUTE_AND_CORRUPT_VIS_DOCS.substitute( array_type=":class:`dask.array.Array`" ) corrupt_vis.__doc__ = CORRUPT_VIS_DOCS.substitute( array_type=":class:`dask.array.Array`" ) correct_vis.__doc__ = CORRECT_VIS_DOCS.substitute( array_type=":class:`dask.array.Array`" ) residual_vis.__doc__ = RESIDUAL_VIS_DOCS.substitute( array_type=":class:`dask.array.Array`" )