Source code for africanus.gridding.wgridder.dask

# -*- coding: utf-8 -*-

try:
    import dask.array as da
except ImportError as e:
    dask_import_error = e
else:
    dask_import_error = None

import numpy as np
from africanus.gridding.wgridder.vis2im import DIRTY_DOCS
from africanus.gridding.wgridder.im2vis import MODEL_DOCS
from africanus.gridding.wgridder.im2residim import RESIDUAL_DOCS
from africanus.gridding.wgridder.hessian import HESSIAN_DOCS
from africanus.gridding.wgridder.im2vis import _model_internal as model_np
from africanus.gridding.wgridder.vis2im import _dirty_internal as dirty_np
from africanus.gridding.wgridder.im2residim import (_residual_internal
                                                    as residual_np)
from africanus.gridding.wgridder.hessian import (_hessian_internal
                                                 as hessian_np)
from africanus.util.requirements import requires_optional


def _model_wrapper(uvw, freq, model, freq_bin_idx, freq_bin_counts, cell,
                   weights, flag, celly, epsilon, nthreads, do_wstacking):

    return model_np(uvw[0], freq, model[0][0], freq_bin_idx, freq_bin_counts,
                    cell, weights, flag, celly, epsilon, nthreads,
                    do_wstacking)


[docs]@requires_optional('dask.array', dask_import_error) def model(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1, do_wstacking=True): # determine output type complex_type = da.result_type(image, np.complex64) if celly is None: celly = cell if not nthreads: import multiprocessing nthreads = multiprocessing.cpu_count() if weights is None: weight_out = None else: weight_out = ('row', 'chan') if flag is None: flag_out = None else: flag_out = ('row', 'chan') vis = da.blockwise(_model_wrapper, ('row', 'chan'), uvw, ('row', 'three'), freq, ('chan',), image, ('chan', 'nx', 'ny'), freq_bin_idx, ('chan',), freq_bin_counts, ('chan',), cell, None, weights, weight_out, flag, flag_out, celly, None, epsilon, None, nthreads, None, do_wstacking, None, adjust_chunks={'chan': freq.chunks[0]}, dtype=complex_type, align_arrays=False) return vis
def _dirty_wrapper(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, cell, weights, flag, celly, epsilon, nthreads, do_wstacking, double_accum): return dirty_np(uvw[0], freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, cell, weights, flag, celly, epsilon, nthreads, do_wstacking, double_accum)
[docs]@requires_optional('dask.array', dask_import_error) def dirty(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, cell, weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1, do_wstacking=True, double_accum=False): # get real data type (not available from inputs) if vis.dtype == np.complex128: real_type = np.float64 elif vis.dtype == np.complex64: real_type = np.float32 if celly is None: celly = cell if not nthreads: import multiprocessing nthreads = multiprocessing.cpu_count() if weights is None: weight_out = None else: weight_out = ('row', 'chan') if flag is None: flag_out = None else: flag_out = ('row', 'chan') img = da.blockwise(_dirty_wrapper, ('row', 'chan', 'nx', 'ny'), uvw, ('row', 'three'), freq, ('chan',), vis, ('row', 'chan'), freq_bin_idx, ('chan',), freq_bin_counts, ('chan',), nx, None, ny, None, cell, None, weights, weight_out, flag, flag_out, celly, None, epsilon, None, nthreads, None, do_wstacking, None, double_accum, None, adjust_chunks={'chan': freq_bin_idx.chunks[0], 'row': (1,)*len(vis.chunks[0])}, new_axes={"nx": nx, "ny": ny}, dtype=real_type, align_arrays=False) return img.sum(axis=0)
def _residual_wrapper(uvw, freq, model, vis, freq_bin_idx, freq_bin_counts, cell, weights, flag, celly, epsilon, nthreads, do_wstacking, double_accum): return residual_np(uvw[0], freq, model, vis, freq_bin_idx, freq_bin_counts, cell, weights, flag, celly, epsilon, nthreads, do_wstacking, double_accum)
[docs]@requires_optional('dask.array', dask_import_error) def residual(uvw, freq, image, vis, freq_bin_idx, freq_bin_counts, cell, weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1, do_wstacking=True, double_accum=False): if celly is None: celly = cell if not nthreads: import multiprocessing nthreads = multiprocessing.cpu_count() if weights is None: weight_out = None else: weight_out = ('row', 'chan') if flag is None: flag_out = None else: flag_out = ('row', 'chan') img = da.blockwise(_residual_wrapper, ('row', 'chan', 'nx', 'ny'), uvw, ('row', 'three'), freq, ('chan',), image, ('chan', 'nx', 'ny'), vis, ('row', 'chan'), freq_bin_idx, ('chan',), freq_bin_counts, ('chan',), cell, None, weights, weight_out, flag, flag_out, celly, None, epsilon, None, nthreads, None, do_wstacking, None, double_accum, None, adjust_chunks={'chan': freq_bin_idx.chunks[0], 'row': (1,)*len(vis.chunks[0])}, dtype=image.dtype, align_arrays=False) return img.sum(axis=0)
def _hessian_wrapper(uvw, freq, model, freq_bin_idx, freq_bin_counts, cell, weights, flag, celly, epsilon, nthreads, do_wstacking, double_accum): return hessian_np(uvw[0], freq, model, freq_bin_idx, freq_bin_counts, cell, weights, flag, celly, epsilon, nthreads, do_wstacking, double_accum)
[docs]@requires_optional('dask.array', dask_import_error) def hessian(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1, do_wstacking=True, double_accum=False): if celly is None: celly = cell if not nthreads: import multiprocessing nthreads = multiprocessing.cpu_count() if weights is None: weight_out = None else: weight_out = ('row', 'chan') if flag is None: flag_out = None else: flag_out = ('row', 'chan') img = da.blockwise(_hessian_wrapper, ('row', 'chan', 'nx', 'ny'), uvw, ('row', 'three'), freq, ('chan',), image, ('chan', 'nx', 'ny'), freq_bin_idx, ('chan',), freq_bin_counts, ('chan',), cell, None, weights, weight_out, flag, flag_out, celly, None, epsilon, None, nthreads, None, do_wstacking, None, double_accum, None, adjust_chunks={'chan': freq_bin_idx.chunks[0], 'row': (1,)*len(uvw.chunks[0])}, dtype=image.dtype, align_arrays=False) return img.sum(axis=0)
model.__doc__ = MODEL_DOCS.substitute( array_type=":class:`dask.array.Array`") dirty.__doc__ = DIRTY_DOCS.substitute( array_type=":class:`dask.array.Array`") residual.__doc__ = RESIDUAL_DOCS.substitute( array_type=":class:`dask.array.Array`") hessian.__doc__ = HESSIAN_DOCS.substitute( array_type=":class:`dask.array.Array`")