# -*- coding: utf-8 -*-
try:
import dask.array as da
except ImportError as e:
dask_import_error = e
else:
dask_import_error = None
import numpy as np
from africanus.gridding.wgridder.vis2im import DIRTY_DOCS
from africanus.gridding.wgridder.im2vis import MODEL_DOCS
from africanus.gridding.wgridder.im2residim import RESIDUAL_DOCS
from africanus.gridding.wgridder.hessian import HESSIAN_DOCS
from africanus.gridding.wgridder.im2vis import _model_internal as model_np
from africanus.gridding.wgridder.vis2im import _dirty_internal as dirty_np
from africanus.gridding.wgridder.im2residim import (_residual_internal
as residual_np)
from africanus.gridding.wgridder.hessian import (_hessian_internal
as hessian_np)
from africanus.util.requirements import requires_optional
def _model_wrapper(uvw, freq, model, freq_bin_idx, freq_bin_counts, cell,
weights, flag, celly, epsilon, nthreads, do_wstacking):
return model_np(uvw[0], freq, model[0][0], freq_bin_idx, freq_bin_counts,
cell, weights, flag, celly, epsilon, nthreads,
do_wstacking)
[docs]@requires_optional('dask.array', dask_import_error)
def model(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell,
weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1,
do_wstacking=True):
# determine output type
complex_type = da.result_type(image, np.complex64)
if celly is None:
celly = cell
if not nthreads:
import multiprocessing
nthreads = multiprocessing.cpu_count()
if weights is None:
weight_out = None
else:
weight_out = ('row', 'chan')
if flag is None:
flag_out = None
else:
flag_out = ('row', 'chan')
vis = da.blockwise(_model_wrapper, ('row', 'chan'),
uvw, ('row', 'three'),
freq, ('chan',),
image, ('chan', 'nx', 'ny'),
freq_bin_idx, ('chan',),
freq_bin_counts, ('chan',),
cell, None,
weights, weight_out,
flag, flag_out,
celly, None,
epsilon, None,
nthreads, None,
do_wstacking, None,
adjust_chunks={'chan': freq.chunks[0]},
dtype=complex_type,
align_arrays=False)
return vis
def _dirty_wrapper(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny,
cell, weights, flag, celly, epsilon, nthreads,
do_wstacking, double_accum):
return dirty_np(uvw[0], freq, vis, freq_bin_idx, freq_bin_counts,
nx, ny, cell, weights, flag, celly, epsilon,
nthreads, do_wstacking, double_accum)
[docs]@requires_optional('dask.array', dask_import_error)
def dirty(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, cell,
weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1,
do_wstacking=True, double_accum=False):
# get real data type (not available from inputs)
if vis.dtype == np.complex128:
real_type = np.float64
elif vis.dtype == np.complex64:
real_type = np.float32
if celly is None:
celly = cell
if not nthreads:
import multiprocessing
nthreads = multiprocessing.cpu_count()
if weights is None:
weight_out = None
else:
weight_out = ('row', 'chan')
if flag is None:
flag_out = None
else:
flag_out = ('row', 'chan')
img = da.blockwise(_dirty_wrapper, ('row', 'chan', 'nx', 'ny'),
uvw, ('row', 'three'),
freq, ('chan',),
vis, ('row', 'chan'),
freq_bin_idx, ('chan',),
freq_bin_counts, ('chan',),
nx, None,
ny, None,
cell, None,
weights, weight_out,
flag, flag_out,
celly, None,
epsilon, None,
nthreads, None,
do_wstacking, None,
double_accum, None,
adjust_chunks={'chan': freq_bin_idx.chunks[0],
'row': (1,)*len(vis.chunks[0])},
new_axes={"nx": nx, "ny": ny},
dtype=real_type,
align_arrays=False)
return img.sum(axis=0)
def _residual_wrapper(uvw, freq, model, vis, freq_bin_idx, freq_bin_counts,
cell, weights, flag, celly, epsilon, nthreads,
do_wstacking, double_accum):
return residual_np(uvw[0], freq, model, vis, freq_bin_idx,
freq_bin_counts, cell, weights, flag, celly, epsilon,
nthreads, do_wstacking, double_accum)
[docs]@requires_optional('dask.array', dask_import_error)
def residual(uvw, freq, image, vis, freq_bin_idx, freq_bin_counts, cell,
weights=None, flag=None, celly=None, epsilon=1e-5,
nthreads=1, do_wstacking=True, double_accum=False):
if celly is None:
celly = cell
if not nthreads:
import multiprocessing
nthreads = multiprocessing.cpu_count()
if weights is None:
weight_out = None
else:
weight_out = ('row', 'chan')
if flag is None:
flag_out = None
else:
flag_out = ('row', 'chan')
img = da.blockwise(_residual_wrapper, ('row', 'chan', 'nx', 'ny'),
uvw, ('row', 'three'),
freq, ('chan',),
image, ('chan', 'nx', 'ny'),
vis, ('row', 'chan'),
freq_bin_idx, ('chan',),
freq_bin_counts, ('chan',),
cell, None,
weights, weight_out,
flag, flag_out,
celly, None,
epsilon, None,
nthreads, None,
do_wstacking, None,
double_accum, None,
adjust_chunks={'chan': freq_bin_idx.chunks[0],
'row': (1,)*len(vis.chunks[0])},
dtype=image.dtype,
align_arrays=False)
return img.sum(axis=0)
def _hessian_wrapper(uvw, freq, model, freq_bin_idx, freq_bin_counts,
cell, weights, flag, celly, epsilon, nthreads,
do_wstacking, double_accum):
return hessian_np(uvw[0], freq, model, freq_bin_idx,
freq_bin_counts, cell, weights, flag, celly, epsilon,
nthreads, do_wstacking, double_accum)
[docs]@requires_optional('dask.array', dask_import_error)
def hessian(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell,
weights=None, flag=None, celly=None, epsilon=1e-5,
nthreads=1, do_wstacking=True, double_accum=False):
if celly is None:
celly = cell
if not nthreads:
import multiprocessing
nthreads = multiprocessing.cpu_count()
if weights is None:
weight_out = None
else:
weight_out = ('row', 'chan')
if flag is None:
flag_out = None
else:
flag_out = ('row', 'chan')
img = da.blockwise(_hessian_wrapper, ('row', 'chan', 'nx', 'ny'),
uvw, ('row', 'three'),
freq, ('chan',),
image, ('chan', 'nx', 'ny'),
freq_bin_idx, ('chan',),
freq_bin_counts, ('chan',),
cell, None,
weights, weight_out,
flag, flag_out,
celly, None,
epsilon, None,
nthreads, None,
do_wstacking, None,
double_accum, None,
adjust_chunks={'chan': freq_bin_idx.chunks[0],
'row': (1,)*len(uvw.chunks[0])},
dtype=image.dtype,
align_arrays=False)
return img.sum(axis=0)
model.__doc__ = MODEL_DOCS.substitute(
array_type=":class:`dask.array.Array`")
dirty.__doc__ = DIRTY_DOCS.substitute(
array_type=":class:`dask.array.Array`")
residual.__doc__ = RESIDUAL_DOCS.substitute(
array_type=":class:`dask.array.Array`")
hessian.__doc__ = HESSIAN_DOCS.substitute(
array_type=":class:`dask.array.Array`")