Source code for africanus.rime.cuda.feeds

# -*- coding: utf-8 -*-

import logging
from os.path import join as pjoin

import numpy as np

from africanus.rime.feeds import FEED_ROTATION_DOCS
from africanus.util.code import format_code, memoize_on_key
from africanus.util.cuda import cuda_function, grids
from africanus.util.jinja2 import jinja_env
from africanus.util.requirements import requires_optional

    import cupy as cp
    from cupy._core._scalar import get_typename as _get_typename
    from cupy.cuda.compiler import CompileException
except ImportError as e:
    opt_import_error = e
    opt_import_error = None

log = logging.getLogger(__name__)

_TEMPLATE_PATH = pjoin("rime", "cuda", "")

def _key_fn(parallactic_angles, feed_type):
    return (parallactic_angles.dtype, feed_type)

def _generate_kernel(parallactic_angles, feed_type):
    dtype = parallactic_angles.dtype

    # Block sizes
    if dtype == np.float32:
        block = (1024, 1, 1)
    elif dtype == np.float64:
        block = (512, 1, 1)
        raise TypeError("Unhandled type %s" % dtype)

    # Create template
    render = jinja_env.get_template(_TEMPLATE_PATH).render
    name = "feed_rotation"

    code = render(
        sincos_fn=cuda_function("sincos", dtype),

    # Complex output type
    out_dtype = np.result_type(dtype, np.complex64)
    return cp.RawKernel(code, name), block, out_dtype

[docs] @requires_optional("cupy", opt_import_error) def feed_rotation(parallactic_angles, feed_type="linear"): """Cupy implementation of the feed_rotation kernel.""" kernel, block, out_dtype = _generate_kernel(parallactic_angles, feed_type) in_shape = parallactic_angles.shape parallactic_angles = parallactic_angles.ravel() grid = grids((parallactic_angles.shape[0], 1, 1), block) out = cp.empty(shape=(parallactic_angles.shape[0], 4), dtype=out_dtype) try: kernel(grid, block, (parallactic_angles, out)) except CompileException: log.exception(format_code(kernel.code)) raise return out.reshape(in_shape + (2, 2))
try: feed_rotation.__doc__ = FEED_ROTATION_DOCS.substitute( array_type=":class:`cupy.ndarray`" ) except AttributeError: pass