# -*- coding: utf-8 -*-
from functools import reduce
import logging
from operator import mul
from os.path import join as pjoin
import numpy as np
from africanus.model.coherency.conversion import (_element_indices_and_shape,
CONVERT_DOCS,
MissingConversionInputs)
from africanus.util.code import memoize_on_key, format_code
from africanus.util.cuda import cuda_type, grids
from africanus.util.jinja2 import jinja_env
from africanus.util.requirements import requires_optional
try:
import cupy as cp
from cupy.cuda.compiler import CompileException
except ImportError as e:
opt_import_error = e
else:
opt_import_error = None
log = logging.getLogger(__name__)
stokes_conv = {
'RR': {('I', 'V'): ("complex", "make_{{out_type}}({{I}} + {{V}}, 0)")},
'RL': {('Q', 'U'): ("complex", "make_{{out_type}}({{Q}}, {{U}})")},
'LR': {('Q', 'U'): ("complex", "make_{{out_type}}({{Q}}, -{{U}})")},
'LL': {('I', 'V'): ("complex", "make_{{out_type}}({{I}} - {{V}}, 0)")},
'XX': {('I', 'Q'): ("complex", "make_{{out_type}}({{I}} + {{Q}}, 0)")},
'XY': {('U', 'V'): ("complex", "make_{{out_type}}({{U}}, {{V}})")},
'YX': {('U', 'V'): ("complex", "make_{{out_type}}({{U}}, -{{V}})")},
'YY': {('I', 'Q'): ("complex", "make_{{out_type}}({{I}} - {{Q}}, 0)")},
'I': {('XX', 'YY'): ("real", "(({{XX}}.x + {{YY}}.x) / 2)"),
('RR', 'LL'): ("real", "(({{RR}}.x + {{LL}}.x) / 2)")},
'Q': {('XX', 'YY'): ("real", "(({{XX}}.x - {{YY}}.x) / 2)"),
('RL', 'LR'): ("real", "(({{RL}}.x + {{LR}}.x) / 2)")},
'U': {('XY', 'YX'): ("real", "(({{XY}}.x + {{YX}}.x) / 2)"),
('RL', 'LR'): ("real", "(({{RL}}.y - {{LR}}.y) / 2)")},
'V': {('XY', 'YX'): ("real", "(({{XY}}.y - {{YX}}.y) / 2)"),
('RR', 'LL'): ("real", "(({{RR}}.x - {{LL}}.x) / 2)")},
}
def stokes_convert_setup(input, input_schema, output_schema):
input_indices, input_shape = _element_indices_and_shape(input_schema)
output_indices, output_shape = _element_indices_and_shape(output_schema)
if input.shape[-len(input_shape):] != input_shape:
raise ValueError("Last dimension of input doesn't match input schema")
mapping = []
dtypes = []
# Figure out how to produce an output from available inputs
for okey, out_idx in output_indices.items():
try:
deps = stokes_conv[okey]
except KeyError:
raise ValueError("Unknown output '%s'. Known types '%s'"
% (okey, stokes_conv.keys()))
found_conv = False
# Find a mapping for which we have inputs
for (c1, c2), (dtype, fn) in deps.items():
# Get indices for both correlations
try:
c1_idx = input_indices[c1]
except KeyError:
continue
try:
c2_idx = input_indices[c2]
except KeyError:
continue
found_conv = True
dtypes.append(dtype)
mapping.append(((c1, c1_idx), (c2, c2_idx), out_idx, fn))
break
# We must find a conversion
if not found_conv:
raise MissingConversionInputs("None of the supplied inputs '%s' "
"can produce output '%s'. It can be "
"produced by the following "
"combinations '%s'." % (
input_schema,
okey, deps.keys()))
# Output types must be all "real" or all "complex"
if not all(dtypes[0] == dt for dt in dtypes[1:]):
raise ValueError("Output data types differ %s" % dtypes)
return mapping, input_shape, output_shape, dtypes[0]
def schema_to_tuple(schema):
if isinstance(schema, (tuple, list)):
return tuple(schema_to_tuple(s) for s in schema)
else:
return schema
def _key_fn(inputs, input_schema, output_schema):
return (inputs.dtype,
schema_to_tuple(input_schema),
schema_to_tuple(output_schema))
_TEMPLATE_PATH = pjoin("model", "coherency", "cuda", "conversion.cu.j2")
@memoize_on_key(_key_fn)
def _generate_kernel(inputs, input_schema, output_schema):
mapping, in_shape, out_shape, out_dtype = stokes_convert_setup(
inputs,
input_schema,
output_schema)
# Flatten input and output shapes
# Check that number elements are the same
in_elems = reduce(mul, in_shape, 1)
out_elems = reduce(mul, out_shape, 1)
if in_elems != out_elems:
raise ValueError("Number of input_schema elements %s "
"and output schema elements %s "
"must match for CUDA kernel." %
(in_shape, out_shape))
# Infer the output data type
if out_dtype == "real":
if np.iscomplexobj(inputs):
out_dtype = inputs.real.dtype
else:
out_dtype = inputs.dtype
elif out_dtype == "complex":
if np.iscomplexobj(inputs):
out_dtype = inputs.dtype
else:
out_dtype = np.result_type(inputs.dtype, np.complex64)
else:
raise ValueError("Invalid setup dtype %s" % out_dtype)
cuda_out_dtype = cuda_type(out_dtype)
assign_exprs = []
# Render the assignment expression for each element
for (c1, c1i), (c2, c2i), outi, template_fn in mapping:
# Flattened indices
flat_outi = np.ravel_multi_index(outi, out_shape)
render = jinja_env.from_string(template_fn).render
kwargs = {c1: "in[%d]" % np.ravel_multi_index(c1i, in_shape),
c2: "in[%d]" % np.ravel_multi_index(c2i, in_shape),
"out_type": cuda_out_dtype}
expr_str = render(**kwargs)
assign_exprs.append("out[%d] = %s;" % (flat_outi, expr_str))
# Now render the main template
render = jinja_env.get_template(_TEMPLATE_PATH).render
name = "stokes_convert"
code = render(kernel_name=name,
input_type=cuda_type(inputs.dtype),
output_type=cuda_type(out_dtype),
assign_exprs=assign_exprs,
elements=in_elems)
# cuda block, flatten non-schema dims into a single source dim
blockdimx = 512
block = (blockdimx, 1, 1)
return (cp.RawKernel(code, name), block, in_shape, out_shape, out_dtype)
[docs]@requires_optional('cupy', opt_import_error)
def convert(inputs, input_schema, output_schema):
(kernel, block,
in_shape, out_shape, dtype) = _generate_kernel(inputs,
input_schema,
output_schema)
# Flatten non-schema input dimensions,
# from inspection of the cupy reshape code,
# this incurs a copy when inputs is non-contiguous
nsrc = reduce(mul, inputs.shape[:-len(in_shape)], 1)
nelems = reduce(mul, in_shape, 1)
rinputs = inputs.reshape(nsrc, nelems)
assert rinputs.flags.c_contiguous
grid = grids((nsrc, 1, 1), block)
outputs = cp.empty(shape=rinputs.shape, dtype=dtype)
try:
kernel(grid, block, (rinputs, outputs))
except CompileException:
log.exception(format_code(kernel.code))
raise
shape = inputs.shape[:-len(in_shape)] + out_shape
outputs = outputs.reshape(shape)
assert outputs.flags.c_contiguous
return outputs
try:
convert.__doc__ = CONVERT_DOCS.substitute(
array_type=":class:`cupy.ndarray`")
except AttributeError:
pass