Source code for africanus.model.spi.dask

# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import wraps

from africanus.model.spi.component_spi import SPI_DOCSTRING
from africanus.model.spi.component_spi import (
                                fit_spi_components as np_fit_spi_components)

from africanus.util.requirements import requires_optional

import numpy as np

try:
    from dask.array.core import blockwise
except ImportError as e:
    opt_import_error = e
else:
    opt_import_error = None


@wraps(np_fit_spi_components)
def _fit_spi_components_wrapper(data, weights, freqs, freq0,
                                alphai, I0i, tol_, maxiter_,
                                dtype_):
    return np_fit_spi_components(data[0],
                                 weights[0],
                                 freqs[0],
                                 freq0,
                                 alphai[0] if alphai is not None else alphai,
                                 I0i[0] if I0i is not None else I0i,
                                 tol=tol_,
                                 maxiter=maxiter_,
                                 dtype=dtype_)


[docs]@requires_optional('dask.array', opt_import_error) def fit_spi_components(data, weights, freqs, freq0, alphai=None, I0i=None, tol=1e-6, maxiter=100, dtype=np.float64): """ Dask wrapper fit_spi_components function """ return blockwise(_fit_spi_components_wrapper, ("vars", "comps"), data, ("comps", "chan"), weights, ("chan",), freqs, ("chan",), freq0, None, alphai, ("comps",) if alphai is not None else None, I0i, ("comps",) if I0i is not None else None, tol, None, maxiter, None, dtype, None, new_axes={"vars": 4}, dtype=dtype)
fit_spi_components.__doc__ = SPI_DOCSTRING.substitute( array_type=":class:`dask.array.Array`")