diff --git a/hera_cal/abscal.py b/hera_cal/abscal.py index 6700314d4..6c49aef37 100644 --- a/hera_cal/abscal.py +++ b/hera_cal/abscal.py @@ -47,7 +47,7 @@ from . import redcal from . import io from . import apply_cal -from .datacontainer import DataContainer +from .datacontainer import DataContainer, RedDataContainer from .utils import echo, polnum2str, polstr2num, reverse_bl, split_pol, split_bl, join_bl, join_pol PHASE_SLOPE_SOLVERS = ['linfit', 'dft', 'ndim_fft'] # list of valid solvers for global_phase_slope_logcal @@ -727,6 +727,282 @@ def delay_lincal(model, data, wgts=None, refant=None, df=9.765625e4, f0=0., solv return fit +@jax.jit +def _stefcal_optimizer(data_matrix, model_matrix, weights, tol=1e-10, maxiter=1000, stepsize=0.5): + """ + Function to run stefcal optimization + + Parameters: + ---------- + data_matrix: np.ndarray + Data matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. Tsecond two axes are the time and frequency axes. + model_matrix: np.ndarray + Model matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. Tsecond two axes are the time and frequency axes. + weights: np.ndarray + Weights matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. Tsecond two axes are the time and frequency axes. + tol: float, optional, default=1e-10 + Tolerance for convergence criterea of subsequent iterations of stefcal + maxiter: int, optional, default=1000 + Maximum number of iterations to run the calibration + stepsize: float, optional, default=0.5 + Step size for the optimization. Must be between 0 and 1. A step size of 1 will take the full step in the direction + of the new gains, while a step size of 0 will take no step in the direction of the new gains. + + Returns: + ------- + gains: np.ndarray + Complex antenna gain solutions of shape (nants) + niters: int + Number of iterations performed by the optimizer + conv_crit: float + Convergence criterea of the final iteration of the optimizer + """ + def inner_function(args): + """ + Main optimization loop + """ + # Unpack arguments + gains, i, _ = args + + # Copy gains + g_old = jnp.copy(gains) + + # Compute the model-gain product + zg = gains[:, None] * model_matrix + zgw = zg * weights + + # Compute gains + gains = jnp.sum(jnp.conj(data_matrix) * zgw, axis=(0)) / jnp.sum(jnp.conj(zgw) * zg, axis=(0)) + + # Set gains to 1 if they are nan + gains = jnp.where(jnp.isnan(gains), 1, gains) + + # Average gains on even iterations and take a full step on odd iterations (ensures convergence) + step = ((i + 1) % 2) * stepsize + gains = gains * (1 - step) + g_old * step + + # Compute convergence criterea + tau = jnp.sqrt(jnp.sum(jnp.abs(gains - g_old) ** 2)) / jnp.sqrt(jnp.sum(jnp.abs(gains)**2)) + return gains, i + 1, tau + + def conditional_function(args): + """ + Conditional function to check to convergence criterea + """ + _, i, tau = args + return (tau > tol) & (i < maxiter) + + nants = data_matrix.shape[0] + gains = jnp.ones(nants, dtype=complex) + + return jax.lax.while_loop(conditional_function, inner_function, (gains, 0, jnp.inf)) + +def _build_model_matrices(data, model, weights, data_bls, data_to_model_bls_map, ant_flags={}): + """ + Function to build data, model, and weights matrices for sky_calibration optimization. + + + Parameters: + ---------- + data: DataContainer + Visibility data of measurements. Keys are antenna pair + pol tuples (must match model), values are + complex ndarray visibilities matching shape of model + model: DataContainer or RedDataContainer + Model visibilities. Keys are antenna pair + pol tuples (must match data), values are complex ndarray + weights: DataContainer + Dictionary of real-valued data weights. Keys are antenna pair + pol tuples (must match data). + data_bls: list + List of baseline tuples of the form (ant1, ant2, pol) to include in the data, model, and weights matrices + data_to_model_bls_map: dict + Dictionary mapping baseline tuples in the data to baseline tuples in the model. Keys are baseline tuples in the + data, values are baseline tuples in the model. + ant_flags: dict, optional, default={} + Dictionary of antenna flags. Keys are antenna + pol tuples, values are boolean. If an antenna is flagged + in this dictionary, it will not be included in the model matrices. + + Returns: + ------- + data_matrix: np.ndarray + Data matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. Tsecond two axes are the time and frequency axes. + model_matrix: np.ndarray + Model matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. The second two axes are the time and frequency axes. + wgts_matrix: np.ndarray + Weights matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. The second two axes are the time and frequency axes. + map_ants_to_index: dict + Dictionary mapping antennas to indices within the data, model, and wgts matrices + """ + # Get antennas and polarizations from data + pols = np.unique([k[2] for k in data_bls]) + ants = np.unique(np.concatenate([k[:2] for k in data_bls])) + + # Remove flagged antennas + ants = [ant for ant in ants if not ant_flags.get((ant, 'J' + pols[0]), False)] + nants = len(ants) + + # Map antennas to indices in the data, model, and wgts matrices + map_ants_to_index = {ant: ki for ki, ant in enumerate(ants)} + + # Number of times and frequencies + ntimes, nfreqs = data[data_bls[0]].shape + + # Populate matrices + data_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='complex') + wgts_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='float') + model_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='complex') + + for bl in data_bls: + if bl[0] in map_ants_to_index and bl[1] in map_ants_to_index: + m, n = map_ants_to_index[bl[0]], map_ants_to_index[bl[1]] + + # Data matrix + data_matrix[m, n] = data[bl] + data_matrix[n, m] = data[bl].conj() + + # Weights matrix + wgts_matrix[m, n] = weights[bl] + wgts_matrix[n, m] = wgts_matrix[m, n] + + # Model Matrix + model_matrix[m, n] = model[data_to_model_bls_map[bl]] + model_matrix[n, m] = model[data_to_model_bls_map[bl]].conj() + + return data_matrix, model_matrix, wgts_matrix, map_ants_to_index + +def sky_calibration( + data, model, weights, data_antpos=None, ant_flags={}, tol=1e-10, maxiter=1000, stepsize=0.5, + min_bl_cut=None, max_bl_cut=None, model_antpos=None, model_is_redundant=False, include_autos=False + ): + """ + Solve for per-antenna gains using the Stefcal algorithm (Salvini et al. 2014). + + Parameters: + ---------- + data: DataContainer + Visibility data of measurements. Keys are antenna pair + pol tuples (must match model), values are + complex ndarray visibilities matching shape of model + model: DataContainer or RedDataContainer + Model visibilities. Keys are antenna pair + pol tuples (must match data), values are complex ndarray + weights: DataContainer + Dictionary of real-valued data weights. Keys are antenna pair + pol tuples (must match data). + data_antpos: dict, default=None + Dictionary of antenna positions. Keys are antenna numbers, values are antenna position vectors. If not provided, + the antenna positions will be taken from the data dictionary if it has the attribute data_antpos. + ant_flags: dict, optional, default={} + Dictionary of antenna flags. Keys are antenna + pol tuples, values are boolean. If an antenna is flagged + in this dictionary, it will not be included in the fit. + tol: float, optional, default=1e-10 + Tolerance for convergence criterea of subsequent iterations of stefcal + maxiter: int, optional, default=1000 + Maximum number of iterations to run the calibration + stepsize: float, optional, default=0.5 + Step size for the optimization. Must be between 0 and 1. A step size of 1 will take the full step in the direction + of the new gains, while a step size of 0 will take no step in the direction of the new gains. + min_bl_cut: float, optional, default=None + Minimum baseline length to include in the fit. If None, no minimum baseline length cut will be applied. + max_bl_cut: float, optional, default=None + Maximum baseline length to include in the fit. If None, no minimum baseline length cut will be applied. + model_antpos: dict, optional, default=None + Dictionary of antenna positions. Keys are antenna numbers, values are antenna position vectors. If not provided, + it is assumed that the model antpos matches the data antpos + model_is_redundant: bool, optional, default=False + If True, it is assumed that the model is redundant. + include_autos: bool, optional, default=False + If True, include auto-correlations in the model. If False, auto-correlations will be ignored. + + Returns: + ------- + gains: dict + Dictionary with all unflagged antenna-polarizations as keys and gain waterfall arrays as values + niters: np.ndarray + Number of iterations performed for each time, frequency, and polarization run of the calibration algorithm + conv_crits: np.ndarray + Convergence criterea for each time, frequency, and polarization run of the calibration. + """ + # Get antennas and polarizations from data + keys = list(data.keys()) + pols = np.unique([k[2] for k in keys]) + ants = np.unique(np.concatenate([k[:2] for k in keys])) + + # Get number of times and frequencies + ntimes, nfreqs = data[keys[0]].shape + + # Check if the model is a RedDataContainer. If so, we can assume that the model is redundant + if isinstance(model, RedDataContainer): + model_is_redundant = True + + # User must provide data_antpos if not in data for baseline matching + assert data_antpos or hasattr(data, "data_antpos"), "data_antpos must be provided if not in data" + if data_antpos is None: + data_antpos = data.data_antpos + + # get keys from model and data dictionary + data_bls, _, data_to_model_bl_map = match_baselines( + list(data.keys()), list(model.keys()), data_antpos, model_is_redundant=model_is_redundant, + min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut, include_autos=include_autos, model_antpos=model_antpos + ) + + # Store gains and metadata + gains = {} + niters = {} + conv_crits = {} + + for pol in pols: + # Initialize arrays for gains, niters, and convergence criterea + gain_array = [] + niter_array = [] + conv_crit_array = [] + + # Get data baselines + _data_bls = [k for k in data_bls if k[2] == pol] + + # Pack data and model into numpy arrays + data_matrix, model_matrix, wgts, map_ants_to_index = _build_model_matrices( + data, model, weights, _data_bls, data_to_model_bl_map, ant_flags=ant_flags + ) + + for ti in range(ntimes): + _gains = [] + _niters = [] + _conv_crits= [] + for fi in range(nfreqs): + # If the weights are non-zero, run the optimizer. Otherwise, set the gains to 1. + if wgts[..., ti, fi].sum() > 0: + gain, niter, conv_crit = _stefcal_optimizer( + data_matrix[..., ti, fi], model_matrix[..., ti, fi], wgts[..., ti, fi], + tol=tol, maxiter=maxiter, stepsize=stepsize + ) + _niters.append(niter) + _conv_crits.append(conv_crit) + _gains.append(gain) + else: + gain = np.ones(data_matrix.shape[0], dtype='complex') + _niters.append(0) + _conv_crits.append(np.nan) + _gains.append(gain) + + gain_array.append(_gains) + niter_array.append(_niters) + conv_crit_array.append(_conv_crits) + + gain_array = np.array(gain_array) + + for k in ants: + if k in map_ants_to_index: + gains[(k, "J" + pol)] = gain_array[..., map_ants_to_index[k]] + else: + gains[(k, "J" + pol)] = np.ones((ntimes, nfreqs), dtype='complex') + + niters[pol] = np.array(niter_array) + conv_crits[pol] = np.array(conv_crit_array) + + return gains, niters, conv_crits + def delay_slope_lincal(model, data, antpos, wgts=None, refant=None, df=9.765625e4, f0=0.0, medfilt=True, kernel=(1, 5), assume_2D=True, four_pol=False, edge_cut=0, time_avg=False, diff --git a/hera_cal/tests/test_abscal.py b/hera_cal/tests/test_abscal.py index b0a582808..63895e637 100644 --- a/hera_cal/tests/test_abscal.py +++ b/hera_cal/tests/test_abscal.py @@ -13,10 +13,11 @@ from pyuvdata import UVCal, UVData import warnings from hera_sim.antpos import hex_array, linear_array +from hera_sim import vis from .. import io, abscal, redcal, utils, apply_cal from ..data import DATA_PATH -from ..datacontainer import DataContainer +from ..datacontainer import DataContainer, RedDataContainer from ..utils import split_pol, reverse_bl, split_bl from ..apply_cal import calibrate_in_place from ..flag_utils import synthesize_ant_flags @@ -751,6 +752,60 @@ def test_complex_phase_abscal(self): with pytest.raises(AssertionError): meta, delta_gains = abscal.complex_phase_abscal(data, model, reds, data_bls, model_bls) + def test_sky_calibration(self): + """ + """ + nfreqs = 10 + antpos = hex_array(3, split_core=True, outriggers=0) + reds = redcal.get_reds(antpos) + gains, model_vis, data_vis = vis.sim_red_data(reds, shape=(1, nfreqs)) + model_vis = RedDataContainer(model_vis, reds=reds) + data_vis = DataContainer(data_vis) + data_vis.data_antpos = antpos + weights = DataContainer({k: np.ones(data_vis[k].shape, dtype=bool) for k in data_vis}) + + # Test that the function runs + gains, niter, conv_crit = abscal.sky_calibration( + data_vis, model_vis, weights, maxiter=1000, tol=1e-10, stepsize=0.5 + ) + assert niter['nn'].shape == conv_crit['nn'].shape == (1, nfreqs) + + # Apply calibration with solved gains + data_vis_copy = copy.deepcopy(data_vis) + apply_cal.calibrate_in_place(data_vis_copy, gains) + + # Test that the data is calibrated properly + for k in data_vis: + np.testing.assert_array_almost_equal(data_vis_copy[k], model_vis[k]) + + # Test the function with antenna flags + ant_flags = {(0, 'Jnn'): True} + gains, niter, conv_crit = abscal.sky_calibration( + data_vis, model_vis, weights, antpos, ant_flags=ant_flags, maxiter=1000, tol=1e-10, stepsize=0.5 + ) + + # Check that the flagged antenna has unity gain + np.testing.assert_allclose(gains[(0, 'Jnn')], np.ones_like(gains[(0, 'Jnn')])) + + # Test the function with DataContainer + model_vis_copy = DataContainer(copy.deepcopy(model_vis)) + for k in data_vis: + model_vis_copy[k] = model_vis[k] + # Run calibration + gains, niter, conv_crit = abscal.sky_calibration( + data_vis, model_vis_copy, weights, antpos, maxiter=1000, tol=1e-10, stepsize=0.5 + ) + for k in data_vis: + np.testing.assert_array_almost_equal(data_vis_copy[k], model_vis[k]) + + # Test the function flagging a frequency and time + for k in weights: + weights[k][0, 0] = 0.0 + + gains, niter, conv_crit = abscal.sky_calibration( + data_vis, model_vis, weights, antpos, maxiter=1000, tol=1e-10, stepsize=0.5 + ) + assert np.isclose(gains[(0, 'Jnn')][0, 0], 1.0 + 0.0j) @pytest.mark.filterwarnings("ignore:The default for the `center` keyword has changed") @pytest.mark.filterwarnings("ignore:invalid value encountered in true_divide")