# -*- coding: utf-8 -*-
from numba import types
import numpy as np
from africanus.util.numba import overload, JIT_OPTIONS, njit
from africanus.util.docs import DocstringTemplate
def numpy_spectral_model(stokes, spi, ref_freq, frequency, base):
out_shape = (stokes.shape[0], frequency.shape[0]) + stokes.shape[1:]
# Add in missing pol dimensions
if stokes.ndim == 1:
stokes = stokes[:, None]
if spi.ndim == 2:
spi = spi[:, :, None]
npol = spi.shape[2]
if isinstance(base, list):
base = base + [base[-1]] * (npol - len(base))
else:
base = [base] * npol
spi_exps = np.arange(1, spi.shape[1] + 1)
spectral_model = np.empty(
(stokes.shape[0], frequency.shape[0], npol), dtype=stokes.dtype
)
spectral_model[:, :, :] = stokes[:, None, :]
for p, b in enumerate(base):
if b in ("std", 0):
freq_ratio = frequency[None, :] / ref_freq[:, None]
term = freq_ratio[:, None, :] ** spi[:, :, p, None]
spectral_model[:, :, p] *= term.prod(axis=1)
elif b in ("log", 1):
freq_ratio = np.log(frequency[None, :] / ref_freq[:, None])
term = freq_ratio[:, None, :] ** spi_exps[None, :, None]
term = spi[:, :, p, None] * term
spectral_model[:, :, p] = stokes[:, p, None] * np.exp(term.sum(axis=1))
elif b in ("log10", 2):
freq_ratio = np.log10(frequency[None, :] / ref_freq[:, None])
term = freq_ratio[:, None, :] ** spi_exps[None, :, None]
term = spi[:, :, p, None] * term
spectral_model[:, :, p] = stokes[:, p, None] * 10 ** (term.sum(axis=1))
else:
raise ValueError("Invalid base %s" % base)
return spectral_model.reshape(out_shape)
def pol_getter_factory(npoldims):
if npoldims == 0:
def impl(pol_shape):
return 1
else:
def impl(pol_shape):
npols = 1
for c in pol_shape:
npols *= c
return npols
return njit(nogil=True, cache=True)(impl)
def promote_base_factory(is_base_list):
if is_base_list:
def impl(base, npol):
return base + [base[-1]] * (npol - len(base))
else:
def impl(base, npol):
return [base] * npol
return njit(nogil=True, cache=True)(impl)
def add_pol_dim_factory(have_pol_dim):
if have_pol_dim:
def impl(array):
return array
else:
def impl(array):
return array.reshape(array.shape + (1,))
return njit(nogil=True, cache=True)(impl)
[docs]
@njit(**JIT_OPTIONS)
def spectral_model(stokes, spi, ref_freq, frequency, base=0):
return spectral_model_impl(stokes, spi, ref_freq, frequency, base=base)
def spectral_model_impl(stokes, spi, ref_freq, frequency, base=0):
raise NotImplementedError
@overload(spectral_model_impl, jit_options=JIT_OPTIONS)
def nb_spectral_model(stokes, spi, ref_freq, frequency, base=0):
arg_dtypes = tuple(
np.dtype(a.dtype.name) for a in (stokes, spi, ref_freq, frequency)
)
dtype = np.result_type(*arg_dtypes)
if isinstance(base, types.containers.List):
is_base_list = True
base = base.dtype
else:
is_base_list = False
promote_base = promote_base_factory(is_base_list)
if isinstance(base, types.scalars.Integer):
def is_std(base):
return base == 0
def is_log(base):
return base == 1
def is_log10(base):
return base == 2
elif isinstance(base, types.misc.UnicodeType):
def is_std(base):
return base == "std"
def is_log(base):
return base == "log"
def is_log10(base):
return base == "log10"
else:
raise TypeError("base '%s' should be a string or integer" % base)
is_std = njit(nogil=True, cache=True)(is_std)
is_log = njit(nogil=True, cache=True)(is_log)
is_log10 = njit(nogil=True, cache=True)(is_log10)
npoldims = stokes.ndim - 1
pol_get_fn = pol_getter_factory(npoldims)
add_pol_dim = add_pol_dim_factory(npoldims > 0)
if spi.ndim - 2 != npoldims:
raise ValueError("Dimensions on stokes and spi don't agree")
def impl(stokes, spi, ref_freq, frequency, base=0):
nsrc = stokes.shape[0]
nchan = frequency.shape[0]
nspi = spi.shape[1]
npol = pol_get_fn(stokes.shape[1:])
if npol != pol_get_fn(spi.shape[2:]):
raise ValueError("Correlations on stokes and spi don't agree")
# Promote base argument to a per-polarisation list
list_base = promote_base(base, npol)
# Reshape adding a polarisation dimension if necessary
estokes = add_pol_dim(stokes)
espi = add_pol_dim(spi)
spectral_model = np.empty((nsrc, nchan, npol), dtype=dtype)
# TODO(sjperkins)
# Polarisation + associated base on the outer loop
# The output cache patterns could be improved.
for p, b in enumerate(list_base[:npol]):
if is_std(b):
for s in range(nsrc):
rf = ref_freq[s]
for f in range(nchan):
freq_ratio = frequency[f] / rf
spec_model = estokes[s, p]
for si in range(0, nspi):
term = freq_ratio ** espi[s, si, p]
spec_model *= term
spectral_model[s, f, p] = spec_model
elif is_log(b):
for s in range(nsrc):
rf = ref_freq[s]
for f in range(nchan):
freq_ratio = np.log(frequency[f] / rf)
spec_model = 0
for si in range(0, nspi):
term = espi[s, si, p] * freq_ratio ** (si + 1)
spec_model += term
spectral_model[s, f, p] = estokes[s, p] * np.exp(spec_model)
elif is_log10(b):
for s in range(nsrc):
rf = ref_freq[s]
for f in range(nchan):
freq_ratio = np.log10(frequency[f] / rf)
spec_model = 0
for si in range(0, nspi):
term = espi[s, si, p] * freq_ratio ** (si + 1)
spec_model += term
spectral_model[s, f, p] = estokes[s, p] * 10**spec_model
else:
raise ValueError("Invalid base")
out_shape = (stokes.shape[0], frequency.shape[0]) + stokes.shape[1:]
return spectral_model.reshape(out_shape)
return impl
SPECTRAL_MODEL_DOC = DocstringTemplate(
r"""
Compute a spectral model, per polarisation.
.. math::
:nowrap:
\begin{eqnarray}
I(\lambda) & = & I_0 \prod_{i=1} (\lambda / \lambda_0)^{\alpha_{i}} \\
\ln( I(\lambda) ) & = & \sum_{i=0} \alpha_{i}
\ln (\lambda / \lambda_0)^i
\, \textrm{where} \, \alpha_0 = \ln I_0 \\
\log_{10}( I(\lambda) ) & = & \sum_{i=0} \alpha_{i}
\log_{10} (\lambda / \lambda_0)^i
\, \textrm{where} \, \alpha_0 = \log_{10} I_0 \\
\end{eqnarray}
Parameters
----------
stokes : $(array_type)
Stokes parameters of shape :code:`(source,)` or :code:`(source, pol)`.
If a ``pol`` dimension is present, then it must also be present on ``spi``.
spi : $(array_type)
Spectral index of shape :code:`(source, spi-comps)`
or :code:`(source, spi-comps, pol)`.
ref_freq : $(array_type)
Reference frequencies of shape :code:`(source,)`
frequencies : $(array_type)
Frequencies of shape :code:`(chan,)`
base : {"std", "log", "log10"} or {0, 1, 2} or list.
string or corresponding enumeration specifying the polynomial base.
Defaults to 0.
If a list is provided, a polynomial base can be specified for each
stokes parameter or polarisation in the ``pol`` dimension.
string specification of the base is only supported in python 3.
while the corresponding integer enumerations are supported
on all python versions.
Returns
-------
spectral_model : $(array_type)
Spectral Model of shape :code:`(source, chan)` or
:code:`(source, chan, pol)`.
"""
)
try:
spectral_model.__doc__ = SPECTRAL_MODEL_DOC.substitute(
array_type=":class:`numpy.ndarray`"
)
except AttributeError:
pass