# -*- coding: utf-8 -*-
from functools import reduce
import logging
from operator import mul
from pathlib import Path
import numpy as np
from africanus.rime.fast_beam_cubes import BEAM_CUBE_DOCS
from africanus.util.code import format_code, memoize_on_key
from africanus.util.cuda import cuda_function, grids
from africanus.util.jinja2 import jinja_env
from africanus.util.requirements import requires_optional
try:
import cupy as cp
from cupy._core._scalar import get_typename as _get_typename
from cupy.cuda.compiler import CompileException
except ImportError as e:
opt_import_error = e
else:
opt_import_error = None
log = logging.getLogger(__name__)
_MAIN_TEMPLATE_PATH = Path("rime", "cuda", "beam.cu.j2")
_INTERP_TEMPLATE_PATH = Path("rime", "cuda", "beam_freq_interp.cu.j2")
BEAM_NUD_LIMIT = 128
def _freq_interp_key(beam_freq_map, frequencies):
return (beam_freq_map.dtype, frequencies.dtype)
@memoize_on_key(_freq_interp_key)
def _generate_interp_kernel(beam_freq_map, frequencies):
render = jinja_env.get_template(str(_INTERP_TEMPLATE_PATH)).render
name = "beam_cube_freq_interp"
block = (1024, 1, 1)
code = render(kernel_name=name,
beam_nud_limit=BEAM_NUD_LIMIT,
blockdimx=block[0],
beam_freq_type=_get_typename(beam_freq_map.dtype),
freq_type=_get_typename(frequencies.dtype))
dtype = np.result_type(beam_freq_map, frequencies)
return cp.RawKernel(code, name), block, dtype
def _main_key_fn(beam, beam_lm_ext, beam_freq_map,
lm, parangles, pointing_errors,
antenna_scaling, frequencies,
dde_dims, ncorr):
return (beam.dtype, beam.ndim, beam_lm_ext.dtype, beam_freq_map.dtype,
lm.dtype, parangles.dtype, pointing_errors.dtype,
antenna_scaling.dtype, frequencies.dtype, dde_dims, ncorr)
# Value to use in a bit shift to recover channel from flattened
# channel/correlation index
_corr_shifter = {4: 2, 2: 1, 1: 0}
@memoize_on_key(_main_key_fn)
def _generate_main_kernel(beam, beam_lm_ext, beam_freq_map,
lm, parangles, pointing_errors,
antenna_scaling, frequencies,
dde_dims, ncorr):
beam_lw, beam_mh, beam_nud = beam.shape[:3]
if beam_lw < 2 or beam_mh < 2 or beam_nud < 2:
raise ValueError("(beam_lw, beam_mh, beam_nud) < 2 "
"to linearly interpolate")
# Create template
render = jinja_env.get_template(str(_MAIN_TEMPLATE_PATH)).render
name = "beam_cube_dde"
dtype = beam.dtype
if dtype == np.complex64:
block = (32, 32, 1)
elif dtype == np.complex128:
block = (32, 16, 1)
else:
raise TypeError("Need complex beam cube '%s'" % beam.dtype)
try:
corr_shift = _corr_shifter[ncorr]
except KeyError:
raise ValueError("Number of Correlations not in %s"
% list(_corr_shifter.keys()))
coord_type = np.result_type(beam_lm_ext, lm, parangles,
pointing_errors, antenna_scaling,
np.float32)
assert coord_type in (np.float32, np.float64)
code = render(kernel_name=name,
blockdimx=block[0],
blockdimy=block[1],
blockdimz=block[2],
corr_shift=corr_shift,
ncorr=ncorr,
beam_nud_limit=BEAM_NUD_LIMIT,
# Beam type and manipulation functions
beam_type=_get_typename(beam.real.dtype),
beam_dims=beam.ndim,
make2_beam_fn=cuda_function('make2', beam.real.dtype),
beam_sqrt_fn=cuda_function('sqrt', beam.real.dtype),
beam_rsqrt_fn=cuda_function('rsqrt', beam.real.dtype),
# Coordinate type and manipulation functions
FT=_get_typename(coord_type),
floor_fn=cuda_function('floor', coord_type),
min_fn=cuda_function('min', coord_type),
max_fn=cuda_function('max', coord_type),
cos_fn=cuda_function('cos', coord_type),
sin_fn=cuda_function('sin', coord_type),
lm_ext_type=_get_typename(beam_lm_ext.dtype),
beam_freq_type=_get_typename(beam_freq_map.dtype),
lm_type=_get_typename(lm.dtype),
pa_type=_get_typename(parangles.dtype),
pe_type=_get_typename(pointing_errors.dtype),
as_type=_get_typename(antenna_scaling.dtype),
freq_type=_get_typename(frequencies.dtype),
dde_type=_get_typename(beam.real.dtype),
dde_dims=dde_dims)
# Complex output type
return cp.RawKernel(code, name), block, dtype
[docs]@requires_optional('cupy', opt_import_error)
def beam_cube_dde(beam, beam_lm_ext, beam_freq_map,
lm, parangles, pointing_errors,
antenna_scaling, frequencies):
corrs = beam.shape[3:]
if beam.shape[2] >= BEAM_NUD_LIMIT:
raise ValueError("beam_nud exceeds %d" % BEAM_NUD_LIMIT)
nsrc = lm.shape[0]
ntime, na = parangles.shape
nchan = frequencies.shape[0]
ncorr = reduce(mul, corrs, 1)
nchancorr = nchan*ncorr
oshape = (nsrc, ntime, na, nchan) + corrs
if len(corrs) > 1:
# Flatten the beam correlation dims
fbeam = beam.reshape(beam.shape[:3] + (ncorr,))
else:
fbeam = beam
# Generate frequency interpolation kernel
ikernel, iblock, idt = _generate_interp_kernel(beam_freq_map, frequencies)
# Generate main beam cube kernel
kernel, block, dtype = _generate_main_kernel(fbeam, beam_lm_ext,
beam_freq_map,
lm, parangles,
pointing_errors,
antenna_scaling,
frequencies,
len(oshape),
ncorr)
# Call frequency interpolation kernel
igrid = grids((nchan, 1, 1), iblock)
freq_data = cp.empty((3, nchan), dtype=frequencies.dtype)
try:
ikernel(igrid, iblock, (frequencies, beam_freq_map, freq_data))
except CompileException:
log.exception(format_code(ikernel.code))
raise
# Call main beam cube kernel
out = cp.empty((nsrc, ntime, na, nchan) + (ncorr,), dtype=beam.dtype)
grid = grids((nchancorr, na, ntime), block)
try:
kernel(grid, block, (fbeam, beam_lm_ext, beam_freq_map,
lm, parangles, pointing_errors,
antenna_scaling, frequencies, freq_data,
nsrc, out))
except CompileException:
log.exception(format_code(kernel.code))
raise
return out.reshape(oshape)
try:
beam_cube_dde.__doc__ = BEAM_CUBE_DOCS.substitute(
array_type=":class:`cupy.ndarray`")
except AttributeError:
pass