Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional fast sky-based calibration to abscal module #892

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
278 changes: 277 additions & 1 deletion hera_cal/abscal.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from . import redcal
from . import io
from . import apply_cal
from .datacontainer import DataContainer
from .datacontainer import DataContainer, RedDataContainer
from .utils import echo, polnum2str, polstr2num, reverse_bl, split_pol, split_bl, join_bl, join_pol

PHASE_SLOPE_SOLVERS = ['linfit', 'dft', 'ndim_fft'] # list of valid solvers for global_phase_slope_logcal
Expand Down Expand Up @@ -727,6 +727,282 @@ def delay_lincal(model, data, wgts=None, refant=None, df=9.765625e4, f0=0., solv

return fit

@jax.jit
def _stefcal_optimizer(data_matrix, model_matrix, weights, tol=1e-10, maxiter=1000, stepsize=0.5):
"""
Function to run stefcal optimization

Parameters:
----------
data_matrix: np.ndarray
Data matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that
are not flagged. Tsecond two axes are the time and frequency axes.
model_matrix: np.ndarray
Model matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that
are not flagged. Tsecond two axes are the time and frequency axes.
weights: np.ndarray
Weights matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that
are not flagged. Tsecond two axes are the time and frequency axes.
tol: float, optional, default=1e-10
Tolerance for convergence criterea of subsequent iterations of stefcal
maxiter: int, optional, default=1000
Maximum number of iterations to run the calibration
stepsize: float, optional, default=0.5
Step size for the optimization. Must be between 0 and 1. A step size of 1 will take the full step in the direction
of the new gains, while a step size of 0 will take no step in the direction of the new gains.

Returns:
-------
gains: np.ndarray
Complex antenna gain solutions of shape (nants)
niters: int
Number of iterations performed by the optimizer
conv_crit: float
Convergence criterea of the final iteration of the optimizer
"""
def inner_function(args):
"""
Main optimization loop
"""
# Unpack arguments
gains, i, _ = args

# Copy gains
g_old = jnp.copy(gains)

# Compute the model-gain product
zg = gains[:, None] * model_matrix
zgw = zg * weights

# Compute gains
gains = jnp.sum(jnp.conj(data_matrix) * zgw, axis=(0)) / jnp.sum(jnp.conj(zgw) * zg, axis=(0))

# Set gains to 1 if they are nan
gains = jnp.where(jnp.isnan(gains), 1, gains)

# Average gains on even iterations and take a full step on odd iterations (ensures convergence)
step = ((i + 1) % 2) * stepsize
gains = gains * (1 - step) + g_old * step

# Compute convergence criterea
tau = jnp.sqrt(jnp.sum(jnp.abs(gains - g_old) ** 2)) / jnp.sqrt(jnp.sum(jnp.abs(gains)**2))
return gains, i + 1, tau

def conditional_function(args):
"""
Conditional function to check to convergence criterea
"""
_, i, tau = args
return (tau > tol) & (i < maxiter)

nants = data_matrix.shape[0]
gains = jnp.ones(nants, dtype=complex)

return jax.lax.while_loop(conditional_function, inner_function, (gains, 0, jnp.inf))

def _build_model_matrices(data, model, weights, data_bls, data_to_model_bls_map, ant_flags={}):
"""
Function to build data, model, and weights matrices for sky_calibration optimization.


Parameters:
----------
data: DataContainer
Visibility data of measurements. Keys are antenna pair + pol tuples (must match model), values are
complex ndarray visibilities matching shape of model
model: DataContainer or RedDataContainer
Model visibilities. Keys are antenna pair + pol tuples (must match data), values are complex ndarray
weights: DataContainer
Dictionary of real-valued data weights. Keys are antenna pair + pol tuples (must match data).
data_bls: list
List of baseline tuples of the form (ant1, ant2, pol) to include in the data, model, and weights matrices
data_to_model_bls_map: dict
Dictionary mapping baseline tuples in the data to baseline tuples in the model. Keys are baseline tuples in the
data, values are baseline tuples in the model.
ant_flags: dict, optional, default={}
Dictionary of antenna flags. Keys are antenna + pol tuples, values are boolean. If an antenna is flagged
in this dictionary, it will not be included in the model matrices.

Returns:
-------
data_matrix: np.ndarray
Data matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that
are not flagged. Tsecond two axes are the time and frequency axes.
model_matrix: np.ndarray
Model matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that
are not flagged. The second two axes are the time and frequency axes.
wgts_matrix: np.ndarray
Weights matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that
are not flagged. The second two axes are the time and frequency axes.
map_ants_to_index: dict
Dictionary mapping antennas to indices within the data, model, and wgts matrices
"""
# Get antennas and polarizations from data
pols = np.unique([k[2] for k in data_bls])
ants = np.unique(np.concatenate([k[:2] for k in data_bls]))

# Remove flagged antennas
ants = [ant for ant in ants if not ant_flags.get((ant, 'J' + pols[0]), False)]
nants = len(ants)

# Map antennas to indices in the data, model, and wgts matrices
map_ants_to_index = {ant: ki for ki, ant in enumerate(ants)}

# Number of times and frequencies
ntimes, nfreqs = data[data_bls[0]].shape

# Populate matrices
data_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='complex')
wgts_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='float')
model_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='complex')

for bl in data_bls:
if bl[0] in map_ants_to_index and bl[1] in map_ants_to_index:
m, n = map_ants_to_index[bl[0]], map_ants_to_index[bl[1]]

# Data matrix
data_matrix[m, n] = data[bl]
data_matrix[n, m] = data[bl].conj()

# Weights matrix
wgts_matrix[m, n] = weights[bl]
wgts_matrix[n, m] = wgts_matrix[m, n]

# Model Matrix
model_matrix[m, n] = model[data_to_model_bls_map[bl]]
model_matrix[n, m] = model[data_to_model_bls_map[bl]].conj()

return data_matrix, model_matrix, wgts_matrix, map_ants_to_index

def sky_calibration(
data, model, weights, data_antpos=None, ant_flags={}, tol=1e-10, maxiter=1000, stepsize=0.5,
min_bl_cut=None, max_bl_cut=None, model_antpos=None, model_is_redundant=False, include_autos=False
):
"""
Solve for per-antenna gains using the Stefcal algorithm (Salvini et al. 2014).

Parameters:
----------
data: DataContainer
Visibility data of measurements. Keys are antenna pair + pol tuples (must match model), values are
complex ndarray visibilities matching shape of model
model: DataContainer or RedDataContainer
Model visibilities. Keys are antenna pair + pol tuples (must match data), values are complex ndarray
weights: DataContainer
Dictionary of real-valued data weights. Keys are antenna pair + pol tuples (must match data).
data_antpos: dict, default=None
Dictionary of antenna positions. Keys are antenna numbers, values are antenna position vectors. If not provided,
the antenna positions will be taken from the data dictionary if it has the attribute data_antpos.
ant_flags: dict, optional, default={}
Dictionary of antenna flags. Keys are antenna + pol tuples, values are boolean. If an antenna is flagged
in this dictionary, it will not be included in the fit.
tol: float, optional, default=1e-10
Tolerance for convergence criterea of subsequent iterations of stefcal
maxiter: int, optional, default=1000
Maximum number of iterations to run the calibration
stepsize: float, optional, default=0.5
Step size for the optimization. Must be between 0 and 1. A step size of 1 will take the full step in the direction
of the new gains, while a step size of 0 will take no step in the direction of the new gains.
min_bl_cut: float, optional, default=None
Minimum baseline length to include in the fit. If None, no minimum baseline length cut will be applied.
max_bl_cut: float, optional, default=None
Maximum baseline length to include in the fit. If None, no minimum baseline length cut will be applied.
model_antpos: dict, optional, default=None
Dictionary of antenna positions. Keys are antenna numbers, values are antenna position vectors. If not provided,
it is assumed that the model antpos matches the data antpos
model_is_redundant: bool, optional, default=False
If True, it is assumed that the model is redundant.
include_autos: bool, optional, default=False
If True, include auto-correlations in the model. If False, auto-correlations will be ignored.

Returns:
-------
gains: dict
Dictionary with all unflagged antenna-polarizations as keys and gain waterfall arrays as values
niters: np.ndarray
Number of iterations performed for each time, frequency, and polarization run of the calibration algorithm
conv_crits: np.ndarray
Convergence criterea for each time, frequency, and polarization run of the calibration.
"""
# Get antennas and polarizations from data
keys = list(data.keys())
pols = np.unique([k[2] for k in keys])
ants = np.unique(np.concatenate([k[:2] for k in keys]))

# Get number of times and frequencies
ntimes, nfreqs = data[keys[0]].shape

# Check if the model is a RedDataContainer. If so, we can assume that the model is redundant
if isinstance(model, RedDataContainer):
model_is_redundant = True

# User must provide data_antpos if not in data for baseline matching
assert data_antpos or hasattr(data, "data_antpos"), "data_antpos must be provided if not in data"
if data_antpos is None:
data_antpos = data.data_antpos

# get keys from model and data dictionary
data_bls, _, data_to_model_bl_map = match_baselines(
list(data.keys()), list(model.keys()), data_antpos, model_is_redundant=model_is_redundant,
min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut, include_autos=include_autos, model_antpos=model_antpos
)

# Store gains and metadata
gains = {}
niters = {}
conv_crits = {}

for pol in pols:
# Initialize arrays for gains, niters, and convergence criterea
gain_array = []
niter_array = []
conv_crit_array = []

# Get data baselines
_data_bls = [k for k in data_bls if k[2] == pol]

# Pack data and model into numpy arrays
data_matrix, model_matrix, wgts, map_ants_to_index = _build_model_matrices(
data, model, weights, _data_bls, data_to_model_bl_map, ant_flags=ant_flags
)

for ti in range(ntimes):
_gains = []
_niters = []
_conv_crits= []
for fi in range(nfreqs):
# If the weights are non-zero, run the optimizer. Otherwise, set the gains to 1.
if wgts[..., ti, fi].sum() > 0:
gain, niter, conv_crit = _stefcal_optimizer(
data_matrix[..., ti, fi], model_matrix[..., ti, fi], wgts[..., ti, fi],
tol=tol, maxiter=maxiter, stepsize=stepsize
)
_niters.append(niter)
_conv_crits.append(conv_crit)
_gains.append(gain)
else:
gain = np.ones(data_matrix.shape[0], dtype='complex')
_niters.append(0)
_conv_crits.append(np.nan)
_gains.append(gain)

gain_array.append(_gains)
niter_array.append(_niters)
conv_crit_array.append(_conv_crits)

gain_array = np.array(gain_array)

for k in ants:
if k in map_ants_to_index:
gains[(k, "J" + pol)] = gain_array[..., map_ants_to_index[k]]
else:
gains[(k, "J" + pol)] = np.ones((ntimes, nfreqs), dtype='complex')

niters[pol] = np.array(niter_array)
conv_crits[pol] = np.array(conv_crit_array)

return gains, niters, conv_crits


def delay_slope_lincal(model, data, antpos, wgts=None, refant=None, df=9.765625e4, f0=0.0, medfilt=True,
kernel=(1, 5), assume_2D=True, four_pol=False, edge_cut=0, time_avg=False,
Expand Down
57 changes: 56 additions & 1 deletion hera_cal/tests/test_abscal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from pyuvdata import UVCal, UVData
import warnings
from hera_sim.antpos import hex_array, linear_array
from hera_sim import vis

from .. import io, abscal, redcal, utils, apply_cal
from ..data import DATA_PATH
from ..datacontainer import DataContainer
from ..datacontainer import DataContainer, RedDataContainer
from ..utils import split_pol, reverse_bl, split_bl
from ..apply_cal import calibrate_in_place
from ..flag_utils import synthesize_ant_flags
Expand Down Expand Up @@ -751,6 +752,60 @@ def test_complex_phase_abscal(self):
with pytest.raises(AssertionError):
meta, delta_gains = abscal.complex_phase_abscal(data, model, reds, data_bls, model_bls)

def test_sky_calibration(self):
"""
"""
nfreqs = 10
antpos = hex_array(3, split_core=True, outriggers=0)
reds = redcal.get_reds(antpos)
gains, model_vis, data_vis = vis.sim_red_data(reds, shape=(1, nfreqs))
model_vis = RedDataContainer(model_vis, reds=reds)
data_vis = DataContainer(data_vis)
data_vis.data_antpos = antpos
weights = DataContainer({k: np.ones(data_vis[k].shape, dtype=bool) for k in data_vis})

# Test that the function runs
gains, niter, conv_crit = abscal.sky_calibration(
data_vis, model_vis, weights, maxiter=1000, tol=1e-10, stepsize=0.5
)
assert niter['nn'].shape == conv_crit['nn'].shape == (1, nfreqs)

# Apply calibration with solved gains
data_vis_copy = copy.deepcopy(data_vis)
apply_cal.calibrate_in_place(data_vis_copy, gains)

# Test that the data is calibrated properly
for k in data_vis:
np.testing.assert_array_almost_equal(data_vis_copy[k], model_vis[k])

# Test the function with antenna flags
ant_flags = {(0, 'Jnn'): True}
gains, niter, conv_crit = abscal.sky_calibration(
data_vis, model_vis, weights, antpos, ant_flags=ant_flags, maxiter=1000, tol=1e-10, stepsize=0.5
)

# Check that the flagged antenna has unity gain
np.testing.assert_allclose(gains[(0, 'Jnn')], np.ones_like(gains[(0, 'Jnn')]))

# Test the function with DataContainer
model_vis_copy = DataContainer(copy.deepcopy(model_vis))
for k in data_vis:
model_vis_copy[k] = model_vis[k]
# Run calibration
gains, niter, conv_crit = abscal.sky_calibration(
data_vis, model_vis_copy, weights, antpos, maxiter=1000, tol=1e-10, stepsize=0.5
)
for k in data_vis:
np.testing.assert_array_almost_equal(data_vis_copy[k], model_vis[k])

# Test the function flagging a frequency and time
for k in weights:
weights[k][0, 0] = 0.0

gains, niter, conv_crit = abscal.sky_calibration(
data_vis, model_vis, weights, antpos, maxiter=1000, tol=1e-10, stepsize=0.5
)
assert np.isclose(gains[(0, 'Jnn')][0, 0], 1.0 + 0.0j)

@pytest.mark.filterwarnings("ignore:The default for the `center` keyword has changed")
@pytest.mark.filterwarnings("ignore:invalid value encountered in true_divide")
Expand Down
Loading