Source code for africanus.model.coherency.dask
# -*- coding: utf-8 -*-
from africanus.model.coherency.conversion import (
convert_setup,
convert_impl,
CONVERT_DOCS,
)
from africanus.util.requirements import requires_optional
try:
import dask.array as da
except ImportError as e:
da_import_error = e
else:
da_import_error = None
def convert_wrapper(np_input, mapping=None, in_shape=None, out_shape=None, dtype_=None):
result = convert_impl(np_input, mapping, in_shape, out_shape, dtype_)
# Introduce extra singleton dimension at the end of our shape
return result.reshape(result.shape + (1,) * len(in_shape))
[docs]
@requires_optional("dask.array", da_import_error)
def convert(input, input_schema, output_schema):
mapping, in_shape, out_shape, dtype = convert_setup(
input, input_schema, output_schema
)
n_free_dims = len(input.shape) - len(in_shape)
free_dims = tuple("dim-%d" % i for i in range(n_free_dims))
in_corr_dims = tuple("icorr-%d" % i for i in range(len(in_shape)))
out_corr_dims = tuple("ocorr-%d" % i for i in range(len(out_shape)))
# Output dimension are new dimensions
new_axes = {d: s for d, s in zip(out_corr_dims, out_shape)}
# Note the dummy in_corr_dims introduced at the end of our output,
# We do this to prevent a contraction over the input dimensions
# (which can be arbitrary) within the wrapper class
res = da.core.blockwise(
convert_wrapper,
free_dims + out_corr_dims + in_corr_dims,
input,
free_dims + in_corr_dims,
mapping=mapping,
in_shape=in_shape,
out_shape=out_shape,
new_axes=new_axes,
dtype_=dtype,
dtype=dtype,
)
# Now contract over the dummy dimensions
start = len(free_dims) + len(out_corr_dims)
end = start + len(in_corr_dims)
return res.sum(axis=list(range(start, end)))
try:
convert.__doc__ = CONVERT_DOCS.substitute(array_type=":class:`dask.array.Array`")
except AttributeError:
pass