Source code for africanus.experimental.rime.fused.terms.core

import inspect
from functools import partial

from numba.experimental import structref
from numba.core import types

from africanus.experimental.rime.fused.error import InvalidSignature


@structref.register
class StateStructRef(types.StructRef):
    def preprocess_fields(self, fields):
        """Disallow literal types in field definitions"""
        return tuple((n, types.unliteral(t)) for n, t in fields)


def sigcheck_factory(expected_sig):
    def check_constructor_signature(self, fn):
        sig = inspect.signature(fn)
        if sig != expected_sig:
            raise ValueError(
                f"{fn.__name__}{sig} should be " f"{fn.__name__}{expected_sig}"
            )

    return check_constructor_signature


class TermMetaClass(type):
    """
    Metaclass which checks that the appropriate methods are
    implemented on any subclass of `Term` and that their
    signatures agree with each other.

    Also sets `ARGS`, `KWARGS` and `ALL_ARGS`
    class members on the subclass based on the above
    signatures
    """

    REQUIRED = ("init_fields", "dask_schema", "sampler")

    @classmethod
    def _expand_namespace(cls, name, namespace):
        """
        Check that the expected implementations are in the namespace.

        Also assign the args and kwargs associated with the implementations
        into the namespace

        Returns
        -------
        dict
            A copy of `namespace` with args and kwargs assigned
        """
        methods = []

        for method_name in cls.REQUIRED:
            try:
                method = namespace[method_name]
            except KeyError:
                raise NotImplementedError(f"{name}.{method_name}")
            else:
                methods.append(method)

        methods = dict(zip(cls.REQUIRED, methods))
        init_fields_sig = inspect.signature(methods["init_fields"])
        field_params = list(init_fields_sig.parameters.values())

        if len(init_fields_sig.parameters) < 2:
            raise InvalidSignature(
                f"{name}.init_fields{init_fields_sig} "
                f"should be "
                f"{name}.init_fields(self, typingctx, ...)"
            )

        it = iter(init_fields_sig.parameters.items())
        first, second = next(it), next(it)

        if first[0] != "self" or second[0] != "typingctx":
            raise InvalidSignature(
                f"{name}.init_fields{init_fields_sig} "
                f"should be "
                f"{name}.init_fields(self, typingctx, ...)"
            )

        for n, p in it:
            if p.kind == p.VAR_POSITIONAL:
                raise InvalidSignature(
                    f"*{n} in "
                    f"{name}.init_fields{init_fields_sig} "
                    f"is not supported"
                )

            if p.kind == p.VAR_KEYWORD:
                raise InvalidSignature(
                    f"**{n} in "
                    f"{name}.init_fields{init_fields_sig} "
                    f"is not supported"
                )

        dask_schema_sig = inspect.signature(methods["dask_schema"])
        expected_dask_params = field_params[0:1] + field_params[2:]
        expected_dask_sig = init_fields_sig.replace(parameters=expected_dask_params)

        if dask_schema_sig != expected_dask_sig:
            raise InvalidSignature(
                f"{name}.dask_schema{dask_schema_sig} "
                f"should be "
                f"{name}.dask_schema{expected_dask_sig}"
            )

        Parameter = inspect.Parameter
        expected_init_sig = init_fields_sig.replace(parameters=field_params[2:])
        validator = sigcheck_factory(expected_init_sig)

        sampler_sig = inspect.signature(methods["sampler"])
        params = [Parameter("self", kind=Parameter.POSITIONAL_OR_KEYWORD)]
        expected_sampler_sig = inspect.Signature(parameters=params)

        if sampler_sig != expected_sampler_sig:
            raise InvalidSignature(
                f"{name}.sampler{sampler_sig} "
                f"should be "
                f"{name}.sampler{expected_sampler_sig}"
            )

        args = tuple(
            n
            for n, p in init_fields_sig.parameters.items()
            if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}
            and n not in {"self", "typingctx"}
            and p.default is p.empty
        )

        kw = [
            (n, p.default)
            for n, p in init_fields_sig.parameters.items()
            if p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
            and n not in {"self", "typingctx"}
            and p.default is not p.empty
        ]

        namespace = namespace.copy()
        namespace["ARGS"] = args
        namespace["KWARGS"] = dict(kw)
        namespace["ALL_ARGS"] = args + tuple(k for k, _ in kw)
        namespace["validate_constructor"] = validator

        return namespace

    @classmethod
    def term_in_bases(cls, bases):
        """Is `Term` in bases?"""
        for base in bases:
            if base is Term or cls.term_in_bases(base.__bases__):
                return True

        return False

    def __new__(mcls, name, bases, namespace):
        # Check methods on any subclasses of Term
        # and expand the subclass namespace
        if mcls.term_in_bases(bases):
            namespace = mcls._expand_namespace(name, namespace)

        return super(TermMetaClass, mcls).__new__(mcls, name, bases, namespace)


[docs] class Term(metaclass=TermMetaClass): def __init__(self, configuration): self._configuration = configuration @property def configuration(self): return self._configuration def __eq__(self, rhs): return isinstance(rhs, Term) and self._configuration == rhs._configuration def __repr__(self): return self.__class__.__name__ def __str__(self): return self.__class__.__name__ @classmethod def validate_sampler(cls, sampler): """Validate the sampler implementation""" sampler_sig = inspect.signature(sampler) Parameter = inspect.Parameter P = partial(Parameter, kind=Parameter.POSITIONAL_OR_KEYWORD) params = map(P, ["state", "s", "r", "t", "f1", "f2", "a1", "a2", "c"]) expected_sig = inspect.Signature(params) if sampler_sig != expected_sig: raise InvalidSignature( f"{sampler.__name__}{sampler_sig}" f"should be " f"{sampler.__name__}{expected_sig}" )