Source code for africanus.rime.dask

# -*- coding: utf-8 -*-


from africanus.rime.phase import (
    phase_delay as np_phase_delay,
    PHASE_DELAY_DOCS,
)
from africanus.rime.parangles import parallactic_angles as np_parangles
from africanus.rime.feeds import feed_rotation as np_feed_rotation
from africanus.rime.feeds import FEED_ROTATION_DOCS
from africanus.rime.transform import transform_sources as np_transform_sources
from africanus.rime.fast_beam_cubes import (
    beam_cube_dde as np_beam_cube_dde,
    BEAM_CUBE_DOCS,
)
from africanus.rime.dask_predict import predict_vis, wsclean_predict  # noqa
from africanus.rime.zernike import zernike_dde as np_zernike_dde


from africanus.util.docs import mod_docs
from africanus.util.requirements import requires_optional
from africanus.util.type_inference import infer_complex_dtype

import numpy as np

try:
    import dask.array as da
except ImportError as e:
    da_import_error = e
else:
    da_import_error = None


def _phase_delay_wrap(lm, uvw, frequency, convention):
    return np_phase_delay(lm[0], uvw[0], frequency, convention=convention)


[docs]@requires_optional("dask.array", da_import_error) def phase_delay(lm, uvw, frequency, convention="fourier"): """ Dask wrapper for phase_delay function """ return da.core.blockwise( _phase_delay_wrap, ("source", "row", "chan"), lm, ("source", "(l,m)"), uvw, ("row", "(u,v,w)"), frequency, ("chan",), convention=convention, dtype=infer_complex_dtype(lm, uvw, frequency), )
def _parangle_wrapper(t, ap, fc, **kw): return np_parangles(t, ap[0], fc[0], **kw)
[docs]@requires_optional("dask.array", da_import_error) def parallactic_angles(times, antenna_positions, field_centre, **kwargs): return da.core.blockwise( _parangle_wrapper, ("time", "ant"), times, ("time",), antenna_positions, ("ant", "xyz"), field_centre, ("fc",), dtype=times.dtype, **kwargs )
[docs]@requires_optional("dask.array", da_import_error) def feed_rotation(parallactic_angles, feed_type): pa_dims = tuple("pa-%d" % i for i in range(parallactic_angles.ndim)) corr_dims = ("corr-1", "corr-2") if parallactic_angles.dtype == np.float32: dtype = np.complex64 elif parallactic_angles.dtype == np.float64: dtype = np.complex128 else: raise ValueError("parallactic_angles have " "non-floating point dtype") return da.core.blockwise( np_feed_rotation, pa_dims + corr_dims, parallactic_angles, pa_dims, feed_type=feed_type, new_axes={"corr-1": 2, "corr-2": 2}, dtype=dtype, )
def _xform_wrap( lm, parallactic_angles, pointing_errors, antenna_scaling, frequency, dtype_ ): return np_transform_sources( lm[0], parallactic_angles, pointing_errors[0], antenna_scaling, frequency, dtype=dtype_, )
[docs]@requires_optional("dask.array", da_import_error) def transform_sources( lm, parallactic_angles, pointing_errors, antenna_scaling, frequency, dtype=None, ): if dtype is None: dtype = np.float64 xform_inds = ("comp", "src", "time", "ant", "chan") return da.core.blockwise( _xform_wrap, xform_inds, lm, ("src", "lm"), parallactic_angles, ("time", "ant"), pointing_errors, ("time", "ant", "lm"), antenna_scaling, ("ant", "chan"), frequency, ("chan",), new_axes={"comp": 3}, dtype=dtype, dtype_=dtype, )
def _beam_cube_dde_wrapper( beam, beam_lm_extents, beam_freq_map, lm, parallactic_angles, point_errors, antenna_scaling, frequencies, ): return np_beam_cube_dde( beam[0][0][0], beam_lm_extents[0][0], beam_freq_map[0], lm[0], parallactic_angles, point_errors[0], antenna_scaling[0], frequencies, )
[docs]@requires_optional("dask.array", da_import_error) def beam_cube_dde( beam, beam_lm_extents, beam_freq_map, lm, parallactic_angles, point_errors, antenna_scaling, frequencies, ): if not all(len(c) == 1 for c in beam.chunks): raise ValueError("Beam chunking unsupported") if not all(len(c) == 1 for c in beam_freq_map.chunks): raise ValueError("Beam frequency map chunking unsupported") if not all(len(c) == 1 for c in beam_lm_extents.chunks): raise ValueError("Chunking of beam_lm_extents unsupported") corr_shapes = beam.shape[3:] corr_dims = tuple("corr-%d" % i for i in range(len(corr_shapes))) dde_dims = ("source", "time", "ant", "chan") + corr_dims beam_dims = ("beam-lw", "beam-mh", "beam-nud") + corr_dims return da.core.blockwise( _beam_cube_dde_wrapper, dde_dims, beam, beam_dims, beam_lm_extents, ("beam-lm", "beam-ext"), beam_freq_map, ("beam-nud",), lm, ("source", "source-comp"), parallactic_angles, ("time", "ant"), point_errors, ("time", "ant", "chan", "pt-comp"), antenna_scaling, ("ant", "chan", "scale-comp"), frequencies, ("chan",), dtype=beam.dtype, )
def _zernike_wrapper( coords, coeffs, noll_index, parallactic_angle, frequency_scaling, antenna_scaling, pointing_errors, ): # coords loses "three" dim # coeffs loses "poly" dim # noll_index loses "poly" dim return np_zernike_dde( coords[0], coeffs[0], noll_index[0], parallactic_angle, frequency_scaling, antenna_scaling[0], pointing_errors[0], )
[docs]@requires_optional("dask.array", da_import_error) def zernike_dde( coords, coeffs, noll_index, parallactic_angle, frequency_scaling, antenna_scaling, pointing_errors, ): ncorrs = len(coeffs.shape[2:-1]) corr_dims = tuple("corr-%d" % i for i in range(ncorrs)) return da.core.blockwise( _zernike_wrapper, ("source", "time", "ant", "chan") + corr_dims, coords, ("three", "source", "time", "ant", "chan"), coeffs, ("ant", "chan") + corr_dims + ("poly",), noll_index, ("ant", "chan") + corr_dims + ("poly",), parallactic_angle, ("time", "ant"), frequency_scaling, ("chan",), antenna_scaling, ("ant", "chan", "two"), pointing_errors, ("time", "ant", "chan", "two"), dtype=coeffs.dtype, )
try: phase_delay.__doc__ = PHASE_DELAY_DOCS.substitute( array_type=":class:`dask.array.Array`" ) except AttributeError: pass try: parallactic_angles.__doc__ = mod_docs( np_parangles.__doc__, [(":class:`numpy.ndarray`", ":class:`dask.array.Array`")], ) except AttributeError: pass try: feed_rotation.__doc__ = FEED_ROTATION_DOCS.substitute( array_type=":class:`numpy.ndarray`" ) except AttributeError: pass try: transform_sources.__doc__ = mod_docs( np_transform_sources.__doc__, [(":class:`numpy.ndarray`", ":class:`dask.array.Array`")], ) except AttributeError: pass try: beam_cube_dde.__doc__ = BEAM_CUBE_DOCS.substitute( array_type=":class:`dask.array.Array`" ) except AttributeError: pass try: zernike_dde.__doc__ = mod_docs( np_zernike_dde.__doc__, [(":class:`numpy.ndarray`", ":class:`dask.array.Array`")], ) except AttributeError: pass