Source code for africanus.experimental.rime.fused.terms.cube_dde

from collections import namedtuple

from numba.core import cgutils, types
from numba.extending import intrinsic
from numba.cpython.unsafe.tuple import tuple_setitem
import numpy as np

from africanus.experimental.rime.fused.terms.core import Term


def zero_vis_factory(ncorr):
    @intrinsic
    def zero_vis(typingctx, value):
        sig = types.Tuple([value] * ncorr)(value)

        def codegen(context, builder, signature, args):
            llvm_ret_type = context.get_value_type(signature.return_type)
            tup = cgutils.get_null_value(llvm_ret_type)

            for i in range(ncorr):
                tup = builder.insert_value(tup, args[0], i)

            return tup

        return sig, codegen

    return zero_vis


BeamInfo = namedtuple(
    "BeamInfo", ["lscale", "mscale", "lmaxi", "mmaxi", "lmaxf", "mmaxf"]
)


[docs] class BeamCubeDDE(Term): """Voxel Beam Cube Term""" def __init__(self, configuration, corrs): if configuration not in {"left", "right"}: raise ValueError( f"BeamCubeDDE configuration must be" f"either 'left' or 'right'. " f"Got {configuration}" ) super().__init__(configuration) self.corrs = corrs def dask_schema( self, beam, beam_lm_extents, beam_freq_map, lm, beam_parangle, chan_freq, beam_point_errors=None, beam_antenna_scaling=None, ): return { "beam": ("beam_lw", "beam_mh", "beam_nud", "corr"), "beam_lm_extents": ("lm_ext", "lm_ext_comp"), "beam_freq_map": ("beam_nud",), "lm": ("source", "lm"), "chan_freq": ("chan",), } def init_fields( self, typingctx, beam, beam_lm_extents, beam_freq_map, lm, beam_parangle, chan_freq, beam_point_errors=None, beam_antenna_scaling=None, ): ncorr = len(self.corrs) ex_dtype = beam_lm_extents.dtype beam_info_types = [ex_dtype] * 2 + [types.int64] * 2 + [types.float64] * 2 beam_info_type = types.NamedTuple(beam_info_types, BeamInfo) fields = [ ("beam_freq_data", chan_freq.copy(ndim=2)), ("beam_info", beam_info_type), ] def beam( beam, beam_lm_extents, beam_freq_map, lm, beam_parangle, chan_freq, beam_point_errors=None, beam_antenna_scaling=None, ): if beam.shape[3] != ncorr: raise ValueError("Beam correlations don't match specification corrs") freq_data = np.empty((chan_freq.shape[0], 3), chan_freq.dtype) beam_nud = beam_freq_map.shape[0] beam_lw, beam_mh, beam_nud = beam.shape[:3] if beam_lw < 2 or beam_mh < 2 or beam_nud < 2: raise ValueError("beam_lw, beam_mh and beam_nud must be >= 2") for f in range(chan_freq.shape[0]): freq = chan_freq[f] lower = 0 upper = beam_nud - 1 while lower <= upper: mid = lower + (upper - lower) // 2 beam_freq = beam_freq_map[mid] if beam_freq < freq: lower = mid + 1 elif beam_freq > freq: upper = mid - 1 else: lower = mid break # This handles the lower <= upper in the while loop lower = min(lower, upper) upper = lower + 1 # Set up scaling, lower weight, lower grid pos if lower == -1: freq_data[f, 0] = freq / beam_freq_map[0] freq_data[f, 1] = 1.0 freq_data[f, 2] = 0 elif upper == beam_nud: freq_data[f, 0] = freq / beam_freq_map[beam_nud - 1] freq_data[f, 1] = 0.0 freq_data[f, 2] = beam_nud - 2 else: freq_data[f, 0] = 1.0 freq_low = beam_freq_map[lower] freq_high = beam_freq_map[upper] freq_diff = freq_high - freq_low freq_data[f, 1] = (freq_high - freq) / freq_diff freq_data[f, 2] = lower # Beam Extents lower_l, upper_l = beam_lm_extents[0] lower_m, upper_m = beam_lm_extents[1] # Maximum l and m indices in float and int lmaxf = ex_dtype(beam_lw - types.int64(1)) mmaxf = ex_dtype(beam_mh - types.int64(1)) lmaxi = beam_lw - types.int64(1) mmaxi = beam_mh - types.int64(1) lscale = lmaxf / (upper_l - lower_l) mscale = mmaxf / (upper_m - lower_m) beam_info = BeamInfo(lscale, mscale, lmaxi, mmaxi, lmaxf, mmaxf) return freq_data, beam_info return fields, beam def sampler(self): left = self.configuration == "left" ncorr = len(self.corrs) zero_vis = zero_vis_factory(ncorr) def cube_dde(state, s, r, t, f1, f2, a1, a2, c): a = state.antenna1_index[r] if left else state.antenna2_index[r] feed = state.feed1_index[r] if left else state.feed2_index[r] sin_pa = state.beam_parangle[t, feed, a, 0] cos_pa = state.beam_parangle[t, feed, a, 1] l = state.lm[s, 0] # noqa m = state.lm[s, 1] # Unpack frequency data freq_scale = state.beam_freq_data[c, 0] # lower and upper frequency weights nud = state.beam_freq_data[c, 1] inv_nud = state.beam_freq_data.dtype.type(1.0) - nud # lower and upper frequency grid position gc0 = np.int32(state.beam_freq_data[c, 2]) gc1 = gc0 + np.int32(1) # Apply any frequency scaling sl = l * freq_scale sm = m * freq_scale # Add pointing errors # tl = sl + point_errors[t, a, c, 0] # tm = sm + point_errors[t, a, c, 1] tl = sl tm = sm # Rotate lm coordinate angle vl = tl * cos_pa - tm * sin_pa vm = tl * sin_pa + tm * cos_pa # Scale by antenna scaling # vl *= antenna_scaling[a, f, 0] # vm *= antenna_scaling[a, f, 1] # Beam Extents lower_l, upper_l = state.beam_lm_extents[0] lower_m, upper_m = state.beam_lm_extents[1] # Shift into the cube coordinate system vl = state.beam_info.lscale * (vl - lower_l) vm = state.beam_info.mscale * (vm - lower_m) # Clamp the coordinates to the edges of the cube vl = max(0.0, min(vl, state.beam_info.lmaxf)) vm = max(0.0, min(vm, state.beam_info.mmaxf)) # Snap to the lower grid coordinates gl0 = np.int32(np.floor(vl)) gm0 = np.int32(np.floor(vm)) # Snap to the upper grid coordinates gl1 = min(gl0 + np.int32(1), state.beam_info.lmaxi) gm1 = min(gm0 + np.int32(1), state.beam_info.mmaxi) # Difference between grid and offset coordinates ld = vl - gl0 md = vm - gm0 corr_sum = zero_vis(state.beam.dtype.type(0)) absc_sum = zero_vis(state.beam.real.dtype.type(0)) # Lower cube weight = (1.0 - ld) * (1.0 - md) * nud for co in range(ncorr): value = state.beam[gl0, gm0, gc0, co] absc_sum = tuple_setitem( absc_sum, co, weight * np.abs(value) + absc_sum[co] ) corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) weight = ld * (1.0 - md) * nud for co in range(ncorr): value = state.beam[gl1, gm0, gc0, co] absc_sum = tuple_setitem( absc_sum, co, weight * np.abs(value) + absc_sum[co] ) corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) weight = (1.0 - ld) * md * nud for co in range(ncorr): value = state.beam[gl0, gm1, gc0, co] absc_sum = tuple_setitem( absc_sum, co, weight * np.abs(value) + absc_sum[co] ) corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) weight = ld * md * nud for co in range(ncorr): value = state.beam[gl1, gm1, gc0, co] absc_sum = tuple_setitem( absc_sum, co, weight * np.abs(value) + absc_sum[co] ) corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) # Upper cube weight = (1.0 - ld) * (1.0 - md) * inv_nud for co in range(ncorr): value = state.beam[gl0, gm0, gc1, co] absc_sum = tuple_setitem( absc_sum, co, weight * np.abs(value) + absc_sum[co] ) corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) weight = ld * (1.0 - md) * inv_nud for co in range(ncorr): value = state.beam[gl1, gm0, gc1, co] absc_sum = tuple_setitem( absc_sum, co, weight * np.abs(value) + absc_sum[co] ) corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) weight = (1.0 - ld) * md * inv_nud for co in range(ncorr): value = state.beam[gl0, gm1, gc1, co] absc_sum = tuple_setitem( absc_sum, co, weight * np.abs(value) + absc_sum[co] ) corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) weight = ld * md * inv_nud for co in range(ncorr): value = state.beam[gl1, gm1, gc1, co] absc_sum = tuple_setitem( absc_sum, co, weight * np.abs(value) + absc_sum[co] ) corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) for co in range(ncorr): div = np.abs(corr_sum[co]) value = corr_sum[co] * absc_sum[co] if div != 0.0: value /= div corr_sum = tuple_setitem(corr_sum, co, value) return corr_sum return cube_dde