Source code for africanus.rime.cuda.phase

# -*- coding: utf-8 -*-


import logging
from os.path import join as pjoin

import numpy as np

from africanus.constants import minus_two_pi_over_c
from africanus.util.jinja2 import jinja_env
from africanus.rime.phase import PHASE_DELAY_DOCS
from africanus.util.code import memoize_on_key, format_code
from africanus.util.cuda import cuda_function, grids
from africanus.util.requirements import requires_optional

try:
    import cupy as cp
    from cupy._core._scalar import get_typename as _get_typename
    from cupy.cuda.compiler import CompileException
except ImportError:
    pass

log = logging.getLogger(__name__)


def _key_fn(lm, uvw, frequency):
    return (lm.dtype, uvw.dtype, frequency.dtype)


_TEMPLATE_PATH = pjoin("rime", "cuda", "phase.cu.j2")


@memoize_on_key(_key_fn)
def _generate_kernel(lm, uvw, frequency):
    # Floating point output type
    out_dtype = np.result_type(lm, uvw, frequency)

    # Block sizes
    blockdimx = 32 if frequency.dtype == np.float32 else 16
    blockdimy = 32 if uvw.dtype == np.float32 else 16
    block = (blockdimx, blockdimy, 1)

    # Create template
    render = jinja_env.get_template(_TEMPLATE_PATH).render
    name = "phase_delay"

    code = render(kernel_name=name,
                  lm_type=_get_typename(lm.dtype),
                  uvw_type=_get_typename(uvw.dtype),
                  freq_type=_get_typename(frequency.dtype),
                  out_type=_get_typename(out_dtype),
                  sqrt_fn=cuda_function('sqrt', lm.dtype),
                  sincos_fn=cuda_function('sincos', out_dtype),
                  minus_two_pi_over_c=minus_two_pi_over_c,
                  blockdimx=blockdimx,
                  blockdimy=blockdimy)

    # Complex output type
    out_dtype = np.result_type(out_dtype, np.complex64)
    return cp.RawKernel(code, name), block, out_dtype


[docs]@requires_optional("cupy") def phase_delay(lm, uvw, frequency): kernel, block, out_dtype = _generate_kernel(lm, uvw, frequency) grid = grids((frequency.shape[0], uvw.shape[0], 1), block) out = cp.empty(shape=(lm.shape[0], uvw.shape[0], frequency.shape[0]), dtype=out_dtype) try: kernel(grid, block, (lm, uvw, frequency, out)) except CompileException: log.exception(format_code(kernel.code)) raise return out
try: phase_delay.__doc__ = PHASE_DELAY_DOCS.substitute( array_type=':class:`cupy.ndarray`') except AttributeError: pass