Source code for africanus.rime.cuda.predict

# -*- coding: utf-8 -*-


from functools import reduce
import logging
from operator import mul
from os.path import join as pjoin

import numpy as np

from africanus.rime.predict import (PREDICT_DOCS, predict_checks)
from africanus.util.code import format_code, memoize_on_key
from africanus.util.cuda import cuda_type, grids
from africanus.util.jinja2 import jinja_env
from africanus.util.requirements import requires_optional

try:
    import cupy as cp
    from cupy.cuda.compiler import CompileException
except ImportError as e:
    opt_import_error = e
else:
    opt_import_error = None


log = logging.getLogger(__name__)


_TEMPLATE_PATH = pjoin("rime", "cuda", "predict.cu.j2")


def _key_fn(*args):
    """ Hash on array datatypes and rank """
    return tuple((a.dtype, a.ndim)
                 if isinstance(a, (np.ndarray, cp.ndarray))
                 else a for a in args)


@memoize_on_key(_key_fn)
def _generate_kernel(time_index, antenna1, antenna2,
                     dde1_jones, source_coh, dde2_jones,
                     die1_jones, base_vis, die2_jones,
                     corrs, out_ndim):

    tup = predict_checks(time_index, antenna1, antenna2,
                         dde1_jones, source_coh, dde2_jones,
                         die1_jones, base_vis, die2_jones)

    (have_ddes1, have_coh, have_ddes2, have_dies1, have_bvis, have_dies2) = tup

    # Check types
    if time_index.dtype != np.int32:
        raise TypeError("time_index.dtype != np.int32 '%s'" % time_index.dtype)

    if antenna1.dtype != np.int32:
        raise TypeError("antenna1.dtype != np.int32 '%s'" % antenna1.dtype)

    if antenna2.dtype != np.int32:
        raise TypeError("antenna2.dtype != np.int32 '%s'" % antenna2.dtype)

    # Create template
    render = jinja_env.get_template(_TEMPLATE_PATH).render
    name = "predict_vis"

    # Complex output type
    out_dtype = np.result_type(dde1_jones, source_coh, dde2_jones,
                               die1_jones, base_vis, die2_jones)

    ncorrs = reduce(mul, corrs, 1)

    # corrs x channels, rows
    blockdimx = 32
    blockdimy = 24 if out_dtype == np.complex128 else 32

    block = (blockdimx, blockdimy, 1)

    code = render(kernel_name=name, blockdimx=blockdimx, blockdimy=blockdimy,
                  have_dde1=have_ddes1,
                  dde1_type=cuda_type(dde1_jones) if have_ddes1 else "int",
                  dde1_ndim=dde1_jones.ndim if have_ddes1 else 1,
                  have_dde2=have_ddes2,
                  dde2_type=cuda_type(dde2_jones) if have_ddes2 else "int",
                  dde2_ndim=dde2_jones.ndim if have_ddes2 else 1,
                  have_coh=have_coh,
                  coh_type=cuda_type(source_coh) if have_coh else "int",
                  coh_ndim=source_coh.ndim if have_coh else 1,
                  have_die1=have_dies1,
                  die1_type=cuda_type(die1_jones) if have_dies1 else "int",
                  die1_ndim=die1_jones.ndim if have_dies1 else 1,
                  have_base_vis=have_bvis,
                  base_vis_type=cuda_type(base_vis) if have_bvis else "int",
                  base_vis_ndim=base_vis.ndim if have_bvis else 1,
                  have_die2=have_dies2,
                  die2_type=cuda_type(die2_jones) if have_dies2 else "int",
                  die2_ndim=die2_jones.ndim if have_dies2 else 1,
                  out_type=cuda_type(out_dtype),
                  corrs=ncorrs,
                  out_ndim=out_ndim,
                  warp_size=32).encode('utf-8')

    return cp.RawKernel(code, name), block, out_dtype


[docs]@requires_optional("cupy", opt_import_error) def predict_vis(time_index, antenna1, antenna2, dde1_jones=None, source_coh=None, dde2_jones=None, die1_jones=None, base_vis=None, die2_jones=None): """ Cupy implementation of the feed_rotation kernel. """ have_ddes = dde1_jones is not None and dde2_jones is not None have_dies = die1_jones is not None and die2_jones is not None have_coh = source_coh is not None have_bvis = base_vis is not None # Infer the output shape if have_ddes: row = time_index.shape[0] chan = dde1_jones.shape[3] corrs = dde1_jones.shape[4:] elif have_coh: row = time_index.shape[0] chan = source_coh.shape[2] corrs = source_coh.shape[3:] elif have_dies: row = time_index.shape[0] chan = die1_jones.shape[2] corrs = die1_jones.shape[3:] elif have_bvis: row = time_index.shape[0] chan = base_vis.shape[1] corrs = base_vis.shape[2:] else: raise ValueError("Insufficient inputs supplied for determining " "the output shape") ncorrs = len(corrs) # Flatten correlations if ncorrs == 2: flat_corrs = reduce(mul, corrs, 1) if have_ddes: dde_shape = dde1_jones.shape[:-ncorrs] + (flat_corrs,) dde1_jones = dde1_jones.reshape(dde_shape) dde2_jones = dde2_jones.reshape(dde_shape) if have_coh: coh_shape = source_coh.shape[:-ncorrs] + (flat_corrs,) source_coh = source_coh.reshape(coh_shape) if have_dies: die_shape = die1_jones.shape[:-ncorrs] + (flat_corrs,) die1_jones = die1_jones.reshape(die_shape) die2_jones = die2_jones.reshape(die_shape) if have_bvis: bvis_shape = base_vis.shape[:-ncorrs] + (flat_corrs,) base_vis = base_vis.reshape(bvis_shape) elif ncorrs == 1: flat_corrs = corrs[0] else: raise ValueError("Invalid correlation setup %s" % (corrs,)) out_shape = (row, chan) + (flat_corrs,) kernel, block, out_dtype = _generate_kernel(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones, corrs, len(out_shape)) grid = grids((chan*flat_corrs, row, 1), block) out = cp.empty(shape=out_shape, dtype=out_dtype) # Normalise the time index # TODO(sjperkins) # Normalise the time index with a device-wide reduction norm_time_index = time_index - time_index.min() args = (norm_time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones, out) try: kernel(grid, block, tuple(a for a in args if a is not None)) except CompileException: log.exception(format_code(kernel.code)) raise return out.reshape((row, chan) + corrs)
try: predict_vis.__doc__ = PREDICT_DOCS.substitute( array_type=":class:`cupy.ndarray`", get_time_index=":code:`cp.unique(time, " "return_inverse=True)[1]`", extra_args="", extra_notes="") except AttributeError: pass