Source code for africanus.experimental.rime.fused.core

from collections.abc import Mapping
from collections import defaultdict

import numba
from numba import generated_jit, types
import numpy as np

from africanus.util.patterns import Multiton
from africanus.experimental.rime.fused.arguments import ArgumentDependencies
from africanus.experimental.rime.fused.intrinsics import IntrinsicFactory
from africanus.experimental.rime.fused.specification import RimeSpecification


DATASET_TYPES = []

try:
    from daskms.dataset import Dataset as dmsds
except ImportError:
    pass
else:
    DATASET_TYPES.append(dmsds)

try:
    from xarray import Dataset as xrds
except ImportError:
    pass
else:
    DATASET_TYPES.append(xrds)


def rime_impl_factory(terms, transformers, ncorr):
    @generated_jit(nopython=True, nogil=True, cache=True)
    def rime(names, *inargs):
        if len(inargs) != 1 or not isinstance(inargs[0], types.BaseTuple):
            raise TypeError(f"{inargs[0]} must be be a Tuple")

        if not isinstance(names, types.BaseTuple):
            raise TypeError(f"{names} must be a Tuple of strings")

        if len(names) != len(inargs[0]):
            raise ValueError(f"len(names): {len(names)} "
                             f"!= {len(inargs[0])}")

        if not all(isinstance(n, types.Literal) for n in names):
            raise TypeError(f"{names} must be a Tuple of strings")

        if not all(n.literal_type is types.unicode_type for n in names):
            raise TypeError(f"{names} must be a Tuple of strings")

        # Get literal argument names
        names = tuple(n.literal_value for n in names)

        # Generate intrinsics
        argdeps = ArgumentDependencies(names, terms, transformers)
        factory = IntrinsicFactory(argdeps)
        out_names, pack_opts_indices = factory.pack_optionals_and_indices_fn()
        out_names, pack_transformed = factory.pack_transformed_fn(out_names)
        term_state = factory.term_state_fn(out_names)
        term_sampler = factory.term_sampler_fn()

        try:
            lm_i = out_names.index("lm")
            uvw_i = out_names.index("uvw")
            chan_freq_i = out_names.index("chan_freq")
        except ValueError as e:
            raise ValueError(f"{str(e)} is required")

        def impl(names, *inargs):
            args_opt_idx = pack_opts_indices(inargs)
            args = pack_transformed(args_opt_idx)
            state = term_state(args)

            nsrc, _ = args[lm_i].shape
            nrow, _ = args[uvw_i].shape
            nchan, = args[chan_freq_i].shape

            vis = np.zeros((nrow, nchan, ncorr), np.complex128)
            # Kahan summation compensation
            compensation = np.zeros_like(vis)

            for s in range(nsrc):
                for r in range(nrow):
                    t = state.time_index[r]
                    a1 = state.antenna1[r]
                    a2 = state.antenna2[r]
                    f1 = state.feed1[r]
                    f2 = state.feed2[r]

                    for ch in range(nchan):
                        X = term_sampler(state, s, r, t, f1, f2, a1, a2, ch)

                        for c, value in enumerate(numba.literal_unroll(X)):
                            # Kahan summation
                            y = value - compensation[r, ch, c]
                            current = vis[r, ch, c]
                            x = current + y
                            compensation[r, ch, c] = (x - current) - y
                            vis[r, ch, c] = x

            return vis

        return impl

    return rime


class RimeFactory(metaclass=Multiton):
    REQUIRED_ARGS = ArgumentDependencies.REQUIRED_ARGS
    REQUIRED_ARGS_LITERAL = tuple(types.literal(n) for n in REQUIRED_ARGS)
    DEFAULT_SPEC = "(Kpq, Bpq): [I, Q, U, V] -> [XX, XY, YX, YY]"

    def __reduce__(self):
        return (RimeFactory, (self.rime_spec,))

    def __hash__(self):
        return hash(self.rime_spec)

    def __eq__(self, rhs):
        return (isinstance(rhs, RimeFactory) and
                self.rime_spec == rhs.rime_spec)

    def __init__(self, rime_spec=DEFAULT_SPEC):
        if isinstance(rime_spec, RimeSpecification):
            pass
        elif isinstance(rime_spec, (list, tuple)):
            rime_spec = RimeSpecification(*rime_spec)
        elif isinstance(rime_spec, str):
            rime_spec = RimeSpecification(rime_spec)

        self.rime_spec = rime_spec
        self.impl = rime_impl_factory(
            rime_spec.terms,
            rime_spec.transformers,
            len(rime_spec.corrs))

    def dask_blockwise_args(self, **kwargs):
        """ Get the dask schema """
        argdeps = ArgumentDependencies(
            tuple(kwargs.keys()),
            self.rime_spec.terms,
            self.rime_spec.transformers)
        # Holds kwargs + any dummy outputs from transformations
        dummy_kw = kwargs.copy()

        dask_schema = defaultdict(list)
        for a in argdeps.REQUIRED_ARGS:
            dask_schema[a].append(("internal", ("row",)))

        POISON = object()

        for transformer in argdeps.can_create.values():
            kw = {}

            for a in transformer.ARGS:
                v = dummy_kw.get(a, None if a in argdeps.KEY_ARGS else POISON)
                kw[a] = v

            for a, d in transformer.KWARGS.items():
                kw[a] = dummy_kw.get(a, d)

            inputs, outputs = transformer.dask_schema(**kw)

            for k, schema in inputs.items():
                dask_schema[k].append((transformer, schema))

            dummy_kw.update(outputs)

        for term in self.rime_spec.terms:
            kw = {a: dummy_kw[a] for a in term.ALL_ARGS if a in dummy_kw}

            for k, v in term.dask_schema(**kw).items():
                dask_schema[k].append((term, v))

        merged_schema = {}

        for a, candidates in dask_schema.items():
            dims = set(pair[1] for pair in candidates)
            if len(dims) != 1:
                raise ValueError(
                    f"Multiple candidates provided conflicting "
                    f"dimension definitions for {a}: {candidates}.")

            merged_schema[a] = dims.pop()

        names = list(sorted(argdeps.valid_inputs & set(kwargs.keys())))
        blockwise_args = [e for n in names
                          for e in (kwargs[n], merged_schema.get(n, None))]

        assert 2 * len(names) == len(blockwise_args)
        return names, blockwise_args

    def __call__(self, time, antenna1, antenna2, feed1, feed2, **kwargs):
        keys = (self.REQUIRED_ARGS_LITERAL +
                tuple(map(types.literal, kwargs.keys())))

        return self.impl(keys, time,
                         antenna1, antenna2,
                         feed1, feed2,
                         *kwargs.values())


def consolidate_args(args, kw):
    mapping = {}
    oargs = []

    for element in args:
        if isinstance(element, tuple(DATASET_TYPES)):
            mapping.update((k.lower(), v.data) for k, v in element.items())
        elif isinstance(element, Mapping):
            mapping.update(element)
        else:
            oargs.append(element)

    mapping.update(zip(oargs, RimeFactory.REQUIRED_ARGS))
    mapping.update(kw)

    return mapping


[docs]def rime(rime_spec, *args, **kw): """ Evaluates the Radio Interferometer Measurement Equation (RIME), given the Specification of the RIME :code:`rime_spec`, as well as the inputs to the RIME given in :code:`*args` and :code:`**kwargs`. """ mapping = consolidate_args(args, kw) factory = RimeFactory(rime_spec=rime_spec) return factory(**mapping)