Source code for africanus.model.coherency.cuda.conversion

# -*- coding: utf-8 -*-


from functools import reduce
import logging
from operator import mul
from os.path import join as pjoin

import numpy as np

from africanus.model.coherency.conversion import (
    _element_indices_and_shape,
    CONVERT_DOCS,
    MissingConversionInputs,
)
from africanus.util.code import memoize_on_key, format_code
from africanus.util.cuda import cuda_type, grids
from africanus.util.jinja2 import jinja_env
from africanus.util.requirements import requires_optional

try:
    import cupy as cp
    from cupy.cuda.compiler import CompileException
except ImportError as e:
    opt_import_error = e
else:
    opt_import_error = None

log = logging.getLogger(__name__)

stokes_conv = {
    "RR": {("I", "V"): ("complex", "make_{{out_type}}({{I}} + {{V}}, 0)")},
    "RL": {("Q", "U"): ("complex", "make_{{out_type}}({{Q}}, {{U}})")},
    "LR": {("Q", "U"): ("complex", "make_{{out_type}}({{Q}}, -{{U}})")},
    "LL": {("I", "V"): ("complex", "make_{{out_type}}({{I}} - {{V}}, 0)")},
    "XX": {("I", "Q"): ("complex", "make_{{out_type}}({{I}} + {{Q}}, 0)")},
    "XY": {("U", "V"): ("complex", "make_{{out_type}}({{U}}, {{V}})")},
    "YX": {("U", "V"): ("complex", "make_{{out_type}}({{U}}, -{{V}})")},
    "YY": {("I", "Q"): ("complex", "make_{{out_type}}({{I}} - {{Q}}, 0)")},
    "I": {
        ("XX", "YY"): ("real", "(({{XX}}.x + {{YY}}.x) / 2)"),
        ("RR", "LL"): ("real", "(({{RR}}.x + {{LL}}.x) / 2)"),
    },
    "Q": {
        ("XX", "YY"): ("real", "(({{XX}}.x - {{YY}}.x) / 2)"),
        ("RL", "LR"): ("real", "(({{RL}}.x + {{LR}}.x) / 2)"),
    },
    "U": {
        ("XY", "YX"): ("real", "(({{XY}}.x + {{YX}}.x) / 2)"),
        ("RL", "LR"): ("real", "(({{RL}}.y - {{LR}}.y) / 2)"),
    },
    "V": {
        ("XY", "YX"): ("real", "(({{XY}}.y - {{YX}}.y) / 2)"),
        ("RR", "LL"): ("real", "(({{RR}}.x - {{LL}}.x) / 2)"),
    },
}


def stokes_convert_setup(input, input_schema, output_schema):
    input_indices, input_shape = _element_indices_and_shape(input_schema)
    output_indices, output_shape = _element_indices_and_shape(output_schema)

    if input.shape[-len(input_shape) :] != input_shape:
        raise ValueError("Last dimension of input doesn't match input schema")

    mapping = []
    dtypes = []

    # Figure out how to produce an output from available inputs
    for okey, out_idx in output_indices.items():
        try:
            deps = stokes_conv[okey]
        except KeyError:
            raise ValueError(
                "Unknown output '%s'. Known types '%s'" % (okey, stokes_conv.keys())
            )

        found_conv = False

        # Find a mapping for which we have inputs
        for (c1, c2), (dtype, fn) in deps.items():
            # Get indices for both correlations
            try:
                c1_idx = input_indices[c1]
            except KeyError:
                continue

            try:
                c2_idx = input_indices[c2]
            except KeyError:
                continue

            found_conv = True
            dtypes.append(dtype)

            mapping.append(((c1, c1_idx), (c2, c2_idx), out_idx, fn))
            break

        # We must find a conversion
        if not found_conv:
            raise MissingConversionInputs(
                "None of the supplied inputs '%s' "
                "can produce output '%s'. It can be "
                "produced by the following "
                "combinations '%s'." % (input_schema, okey, deps.keys())
            )

    # Output types must be all "real" or all "complex"
    if not all(dtypes[0] == dt for dt in dtypes[1:]):
        raise ValueError("Output data types differ %s" % dtypes)

    return mapping, input_shape, output_shape, dtypes[0]


def schema_to_tuple(schema):
    if isinstance(schema, (tuple, list)):
        return tuple(schema_to_tuple(s) for s in schema)
    else:
        return schema


def _key_fn(inputs, input_schema, output_schema):
    return (inputs.dtype, schema_to_tuple(input_schema), schema_to_tuple(output_schema))


_TEMPLATE_PATH = pjoin("model", "coherency", "cuda", "conversion.cu.j2")


@memoize_on_key(_key_fn)
def _generate_kernel(inputs, input_schema, output_schema):
    mapping, in_shape, out_shape, out_dtype = stokes_convert_setup(
        inputs, input_schema, output_schema
    )

    # Flatten input and output shapes
    # Check that number elements are the same
    in_elems = reduce(mul, in_shape, 1)
    out_elems = reduce(mul, out_shape, 1)

    if in_elems != out_elems:
        raise ValueError(
            "Number of input_schema elements %s "
            "and output schema elements %s "
            "must match for CUDA kernel." % (in_shape, out_shape)
        )

    # Infer the output data type
    if out_dtype == "real":
        if np.iscomplexobj(inputs):
            out_dtype = inputs.real.dtype
        else:
            out_dtype = inputs.dtype
    elif out_dtype == "complex":
        if np.iscomplexobj(inputs):
            out_dtype = inputs.dtype
        else:
            out_dtype = np.result_type(inputs.dtype, np.complex64)
    else:
        raise ValueError("Invalid setup dtype %s" % out_dtype)

    cuda_out_dtype = cuda_type(out_dtype)
    assign_exprs = []

    # Render the assignment expression for each element
    for (c1, c1i), (c2, c2i), outi, template_fn in mapping:
        # Flattened indices
        flat_outi = np.ravel_multi_index(outi, out_shape)
        render = jinja_env.from_string(template_fn).render
        kwargs = {
            c1: "in[%d]" % np.ravel_multi_index(c1i, in_shape),
            c2: "in[%d]" % np.ravel_multi_index(c2i, in_shape),
            "out_type": cuda_out_dtype,
        }

        expr_str = render(**kwargs)
        assign_exprs.append("out[%d] = %s;" % (flat_outi, expr_str))

    # Now render the main template
    render = jinja_env.get_template(_TEMPLATE_PATH).render
    name = "stokes_convert"
    code = render(
        kernel_name=name,
        input_type=cuda_type(inputs.dtype),
        output_type=cuda_type(out_dtype),
        assign_exprs=assign_exprs,
        elements=in_elems,
    )

    # cuda block, flatten non-schema dims into a single source dim
    blockdimx = 512
    block = (blockdimx, 1, 1)

    return (cp.RawKernel(code, name), block, in_shape, out_shape, out_dtype)


[docs] @requires_optional("cupy", opt_import_error) def convert(inputs, input_schema, output_schema): (kernel, block, in_shape, out_shape, dtype) = _generate_kernel( inputs, input_schema, output_schema ) # Flatten non-schema input dimensions, # from inspection of the cupy reshape code, # this incurs a copy when inputs is non-contiguous nsrc = reduce(mul, inputs.shape[: -len(in_shape)], 1) nelems = reduce(mul, in_shape, 1) rinputs = inputs.reshape(nsrc, nelems) assert rinputs.flags.c_contiguous grid = grids((nsrc, 1, 1), block) outputs = cp.empty(shape=rinputs.shape, dtype=dtype) try: kernel(grid, block, (rinputs, outputs)) except CompileException: log.exception(format_code(kernel.code)) raise shape = inputs.shape[: -len(in_shape)] + out_shape outputs = outputs.reshape(shape) assert outputs.flags.c_contiguous return outputs
try: convert.__doc__ = CONVERT_DOCS.substitute(array_type=":class:`cupy.ndarray`") except AttributeError: pass