# -*- coding: utf-8 -*-
try:
import dask.array as da
except ImportError as e:
dask_import_error = e
else:
dask_import_error = None
import numpy as np
from africanus.gridding.wgridder.vis2im import DIRTY_DOCS
from africanus.gridding.wgridder.im2vis import MODEL_DOCS
from africanus.gridding.wgridder.im2residim import RESIDUAL_DOCS
from africanus.gridding.wgridder.hessian import HESSIAN_DOCS
from africanus.gridding.wgridder.im2vis import _model_internal as model_np
from africanus.gridding.wgridder.vis2im import _dirty_internal as dirty_np
from africanus.gridding.wgridder.im2residim import _residual_internal as residual_np
from africanus.gridding.wgridder.hessian import _hessian_internal as hessian_np
from africanus.util.requirements import requires_optional
def _model_wrapper(
uvw,
freq,
model,
freq_bin_idx,
freq_bin_counts,
cell,
weights,
flag,
celly,
epsilon,
nthreads,
do_wstacking,
):
return model_np(
uvw[0],
freq,
model[0][0],
freq_bin_idx,
freq_bin_counts,
cell,
weights,
flag,
celly,
epsilon,
nthreads,
do_wstacking,
)
[docs]
@requires_optional("dask.array", dask_import_error)
def model(
uvw,
freq,
image,
freq_bin_idx,
freq_bin_counts,
cell,
weights=None,
flag=None,
celly=None,
epsilon=1e-5,
nthreads=1,
do_wstacking=True,
):
# determine output type
complex_type = da.result_type(image, np.complex64)
if celly is None:
celly = cell
if not nthreads:
import multiprocessing
nthreads = multiprocessing.cpu_count()
if weights is None:
weight_out = None
else:
weight_out = ("row", "chan")
if flag is None:
flag_out = None
else:
flag_out = ("row", "chan")
vis = da.blockwise(
_model_wrapper,
("row", "chan"),
uvw,
("row", "three"),
freq,
("chan",),
image,
("chan", "nx", "ny"),
freq_bin_idx,
("chan",),
freq_bin_counts,
("chan",),
cell,
None,
weights,
weight_out,
flag,
flag_out,
celly,
None,
epsilon,
None,
nthreads,
None,
do_wstacking,
None,
adjust_chunks={"chan": freq.chunks[0]},
dtype=complex_type,
align_arrays=False,
)
return vis
def _dirty_wrapper(
uvw,
freq,
vis,
freq_bin_idx,
freq_bin_counts,
nx,
ny,
cell,
weights,
flag,
celly,
epsilon,
nthreads,
do_wstacking,
double_accum,
):
return dirty_np(
uvw[0],
freq,
vis,
freq_bin_idx,
freq_bin_counts,
nx,
ny,
cell,
weights,
flag,
celly,
epsilon,
nthreads,
do_wstacking,
double_accum,
)
[docs]
@requires_optional("dask.array", dask_import_error)
def dirty(
uvw,
freq,
vis,
freq_bin_idx,
freq_bin_counts,
nx,
ny,
cell,
weights=None,
flag=None,
celly=None,
epsilon=1e-5,
nthreads=1,
do_wstacking=True,
double_accum=False,
):
# get real data type (not available from inputs)
if vis.dtype == np.complex128:
real_type = np.float64
elif vis.dtype == np.complex64:
real_type = np.float32
if celly is None:
celly = cell
if not nthreads:
import multiprocessing
nthreads = multiprocessing.cpu_count()
if weights is None:
weight_out = None
else:
weight_out = ("row", "chan")
if flag is None:
flag_out = None
else:
flag_out = ("row", "chan")
img = da.blockwise(
_dirty_wrapper,
("row", "chan", "nx", "ny"),
uvw,
("row", "three"),
freq,
("chan",),
vis,
("row", "chan"),
freq_bin_idx,
("chan",),
freq_bin_counts,
("chan",),
nx,
None,
ny,
None,
cell,
None,
weights,
weight_out,
flag,
flag_out,
celly,
None,
epsilon,
None,
nthreads,
None,
do_wstacking,
None,
double_accum,
None,
adjust_chunks={
"chan": freq_bin_idx.chunks[0],
"row": (1,) * len(vis.chunks[0]),
},
new_axes={"nx": nx, "ny": ny},
dtype=real_type,
align_arrays=False,
)
return img.sum(axis=0)
def _residual_wrapper(
uvw,
freq,
model,
vis,
freq_bin_idx,
freq_bin_counts,
cell,
weights,
flag,
celly,
epsilon,
nthreads,
do_wstacking,
double_accum,
):
return residual_np(
uvw[0],
freq,
model,
vis,
freq_bin_idx,
freq_bin_counts,
cell,
weights,
flag,
celly,
epsilon,
nthreads,
do_wstacking,
double_accum,
)
[docs]
@requires_optional("dask.array", dask_import_error)
def residual(
uvw,
freq,
image,
vis,
freq_bin_idx,
freq_bin_counts,
cell,
weights=None,
flag=None,
celly=None,
epsilon=1e-5,
nthreads=1,
do_wstacking=True,
double_accum=False,
):
if celly is None:
celly = cell
if not nthreads:
import multiprocessing
nthreads = multiprocessing.cpu_count()
if weights is None:
weight_out = None
else:
weight_out = ("row", "chan")
if flag is None:
flag_out = None
else:
flag_out = ("row", "chan")
img = da.blockwise(
_residual_wrapper,
("row", "chan", "nx", "ny"),
uvw,
("row", "three"),
freq,
("chan",),
image,
("chan", "nx", "ny"),
vis,
("row", "chan"),
freq_bin_idx,
("chan",),
freq_bin_counts,
("chan",),
cell,
None,
weights,
weight_out,
flag,
flag_out,
celly,
None,
epsilon,
None,
nthreads,
None,
do_wstacking,
None,
double_accum,
None,
adjust_chunks={
"chan": freq_bin_idx.chunks[0],
"row": (1,) * len(vis.chunks[0]),
},
dtype=image.dtype,
align_arrays=False,
)
return img.sum(axis=0)
def _hessian_wrapper(
uvw,
freq,
model,
freq_bin_idx,
freq_bin_counts,
cell,
weights,
flag,
celly,
epsilon,
nthreads,
do_wstacking,
double_accum,
):
return hessian_np(
uvw[0],
freq,
model,
freq_bin_idx,
freq_bin_counts,
cell,
weights,
flag,
celly,
epsilon,
nthreads,
do_wstacking,
double_accum,
)
[docs]
@requires_optional("dask.array", dask_import_error)
def hessian(
uvw,
freq,
image,
freq_bin_idx,
freq_bin_counts,
cell,
weights=None,
flag=None,
celly=None,
epsilon=1e-5,
nthreads=1,
do_wstacking=True,
double_accum=False,
):
if celly is None:
celly = cell
if not nthreads:
import multiprocessing
nthreads = multiprocessing.cpu_count()
if weights is None:
weight_out = None
else:
weight_out = ("row", "chan")
if flag is None:
flag_out = None
else:
flag_out = ("row", "chan")
img = da.blockwise(
_hessian_wrapper,
("row", "chan", "nx", "ny"),
uvw,
("row", "three"),
freq,
("chan",),
image,
("chan", "nx", "ny"),
freq_bin_idx,
("chan",),
freq_bin_counts,
("chan",),
cell,
None,
weights,
weight_out,
flag,
flag_out,
celly,
None,
epsilon,
None,
nthreads,
None,
do_wstacking,
None,
double_accum,
None,
adjust_chunks={
"chan": freq_bin_idx.chunks[0],
"row": (1,) * len(uvw.chunks[0]),
},
dtype=image.dtype,
align_arrays=False,
)
return img.sum(axis=0)
model.__doc__ = MODEL_DOCS.substitute(array_type=":class:`dask.array.Array`")
dirty.__doc__ = DIRTY_DOCS.substitute(array_type=":class:`dask.array.Array`")
residual.__doc__ = RESIDUAL_DOCS.substitute(array_type=":class:`dask.array.Array`")
hessian.__doc__ = HESSIAN_DOCS.substitute(array_type=":class:`dask.array.Array`")