# -*- coding: utf-8 -*-
from africanus.rime.phase import (
phase_delay as np_phase_delay,
PHASE_DELAY_DOCS,
)
from africanus.rime.parangles import parallactic_angles as np_parangles
from africanus.rime.feeds import feed_rotation as np_feed_rotation
from africanus.rime.feeds import FEED_ROTATION_DOCS
from africanus.rime.transform import transform_sources as np_transform_sources
from africanus.rime.fast_beam_cubes import (
beam_cube_dde as np_beam_cube_dde,
BEAM_CUBE_DOCS,
)
from africanus.rime.dask_predict import predict_vis, wsclean_predict # noqa
from africanus.rime.zernike import zernike_dde as np_zernike_dde
from africanus.util.docs import mod_docs
from africanus.util.requirements import requires_optional
from africanus.util.type_inference import infer_complex_dtype
import numpy as np
try:
import dask.array as da
except ImportError as e:
da_import_error = e
else:
da_import_error = None
def _phase_delay_wrap(lm, uvw, frequency, convention):
return np_phase_delay(lm[0], uvw[0], frequency, convention=convention)
[docs]
@requires_optional("dask.array", da_import_error)
def phase_delay(lm, uvw, frequency, convention="fourier"):
"""Dask wrapper for phase_delay function"""
return da.core.blockwise(
_phase_delay_wrap,
("source", "row", "chan"),
lm,
("source", "(l,m)"),
uvw,
("row", "(u,v,w)"),
frequency,
("chan",),
convention=convention,
dtype=infer_complex_dtype(lm, uvw, frequency),
)
def _parangle_wrapper(t, ap, fc, **kw):
return np_parangles(t, ap[0], fc[0], **kw)
[docs]
@requires_optional("dask.array", da_import_error)
def parallactic_angles(times, antenna_positions, field_centre, **kwargs):
return da.core.blockwise(
_parangle_wrapper,
("time", "ant"),
times,
("time",),
antenna_positions,
("ant", "xyz"),
field_centre,
("fc",),
dtype=times.dtype,
**kwargs,
)
[docs]
@requires_optional("dask.array", da_import_error)
def feed_rotation(parallactic_angles, feed_type):
pa_dims = tuple("pa-%d" % i for i in range(parallactic_angles.ndim))
corr_dims = ("corr-1", "corr-2")
if parallactic_angles.dtype == np.float32:
dtype = np.complex64
elif parallactic_angles.dtype == np.float64:
dtype = np.complex128
else:
raise ValueError("parallactic_angles have " "non-floating point dtype")
return da.core.blockwise(
np_feed_rotation,
pa_dims + corr_dims,
parallactic_angles,
pa_dims,
feed_type=feed_type,
new_axes={"corr-1": 2, "corr-2": 2},
dtype=dtype,
)
def _xform_wrap(
lm, parallactic_angles, pointing_errors, antenna_scaling, frequency, dtype_
):
return np_transform_sources(
lm[0],
parallactic_angles,
pointing_errors[0],
antenna_scaling,
frequency,
dtype=dtype_,
)
def _beam_cube_dde_wrapper(
beam,
beam_lm_extents,
beam_freq_map,
lm,
parallactic_angles,
point_errors,
antenna_scaling,
frequencies,
):
return np_beam_cube_dde(
beam[0][0][0],
beam_lm_extents[0][0],
beam_freq_map[0],
lm[0],
parallactic_angles,
point_errors[0],
antenna_scaling[0],
frequencies,
)
[docs]
@requires_optional("dask.array", da_import_error)
def beam_cube_dde(
beam,
beam_lm_extents,
beam_freq_map,
lm,
parallactic_angles,
point_errors,
antenna_scaling,
frequencies,
):
if not all(len(c) == 1 for c in beam.chunks):
raise ValueError("Beam chunking unsupported")
if not all(len(c) == 1 for c in beam_freq_map.chunks):
raise ValueError("Beam frequency map chunking unsupported")
if not all(len(c) == 1 for c in beam_lm_extents.chunks):
raise ValueError("Chunking of beam_lm_extents unsupported")
corr_shapes = beam.shape[3:]
corr_dims = tuple("corr-%d" % i for i in range(len(corr_shapes)))
dde_dims = ("source", "time", "ant", "chan") + corr_dims
beam_dims = ("beam-lw", "beam-mh", "beam-nud") + corr_dims
return da.core.blockwise(
_beam_cube_dde_wrapper,
dde_dims,
beam,
beam_dims,
beam_lm_extents,
("beam-lm", "beam-ext"),
beam_freq_map,
("beam-nud",),
lm,
("source", "source-comp"),
parallactic_angles,
("time", "ant"),
point_errors,
("time", "ant", "chan", "pt-comp"),
antenna_scaling,
("ant", "chan", "scale-comp"),
frequencies,
("chan",),
dtype=beam.dtype,
)
def _zernike_wrapper(
coords,
coeffs,
noll_index,
parallactic_angle,
frequency_scaling,
antenna_scaling,
pointing_errors,
):
# coords loses "three" dim
# coeffs loses "poly" dim
# noll_index loses "poly" dim
return np_zernike_dde(
coords[0],
coeffs[0],
noll_index[0],
parallactic_angle,
frequency_scaling,
antenna_scaling[0],
pointing_errors[0],
)
[docs]
@requires_optional("dask.array", da_import_error)
def zernike_dde(
coords,
coeffs,
noll_index,
parallactic_angle,
frequency_scaling,
antenna_scaling,
pointing_errors,
):
ncorrs = len(coeffs.shape[2:-1])
corr_dims = tuple("corr-%d" % i for i in range(ncorrs))
return da.core.blockwise(
_zernike_wrapper,
("source", "time", "ant", "chan") + corr_dims,
coords,
("three", "source", "time", "ant", "chan"),
coeffs,
("ant", "chan") + corr_dims + ("poly",),
noll_index,
("ant", "chan") + corr_dims + ("poly",),
parallactic_angle,
("time", "ant"),
frequency_scaling,
("chan",),
antenna_scaling,
("ant", "chan", "two"),
pointing_errors,
("time", "ant", "chan", "two"),
dtype=coeffs.dtype,
)
try:
phase_delay.__doc__ = PHASE_DELAY_DOCS.substitute(
array_type=":class:`dask.array.Array`"
)
except AttributeError:
pass
try:
parallactic_angles.__doc__ = mod_docs(
np_parangles.__doc__,
[(":class:`numpy.ndarray`", ":class:`dask.array.Array`")],
)
except AttributeError:
pass
try:
feed_rotation.__doc__ = FEED_ROTATION_DOCS.substitute(
array_type=":class:`numpy.ndarray`"
)
except AttributeError:
pass
try:
transform_sources.__doc__ = mod_docs(
np_transform_sources.__doc__,
[(":class:`numpy.ndarray`", ":class:`dask.array.Array`")],
)
except AttributeError:
pass
try:
beam_cube_dde.__doc__ = BEAM_CUBE_DOCS.substitute(
array_type=":class:`dask.array.Array`"
)
except AttributeError:
pass
try:
zernike_dde.__doc__ = mod_docs(
np_zernike_dde.__doc__,
[(":class:`numpy.ndarray`", ":class:`dask.array.Array`")],
)
except AttributeError:
pass