import inspect
from functools import partial
from numba.experimental import structref
from numba.core import types
from africanus.experimental.rime.fused.error import InvalidSignature
@structref.register
class StateStructRef(types.StructRef):
def preprocess_fields(self, fields):
"""Disallow literal types in field definitions"""
return tuple((n, types.unliteral(t)) for n, t in fields)
def sigcheck_factory(expected_sig):
def check_constructor_signature(self, fn):
sig = inspect.signature(fn)
if sig != expected_sig:
raise ValueError(
f"{fn.__name__}{sig} should be " f"{fn.__name__}{expected_sig}"
)
return check_constructor_signature
class TermMetaClass(type):
"""
Metaclass which checks that the appropriate methods are
implemented on any subclass of `Term` and that their
signatures agree with each other.
Also sets `ARGS`, `KWARGS` and `ALL_ARGS`
class members on the subclass based on the above
signatures
"""
REQUIRED_METHODS = ("init_fields", "dask_schema", "sampler")
INIT_FIELDS_REQUIRED_ARGS = ("self", "typingctx", "init_state")
@classmethod
def _expand_namespace(cls, name, namespace):
"""
Check that the expected implementations are in the namespace.
Also assign the args and kwargs associated with the implementations
into the namespace
Returns
-------
dict
A copy of `namespace` with args and kwargs assigned
"""
methods = []
for method_name in cls.REQUIRED_METHODS:
try:
method = namespace[method_name]
except KeyError:
raise NotImplementedError(f"{name}.{method_name}")
else:
methods.append(method)
methods = dict(zip(cls.REQUIRED_METHODS, methods))
init_fields_sig = inspect.signature(methods["init_fields"])
field_params = list(init_fields_sig.parameters.values())
sig_error = InvalidSignature(
f"{name}.init_fields{init_fields_sig} "
f"should be "
f"{name}.init_fields({', '.join(cls.INIT_FIELDS_REQUIRED_ARGS)}, ...)"
)
if len(init_fields_sig.parameters) < 3:
raise sig_error
it = iter(init_fields_sig.parameters.items())
expected_args = tuple((next(it)[0], next(it)[0], next(it)[0]))
if expected_args != cls.INIT_FIELDS_REQUIRED_ARGS:
raise sig_error
for n, p in it:
if p.kind == p.VAR_POSITIONAL:
raise InvalidSignature(
f"*{n} in "
f"{name}.init_fields{init_fields_sig} "
f"is not supported"
)
if p.kind == p.VAR_KEYWORD:
raise InvalidSignature(
f"**{n} in "
f"{name}.init_fields{init_fields_sig} "
f"is not supported"
)
dask_schema_sig = inspect.signature(methods["dask_schema"])
expected_dask_params = field_params[0:1] + field_params[3:]
expected_dask_sig = init_fields_sig.replace(parameters=expected_dask_params)
if dask_schema_sig != expected_dask_sig:
raise InvalidSignature(
f"{name}.dask_schema{dask_schema_sig} "
f"should be "
f"{name}.dask_schema{expected_dask_sig}"
)
Parameter = inspect.Parameter
expected_init_sig = init_fields_sig.replace(parameters=field_params[2:])
validator = sigcheck_factory(expected_init_sig)
sampler_sig = inspect.signature(methods["sampler"])
params = [Parameter("self", kind=Parameter.POSITIONAL_OR_KEYWORD)]
expected_sampler_sig = inspect.Signature(parameters=params)
if sampler_sig != expected_sampler_sig:
raise InvalidSignature(
f"{name}.sampler{sampler_sig} "
f"should be "
f"{name}.sampler{expected_sampler_sig}"
)
args = tuple(
n
for n, p in init_fields_sig.parameters.items()
if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}
and n not in set(cls.INIT_FIELDS_REQUIRED_ARGS)
and p.default is p.empty
)
kw = [
(n, p.default)
for n, p in init_fields_sig.parameters.items()
if p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
and n not in set(cls.INIT_FIELDS_REQUIRED_ARGS)
and p.default is not p.empty
]
namespace = namespace.copy()
namespace["ARGS"] = args
namespace["KWARGS"] = dict(kw)
namespace["ALL_ARGS"] = args + tuple(k for k, _ in kw)
namespace["validate_constructor"] = validator
return namespace
@classmethod
def term_in_bases(cls, bases):
"""Is `Term` in bases?"""
for base in bases:
if base is Term or cls.term_in_bases(base.__bases__):
return True
return False
def __new__(mcls, name, bases, namespace):
# Check methods on any subclasses of Term
# and expand the subclass namespace
if mcls.term_in_bases(bases):
namespace = mcls._expand_namespace(name, namespace)
return super(TermMetaClass, mcls).__new__(mcls, name, bases, namespace)
[docs]
class Term(metaclass=TermMetaClass):
def __init__(self, configuration):
self._configuration = configuration
@property
def configuration(self):
return self._configuration
def __eq__(self, rhs):
return isinstance(rhs, Term) and self._configuration == rhs._configuration
def __repr__(self):
return self.__class__.__name__
def __str__(self):
return self.__class__.__name__
@classmethod
def validate_sampler(cls, sampler):
"""Validate the sampler implementation"""
sampler_sig = inspect.signature(sampler)
Parameter = inspect.Parameter
P = partial(Parameter, kind=Parameter.POSITIONAL_OR_KEYWORD)
params = map(P, ["state", "s", "r", "t", "f1", "f2", "a1", "a2", "c"])
expected_sig = inspect.Signature(params)
if sampler_sig != expected_sig:
raise InvalidSignature(
f"{sampler.__name__}{sampler_sig}"
f"should be "
f"{sampler.__name__}{expected_sig}"
)