Source code for africanus.rime.cuda.beam

# -*- coding: utf-8 -*-

from functools import reduce
import logging
from operator import mul
from pathlib import Path

import numpy as np

from africanus.rime.fast_beam_cubes import BEAM_CUBE_DOCS
from africanus.util.code import format_code, memoize_on_key
from africanus.util.cuda import cuda_function, grids
from africanus.util.jinja2 import jinja_env
from africanus.util.requirements import requires_optional

try:
    import cupy as cp
    from cupy._core._scalar import get_typename as _get_typename
    from cupy.cuda.compiler import CompileException
except ImportError as e:
    opt_import_error = e
else:
    opt_import_error = None


log = logging.getLogger(__name__)


_MAIN_TEMPLATE_PATH = Path("rime", "cuda", "beam.cu.j2")
_INTERP_TEMPLATE_PATH = Path("rime", "cuda", "beam_freq_interp.cu.j2")

BEAM_NUD_LIMIT = 128


def _freq_interp_key(beam_freq_map, frequencies):
    return (beam_freq_map.dtype, frequencies.dtype)


@memoize_on_key(_freq_interp_key)
def _generate_interp_kernel(beam_freq_map, frequencies):
    render = jinja_env.get_template(str(_INTERP_TEMPLATE_PATH)).render
    name = "beam_cube_freq_interp"

    block = (1024, 1, 1)

    code = render(
        kernel_name=name,
        beam_nud_limit=BEAM_NUD_LIMIT,
        blockdimx=block[0],
        beam_freq_type=_get_typename(beam_freq_map.dtype),
        freq_type=_get_typename(frequencies.dtype),
    )

    dtype = np.result_type(beam_freq_map, frequencies)

    return cp.RawKernel(code, name), block, dtype


def _main_key_fn(
    beam,
    beam_lm_ext,
    beam_freq_map,
    lm,
    parangles,
    pointing_errors,
    antenna_scaling,
    frequencies,
    dde_dims,
    ncorr,
):
    return (
        beam.dtype,
        beam.ndim,
        beam_lm_ext.dtype,
        beam_freq_map.dtype,
        lm.dtype,
        parangles.dtype,
        pointing_errors.dtype,
        antenna_scaling.dtype,
        frequencies.dtype,
        dde_dims,
        ncorr,
    )


# Value to use in a bit shift to recover channel from flattened
# channel/correlation index
_corr_shifter = {4: 2, 2: 1, 1: 0}


@memoize_on_key(_main_key_fn)
def _generate_main_kernel(
    beam,
    beam_lm_ext,
    beam_freq_map,
    lm,
    parangles,
    pointing_errors,
    antenna_scaling,
    frequencies,
    dde_dims,
    ncorr,
):
    beam_lw, beam_mh, beam_nud = beam.shape[:3]

    if beam_lw < 2 or beam_mh < 2 or beam_nud < 2:
        raise ValueError("(beam_lw, beam_mh, beam_nud) < 2 " "to linearly interpolate")

    # Create template
    render = jinja_env.get_template(str(_MAIN_TEMPLATE_PATH)).render
    name = "beam_cube_dde"
    dtype = beam.dtype

    if dtype == np.complex64:
        block = (32, 32, 1)
    elif dtype == np.complex128:
        block = (32, 16, 1)
    else:
        raise TypeError("Need complex beam cube '%s'" % beam.dtype)

    try:
        corr_shift = _corr_shifter[ncorr]
    except KeyError:
        raise ValueError(
            "Number of Correlations not in %s" % list(_corr_shifter.keys())
        )

    coord_type = np.result_type(
        beam_lm_ext, lm, parangles, pointing_errors, antenna_scaling, np.float32
    )

    assert coord_type in (np.float32, np.float64)

    code = render(
        kernel_name=name,
        blockdimx=block[0],
        blockdimy=block[1],
        blockdimz=block[2],
        corr_shift=corr_shift,
        ncorr=ncorr,
        beam_nud_limit=BEAM_NUD_LIMIT,
        # Beam type and manipulation functions
        beam_type=_get_typename(beam.real.dtype),
        beam_dims=beam.ndim,
        make2_beam_fn=cuda_function("make2", beam.real.dtype),
        beam_sqrt_fn=cuda_function("sqrt", beam.real.dtype),
        beam_rsqrt_fn=cuda_function("rsqrt", beam.real.dtype),
        # Coordinate type and manipulation functions
        FT=_get_typename(coord_type),
        floor_fn=cuda_function("floor", coord_type),
        min_fn=cuda_function("min", coord_type),
        max_fn=cuda_function("max", coord_type),
        cos_fn=cuda_function("cos", coord_type),
        sin_fn=cuda_function("sin", coord_type),
        lm_ext_type=_get_typename(beam_lm_ext.dtype),
        beam_freq_type=_get_typename(beam_freq_map.dtype),
        lm_type=_get_typename(lm.dtype),
        pa_type=_get_typename(parangles.dtype),
        pe_type=_get_typename(pointing_errors.dtype),
        as_type=_get_typename(antenna_scaling.dtype),
        freq_type=_get_typename(frequencies.dtype),
        dde_type=_get_typename(beam.real.dtype),
        dde_dims=dde_dims,
    )

    # Complex output type
    return cp.RawKernel(code, name), block, dtype


[docs] @requires_optional("cupy", opt_import_error) def beam_cube_dde( beam, beam_lm_ext, beam_freq_map, lm, parangles, pointing_errors, antenna_scaling, frequencies, ): corrs = beam.shape[3:] if beam.shape[2] >= BEAM_NUD_LIMIT: raise ValueError("beam_nud exceeds %d" % BEAM_NUD_LIMIT) nsrc = lm.shape[0] ntime, na = parangles.shape nchan = frequencies.shape[0] ncorr = reduce(mul, corrs, 1) nchancorr = nchan * ncorr oshape = (nsrc, ntime, na, nchan) + corrs if len(corrs) > 1: # Flatten the beam correlation dims fbeam = beam.reshape(beam.shape[:3] + (ncorr,)) else: fbeam = beam # Generate frequency interpolation kernel ikernel, iblock, idt = _generate_interp_kernel(beam_freq_map, frequencies) # Generate main beam cube kernel kernel, block, dtype = _generate_main_kernel( fbeam, beam_lm_ext, beam_freq_map, lm, parangles, pointing_errors, antenna_scaling, frequencies, len(oshape), ncorr, ) # Call frequency interpolation kernel igrid = grids((nchan, 1, 1), iblock) freq_data = cp.empty((3, nchan), dtype=frequencies.dtype) try: ikernel(igrid, iblock, (frequencies, beam_freq_map, freq_data)) except CompileException: log.exception(format_code(ikernel.code)) raise # Call main beam cube kernel out = cp.empty((nsrc, ntime, na, nchan) + (ncorr,), dtype=beam.dtype) grid = grids((nchancorr, na, ntime), block) try: kernel( grid, block, ( fbeam, beam_lm_ext, beam_freq_map, lm, parangles, pointing_errors, antenna_scaling, frequencies, freq_data, nsrc, out, ), ) except CompileException: log.exception(format_code(kernel.code)) raise return out.reshape(oshape)
try: beam_cube_dde.__doc__ = BEAM_CUBE_DOCS.substitute( array_type=":class:`cupy.ndarray`" ) except AttributeError: pass