Source code for africanus.experimental.rime.fused.core

from collections.abc import Mapping
from collections import defaultdict

import numba
from numba import types
import numpy as np

from africanus.util.patterns import Multiton
from africanus.util.numba import overload, njit, JIT_OPTIONS
from africanus.experimental.rime.fused.arguments import ArgumentDependencies
from africanus.experimental.rime.fused.intrinsics import IntrinsicFactory
from africanus.experimental.rime.fused.specification import RimeSpecification


DATASET_TYPES = []

try:
    from daskms.dataset import Dataset as dmsds
except ImportError:
    pass
else:
    DATASET_TYPES.append(dmsds)

try:
    from xarray import Dataset as xrds
except ImportError:
    pass
else:
    DATASET_TYPES.append(xrds)


def rime_impl_factory(terms, transformers, ncorr):
    @njit(**JIT_OPTIONS)
    def rime(*args):
        return rime_impl(*args)

    def rime_impl(*args):
        raise NotImplementedError

    @overload(rime_impl, jit_options=JIT_OPTIONS, prefer_literal=True)
    def nb_rime(*args):
        if not len(args) > 0:
            raise TypeError(
                "rime must be at least be called with the signature argument"
            )

        if not isinstance(args[0], types.Literal):
            raise TypeError(f"Signature hash ({args[0]}) must be a literal")

        if not len(args) % 2 == 1:
            raise TypeError(
                f"Length of named arguments {len(args)} " f"is not divisible by 2"
            )

        argstart = 1 + (len(args) - 1) // 2
        names = args[1:argstart]

        if not all(isinstance(n, types.Literal) for n in names):
            raise TypeError(f"{names} must be a Tuple of Literal strings")

        if not all(n.literal_type is types.unicode_type for n in names):
            raise TypeError(f"{names} must be a Tuple of Literal strings")

        # Get literal argument names
        names = tuple(n.literal_value for n in names)

        # Generate intrinsics
        argdeps = ArgumentDependencies(names, terms, transformers)
        factory = IntrinsicFactory(argdeps)
        out_names, pack_opts_indices = factory.pack_optionals_and_indices_fn()
        out_names, pack_transformed = factory.pack_transformed_fn(out_names)
        term_state = factory.term_state_fn(out_names)
        term_sampler = factory.term_sampler_fn()

        try:
            lm_i = out_names.index("lm")
            uvw_i = out_names.index("uvw")
            chan_freq_i = out_names.index("chan_freq")
        except ValueError as e:
            raise ValueError(f"{str(e)} is required")

        def impl(*args):
            args_opt_idx = pack_opts_indices(args[argstart:])
            args = pack_transformed(args_opt_idx)
            state = term_state(args)

            nsrc, _ = args[lm_i].shape
            nrow, _ = args[uvw_i].shape
            (nchan,) = args[chan_freq_i].shape

            vis = np.zeros((nrow, nchan, ncorr), np.complex128)
            # Kahan summation compensation
            compensation = np.zeros_like(vis)

            for s in range(nsrc):
                for r in range(nrow):
                    t = state.time_inverse[r]
                    a1 = state.antenna1[r]
                    a2 = state.antenna2[r]
                    f1 = state.feed1[r]
                    f2 = state.feed2[r]

                    for ch in range(nchan):
                        X = term_sampler(state, s, r, t, f1, f2, a1, a2, ch)

                        for co, value in enumerate(numba.literal_unroll(X)):
                            # Kahan summation
                            y = value - compensation[r, ch, co]
                            current = vis[r, ch, co]
                            x = current + y
                            compensation[r, ch, co] = (x - current) - y
                            vis[r, ch, co] = x

            return vis

        return impl

    return rime


class RimeFactory(metaclass=Multiton):
    REQUIRED_ARGS = ArgumentDependencies.REQUIRED_ARGS
    REQUIRED_ARGS_LITERAL = tuple(types.literal(n) for n in REQUIRED_ARGS)
    DEFAULT_SPEC = "(Kpq, Bpq): [I, Q, U, V] -> [XX, XY, YX, YY]"

    def __reduce__(self):
        return (RimeFactory, (self.rime_spec,))

    def __hash__(self):
        return hash(self.rime_spec)

    def __eq__(self, rhs):
        return isinstance(rhs, RimeFactory) and self.rime_spec == rhs.rime_spec

    def __init__(self, rime_spec=DEFAULT_SPEC):
        if isinstance(rime_spec, RimeSpecification):
            pass
        elif isinstance(rime_spec, (list, tuple)):
            rime_spec = RimeSpecification(*rime_spec)
        elif isinstance(rime_spec, str):
            rime_spec = RimeSpecification(rime_spec)

        self.rime_spec = rime_spec
        self.impl = rime_impl_factory(
            rime_spec.terms, rime_spec.transformers, len(rime_spec.corrs)
        )

    def dask_blockwise_args(self, **kwargs):
        """Get the dask schema"""
        argdeps = ArgumentDependencies(
            tuple(kwargs.keys()), self.rime_spec.terms, self.rime_spec.transformers
        )
        # Holds kwargs + any dummy outputs from transformations
        dummy_kw = kwargs.copy()

        dask_schema = defaultdict(list)
        for a in argdeps.REQUIRED_ARGS:
            dask_schema[a].append(("internal", ("row",)))

        POISON = object()

        for transformer in argdeps.can_create.values():
            kw = {}

            for a in transformer.ARGS:
                v = dummy_kw.get(a, None if a in argdeps.KEY_ARGS else POISON)
                kw[a] = v

            for a, d in transformer.KWARGS.items():
                kw[a] = dummy_kw.get(a, d)

            inputs, outputs = transformer.dask_schema(**kw)

            for k, schema in inputs.items():
                dask_schema[k].append((transformer, schema))

            dummy_kw.update(outputs)

        for term in self.rime_spec.terms:
            kw = {a: dummy_kw[a] for a in term.ALL_ARGS if a in dummy_kw}

            for k, v in term.dask_schema(**kw).items():
                dask_schema[k].append((term, v))

        merged_schema = {}

        for a, candidates in dask_schema.items():
            dims = set(pair[1] for pair in candidates)
            if len(dims) != 1:
                raise ValueError(
                    f"Multiple candidates provided conflicting "
                    f"dimension definitions for {a}: {candidates}."
                )

            merged_schema[a] = dims.pop()

        names = list(sorted(argdeps.valid_inputs & set(kwargs.keys())))
        blockwise_args = [
            e for n in names for e in (kwargs[n], merged_schema.get(n, None))
        ]

        assert 2 * len(names) == len(blockwise_args)
        return names, blockwise_args

    def __call__(self, time, antenna1, antenna2, feed1, feed2, **kwargs):
        keys = self.REQUIRED_ARGS_LITERAL + tuple(map(types.literal, kwargs.keys()))
        args = keys + (time, antenna1, antenna2, feed1, feed2) + tuple(kwargs.values())
        return self.impl(types.literal(self.rime_spec.spec_hash), *args)


def consolidate_args(args, kw):
    mapping = {}
    oargs = []

    for element in args:
        if isinstance(element, tuple(DATASET_TYPES)):
            mapping.update((k.lower(), v.data) for k, v in element.items())
        elif isinstance(element, Mapping):
            mapping.update(element)
        else:
            oargs.append(element)

    mapping.update(zip(oargs, RimeFactory.REQUIRED_ARGS))
    mapping.update(kw)

    return mapping


[docs] def rime(rime_spec, *args, **kw): """ Evaluates the Radio Interferometer Measurement Equation (RIME), given the Specification of the RIME :code:`rime_spec`, as well as the inputs to the RIME given in :code:`*args` and :code:`**kwargs`. """ mapping = consolidate_args(args, kw) factory = RimeFactory(rime_spec=rime_spec) return factory(**mapping)