From 446bd35681fe4e3ebc4b75e6d157a6dded642847 Mon Sep 17 00:00:00 2001 From: Tyler Cox Date: Sat, 29 Apr 2023 18:48:12 -0700 Subject: [PATCH 1/7] Add sky-based calibration and tests to abscal.py --- hera_cal/abscal.py | 250 +++++++++++++++++++++++++++++++++- hera_cal/tests/test_abscal.py | 56 +++++++- 2 files changed, 303 insertions(+), 3 deletions(-) diff --git a/hera_cal/abscal.py b/hera_cal/abscal.py index 82ae131b7..525f087f3 100644 --- a/hera_cal/abscal.py +++ b/hera_cal/abscal.py @@ -47,7 +47,7 @@ from . import redcal from . import io from . import apply_cal -from .datacontainer import DataContainer +from .datacontainer import DataContainer, RedDataContainer from .utils import echo, polnum2str, polstr2num, reverse_bl, split_pol, split_bl, join_bl, join_pol PHASE_SLOPE_SOLVERS = ['linfit', 'dft', 'ndim_fft'] # list of valid solvers for global_phase_slope_logcal @@ -727,6 +727,253 @@ def delay_lincal(model, data, wgts=None, refant=None, df=9.765625e4, f0=0., solv return fit +@jax.jit +def _stefcal_optimizer(data_matrix, model_matrix, weights, tol=1e-10, maxiter=1000, stepsize=0.5): + """ + Function to run stefcal optimization + + Parameters: + ---------- + data_matrix: np.ndarray + Data matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. Tsecond two axes are the time and frequency axes. + model_matrix: np.ndarray + Model matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. Tsecond two axes are the time and frequency axes. + weights: np.ndarray + Weights matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. Tsecond two axes are the time and frequency axes. + tol: float, optional, default=1e-10 + Tolerance for convergence criterea of subsequent iterations of stefcal + maxiter: int, optional, default=1000 + Maximum number of iterations to run the calibration + stepsize: float, optional, default=0.5 + Step size for the optimization. Must be between 0 and 1. A step size of 1 will take the full step in the direction + of the new gains, while a step size of 0 will take no step in the direction of the new gains. + + Returns: + ------- + gains: np.ndarray + Complex antenna gain solutions of shape (nants) + niters: int + Number of iterations performed by the optimizer + conv_crit: float + Convergence criterea of the final iteration of the optimizer + """ + def inner_function(args): + """ + Main optimization loop + """ + # Unpack arguments + gains, i, tau = args + + # Copy gains + g_old = jnp.copy(gains) + + # Compute the model gain product + zg = gains[:, None] * model_matrix + zgw = zg * weights + + # Compute gains + gains = jnp.sum(jnp.conj(data_matrix) * zgw, axis=(0)) / jnp.sum(jnp.conj(zgw) * zg, axis=(0)) + + # Set gains to 1 if they are nan + gains = jnp.where(jnp.isnan(gains), 1, gains) + gains = gains * stepsize + g_old * (1 - stepsize) + + # Compute convergence criterea + tau = jnp.sqrt(jnp.sum(jnp.abs(gains - g_old) ** 2))/ jnp.sqrt(jnp.sum(jnp.abs(gains)**2)) + return gains, i + 1, tau + + def conditional_function(args): + """ + Conditional function to check to convergence criterea + """ + _, i, tau = args + return (tau > tol) & (i < maxiter) + + nants = data_matrix.shape[0] + gains = jnp.ones(nants, dtype=complex) + + return jax.lax.while_loop(conditional_function, inner_function, (gains, 0, jnp.inf)) + +def _build_model_matrices(data, model, flags, baselines, ant_flags={}): + """ + Function to build data, model, and weights matrices for sky_calibration optimization. + + + Parameters: + ---------- + data: DataContainer + Visibility data of measurements. Keys are antenna pair + pol tuples (must match model), values are + complex ndarray visibilities matching shape of model + model: DataContainer or RedDataContainer + Model visibilities. Keys are antenna pair + pol tuples (must match data), values are complex ndarray + flags: DataContainer + Dictionary of flags. Keys are antenna pair + pol tuples (must match data), values are boolean ndarrays + baselines: list + List of baseline tuples of the form (ant1, ant2, pol) to include in the model matrices + ant_flags: dict, optional, default={} + Dictionary of antenna flags. Keys are antenna + pol tuples, values are boolean. If an antenna is flagged + in this dictionary, it will not be included in the model matrices. + + Returns: + ------- + data_matrix: np.ndarray + Data matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. Tsecond two axes are the time and frequency axes. + model_matrix: np.ndarray + Model matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. The second two axes are the time and frequency axes. + wgts_matrix: np.ndarray + Weights matrix of shape (nants, nants, ntimes, nfreqs). The first two axes are the antenna indices of antennas that + are not flagged. The second two axes are the time and frequency axes. + map_ants_to_index: dict + Dictionary mapping antennas to indices within the data, model, and wgts matrices + """ + # Get unique antennas + ants = sorted(list(set(sum([list(k[:2]) for k in data], [])))) + pols = sorted(list(set([k[2] for k in baselines]))) + + # Remove flagged antennas + ants = [ant for ant in ants if not ant_flags.get((ant, 'J' + pols[0]), False)] + nants = len(ants) + + # Map antennas to indices in the data, model, and wgts matrices + map_ants_to_index = { + ant: ki for ki, ant in enumerate(ants) + } + + # Number of times and frequencies + ntimes, nfreqs = data[baselines[0]].shape + + # Populate matrices + data_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='complex') + wgts_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='float') + model_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='complex') + + for bl in baselines: + if bl[0] in map_ants_to_index and bl[1] in map_ants_to_index: + m, n = map_ants_to_index[bl[0]], map_ants_to_index[bl[1]] + + # Data matrix + data_matrix[m, n] = data[bl] + data_matrix[n, m] = data[bl].conj() + + # Weights matrix + wgts_matrix[m, n] = (~flags[bl]).astype(float) + wgts_matrix[n, m] = wgts_matrix[m, n] + + # Model Matrix + model_matrix[m, n] = model[bl] + model_matrix[n, m] = model[bl].conj() + + return data_matrix, model_matrix, wgts_matrix, map_ants_to_index + +def sky_calibration(data, model, flags, ant_flags={}, tol=1e-10, maxiter=1000, stepsize=0.5): + """ + Solve for per-antenna gains using the Stefcal algorithm (Salvini et al. 2014). + + Parameters: + ---------- + data: DataContainer + Visibility data of measurements. Keys are antenna pair + pol tuples (must match model), values are + complex ndarray visibilities matching shape of model + model: DataContainer or RedDataContainer + Model visibilities. Keys are antenna pair + pol tuples (must match data), values are complex ndarray + flags: DataContainer + Dictionary of flags. Keys are antenna pair + pol tuples (must match data), values are boolean ndarrays + ant_flags: dict, optional, default={} + Dictionary of antenna flags. Keys are antenna + pol tuples, values are boolean. If an antenna is flagged + in this dictionary, it will not be included in the fit. + tol: float, optional, default=1e-10 + Tolerance for convergence criterea of subsequent iterations of stefcal + maxiter: int, optional, default=1000 + Maximum number of iterations to run the calibration + stepsize: float, optional, default=0.5 + Step size for the optimization. Must be between 0 and 1. A step size of 1 will take the full step in the direction + of the new gains, while a step size of 0 will take no step in the direction of the new gains. + + Returns: + ------- + gains: dict + Dictionary with all unflagged antenna-polarizations as keys and gain waterfall arrays as values + niters: np.ndarray + Number of iterations performed for each time, frequency, and polarization run of the calibration algorithm + conv_crits: np.ndarray + Convergence criterea for each time, frequency, and polarization run of the calibration. + """ + # Get number of times and frequencies + ntimes, nfreqs = data[list(data.keys())[0]].shape + + # Get unique polarizations in the data + pols = sorted(list(set([k[2] for k in data.keys()]))) + + # Get antennas + ants = sorted(list(set(sum([list(k[:2]) for k in data.keys()], [])))) + + # get keys from model and data dictionary + if isinstance(model, RedDataContainer): + all_bls = sorted(set(data.keys())) + else: + all_bls = sorted(set(data.keys()) & set(model.keys())) + + # Store gains and metadata + gains = {} + niters = {} + conv_crits = {} + + for pol in pols: + # Initialize arrays for gains, niters, and convergence criterea + gain_array = [] + niter_array = [] + conv_crit_array = [] + + # Get data baselines + baselines = [k for k in all_bls if k[2] == pol] + + # Pack data and model into numpy arrays + data_matrix, model_matrix, wgts, map_ants_to_index = _build_model_matrices( + data, model, flags, baselines, ant_flags=ant_flags + ) + + for ti in range(ntimes): + _gains = [] + _niters = [] + _conv_crits= [] + for fi in range(nfreqs): + if wgts[..., ti, fi].sum() > 0: + gain, niter, conv_crit = _stefcal_optimizer( + data_matrix[..., ti, fi], model_matrix[..., ti, fi], wgts[..., ti, fi], + tol=tol, maxiter=maxiter, stepsize=stepsize + ) + _niters.append(niter) + _conv_crits.append(conv_crit) + _gains.append(gain) + + else: + gain = np.ones(data_matrix.shape[0], dtype='complex') + _niters.append(0) + _conv_crits.append(np.nan) + _gains.append(gain) + + gain_array.append(_gains) + niter_array.append(_niters) + conv_crit_array.append(_conv_crits) + + gain_array = np.array(gain_array) + + for k in ants: + if k in map_ants_to_index: + gains[(k, "J" + pol)] = gain_array[..., map_ants_to_index[k]] + else: + gains[(k, "J" + pol)] = np.ones((ntimes, nfreqs), dtype='complex') + + niters[pol] = np.array(niter_array) + conv_crits[pol] = np.array(conv_crit_array) + + return gains, niters, conv_crits + def delay_slope_lincal(model, data, antpos, wgts=None, refant=None, df=9.765625e4, f0=0.0, medfilt=True, kernel=(1, 5), assume_2D=True, four_pol=False, edge_cut=0, time_avg=False, @@ -2450,7 +2697,6 @@ def match_times(datafile, modelfiles, filetype='uvh5', atol=1e-5): return match - def cut_bls(datacontainer, bls=None, min_bl_cut=None, max_bl_cut=None, inplace=False): """ Cut visibility data based on min and max baseline length. diff --git a/hera_cal/tests/test_abscal.py b/hera_cal/tests/test_abscal.py index b0a582808..c080b2d50 100644 --- a/hera_cal/tests/test_abscal.py +++ b/hera_cal/tests/test_abscal.py @@ -13,10 +13,11 @@ from pyuvdata import UVCal, UVData import warnings from hera_sim.antpos import hex_array, linear_array +from hera_sim import vis from .. import io, abscal, redcal, utils, apply_cal from ..data import DATA_PATH -from ..datacontainer import DataContainer +from ..datacontainer import DataContainer, RedDataContainer from ..utils import split_pol, reverse_bl, split_bl from ..apply_cal import calibrate_in_place from ..flag_utils import synthesize_ant_flags @@ -751,6 +752,59 @@ def test_complex_phase_abscal(self): with pytest.raises(AssertionError): meta, delta_gains = abscal.complex_phase_abscal(data, model, reds, data_bls, model_bls) + def test_sky_calibration(self): + """ + """ + nfreqs = 10 + antpos = hex_array(3, split_core=True, outriggers=0) + reds = redcal.get_reds(antpos) + gains, model_vis, data_vis = vis.sim_red_data(reds, shape=(1, nfreqs)) + model_vis = RedDataContainer(model_vis, reds=reds) + data_vis = DataContainer(data_vis) + flags = DataContainer({k: np.zeros(data_vis[k].shape, dtype=bool) for k in data_vis}) + + # Test that the function runs + gains, niter, conv_crit = abscal.sky_calibration( + data_vis, model_vis, flags, maxiter=1000, tol=1e-10, stepsize=0.5 + ) + assert niter['nn'].shape == conv_crit['nn'].shape == (1, nfreqs) + + # Apply calibration with solved gains + data_vis_copy = copy.deepcopy(data_vis) + apply_cal.calibrate_in_place(data_vis_copy, gains) + + # Test that the data is calibrated properly + for k in data_vis: + np.testing.assert_array_almost_equal(data_vis_copy[k], model_vis[k]) + + # Test the function with antenna flags + ant_flags = {(0, 'Jnn'): True} + gains, niter, conv_crit = abscal.sky_calibration( + data_vis, model_vis, flags, ant_flags=ant_flags, maxiter=1000, tol=1e-10, stepsize=0.5 + ) + + # Check that the flagged antenna has unity gain + np.testing.assert_allclose(gains[(0, 'Jnn')], np.ones_like(gains[(0, 'Jnn')])) + + # Test the function with DataContainer + model_vis_copy = DataContainer(copy.deepcopy(model_vis)) + for k in data_vis: + model_vis_copy[k] = model_vis[k] + # Run calibration + gains, niter, conv_crit = abscal.sky_calibration( + data_vis, model_vis_copy, flags, maxiter=1000, tol=1e-10, stepsize=0.5 + ) + for k in data_vis: + np.testing.assert_array_almost_equal(data_vis_copy[k], model_vis[k]) + + # Test the function flagging a frequency and time + for k in flags: + flags[k][0, 0] = True + + gains, niter, conv_crit = abscal.sky_calibration( + data_vis, model_vis, flags, maxiter=1000, tol=1e-10, stepsize=0.5 + ) + assert np.isclose(gains[(0, 'Jnn')][0, 0], 1.0 + 0.0j) @pytest.mark.filterwarnings("ignore:The default for the `center` keyword has changed") @pytest.mark.filterwarnings("ignore:invalid value encountered in true_divide") From 9a9b86efaaba3c833d9f9331a04d85704871798b Mon Sep 17 00:00:00 2001 From: Tyler Cox Date: Sun, 30 Apr 2023 14:49:39 -0700 Subject: [PATCH 2/7] Make getting antennas faster --- hera_cal/abscal.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/hera_cal/abscal.py b/hera_cal/abscal.py index 525f087f3..be72644f2 100644 --- a/hera_cal/abscal.py +++ b/hera_cal/abscal.py @@ -782,7 +782,7 @@ def inner_function(args): gains = gains * stepsize + g_old * (1 - stepsize) # Compute convergence criterea - tau = jnp.sqrt(jnp.sum(jnp.abs(gains - g_old) ** 2))/ jnp.sqrt(jnp.sum(jnp.abs(gains)**2)) + tau = jnp.sqrt(jnp.sum(jnp.abs(gains - g_old) ** 2)) / jnp.sqrt(jnp.sum(jnp.abs(gains)**2)) return gains, i + 1, tau def conditional_function(args): @@ -831,9 +831,10 @@ def _build_model_matrices(data, model, flags, baselines, ant_flags={}): map_ants_to_index: dict Dictionary mapping antennas to indices within the data, model, and wgts matrices """ - # Get unique antennas - ants = sorted(list(set(sum([list(k[:2]) for k in data], [])))) - pols = sorted(list(set([k[2] for k in baselines]))) + # Get antennas and polarizations from data + keys = list(data.keys()) + pols = np.unique([k[2] for k in keys]) + ants = np.unique(np.concatenate([k[:2] for k in keys])) # Remove flagged antennas ants = [ant for ant in ants if not ant_flags.get((ant, 'J' + pols[0]), False)] @@ -903,14 +904,13 @@ def sky_calibration(data, model, flags, ant_flags={}, tol=1e-10, maxiter=1000, s conv_crits: np.ndarray Convergence criterea for each time, frequency, and polarization run of the calibration. """ - # Get number of times and frequencies - ntimes, nfreqs = data[list(data.keys())[0]].shape - - # Get unique polarizations in the data - pols = sorted(list(set([k[2] for k in data.keys()]))) + # Get antennas and polarizations from data + keys = list(data.keys()) + pols = np.unique([k[2] for k in keys]) + ants = np.unique(np.concatenate([k[:2] for k in keys])) - # Get antennas - ants = sorted(list(set(sum([list(k[:2]) for k in data.keys()], [])))) + # Get number of times and frequencies + ntimes, nfreqs = data[keys[0]].shape # get keys from model and data dictionary if isinstance(model, RedDataContainer): From 71133a351dda1c6f9968311b1996eae9003eeafb Mon Sep 17 00:00:00 2001 From: Tyler Cox Date: Sun, 30 Apr 2023 17:22:05 -0700 Subject: [PATCH 3/7] Fix gain updates for proper convergence --- hera_cal/abscal.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/hera_cal/abscal.py b/hera_cal/abscal.py index be72644f2..3314889aa 100644 --- a/hera_cal/abscal.py +++ b/hera_cal/abscal.py @@ -765,12 +765,12 @@ def inner_function(args): Main optimization loop """ # Unpack arguments - gains, i, tau = args + gains, i, _ = args # Copy gains g_old = jnp.copy(gains) - # Compute the model gain product + # Compute the model-gain product zg = gains[:, None] * model_matrix zgw = zg * weights @@ -779,7 +779,10 @@ def inner_function(args): # Set gains to 1 if they are nan gains = jnp.where(jnp.isnan(gains), 1, gains) - gains = gains * stepsize + g_old * (1 - stepsize) + + # Average gains on even iterations and take a full step on odd iterations (ensures convergence) + step = ((i + 1) % 2) * stepsize + gains = gains * (1 - step) + g_old * step # Compute convergence criterea tau = jnp.sqrt(jnp.sum(jnp.abs(gains - g_old) ** 2)) / jnp.sqrt(jnp.sum(jnp.abs(gains)**2)) From d279841132898dbeb4a7c375a954df43bd9d953a Mon Sep 17 00:00:00 2001 From: Tyler Cox Date: Mon, 30 Oct 2023 14:47:36 -0700 Subject: [PATCH 4/7] Change flags param in sky_cal to weights --- hera_cal/abscal.py | 18 +++++++++--------- hera_cal/tests/test_abscal.py | 14 +++++++------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/hera_cal/abscal.py b/hera_cal/abscal.py index 3314889aa..8ad04b3a0 100644 --- a/hera_cal/abscal.py +++ b/hera_cal/abscal.py @@ -800,7 +800,7 @@ def conditional_function(args): return jax.lax.while_loop(conditional_function, inner_function, (gains, 0, jnp.inf)) -def _build_model_matrices(data, model, flags, baselines, ant_flags={}): +def _build_model_matrices(data, model, weights, baselines, ant_flags={}): """ Function to build data, model, and weights matrices for sky_calibration optimization. @@ -812,8 +812,8 @@ def _build_model_matrices(data, model, flags, baselines, ant_flags={}): complex ndarray visibilities matching shape of model model: DataContainer or RedDataContainer Model visibilities. Keys are antenna pair + pol tuples (must match data), values are complex ndarray - flags: DataContainer - Dictionary of flags. Keys are antenna pair + pol tuples (must match data), values are boolean ndarrays + weights: DataContainer + Dictionary of real-valued data weights. Keys are antenna pair + pol tuples (must match data). baselines: list List of baseline tuples of the form (ant1, ant2, pol) to include in the model matrices ant_flags: dict, optional, default={} @@ -865,7 +865,7 @@ def _build_model_matrices(data, model, flags, baselines, ant_flags={}): data_matrix[n, m] = data[bl].conj() # Weights matrix - wgts_matrix[m, n] = (~flags[bl]).astype(float) + wgts_matrix[m, n] = weights[bl] wgts_matrix[n, m] = wgts_matrix[m, n] # Model Matrix @@ -874,7 +874,7 @@ def _build_model_matrices(data, model, flags, baselines, ant_flags={}): return data_matrix, model_matrix, wgts_matrix, map_ants_to_index -def sky_calibration(data, model, flags, ant_flags={}, tol=1e-10, maxiter=1000, stepsize=0.5): +def sky_calibration(data, model, weights, ant_flags={}, tol=1e-10, maxiter=1000, stepsize=0.5): """ Solve for per-antenna gains using the Stefcal algorithm (Salvini et al. 2014). @@ -885,8 +885,8 @@ def sky_calibration(data, model, flags, ant_flags={}, tol=1e-10, maxiter=1000, s complex ndarray visibilities matching shape of model model: DataContainer or RedDataContainer Model visibilities. Keys are antenna pair + pol tuples (must match data), values are complex ndarray - flags: DataContainer - Dictionary of flags. Keys are antenna pair + pol tuples (must match data), values are boolean ndarrays + weights: DataContainer + Dictionary of real-valued data weights. Keys are antenna pair + pol tuples (must match data). ant_flags: dict, optional, default={} Dictionary of antenna flags. Keys are antenna + pol tuples, values are boolean. If an antenna is flagged in this dictionary, it will not be included in the fit. @@ -937,7 +937,7 @@ def sky_calibration(data, model, flags, ant_flags={}, tol=1e-10, maxiter=1000, s # Pack data and model into numpy arrays data_matrix, model_matrix, wgts, map_ants_to_index = _build_model_matrices( - data, model, flags, baselines, ant_flags=ant_flags + data, model, weights, baselines, ant_flags=ant_flags ) for ti in range(ntimes): @@ -945,6 +945,7 @@ def sky_calibration(data, model, flags, ant_flags={}, tol=1e-10, maxiter=1000, s _niters = [] _conv_crits= [] for fi in range(nfreqs): + # If the weights are non-zero, run the optimizer. Otherwise, set the gains to 1. if wgts[..., ti, fi].sum() > 0: gain, niter, conv_crit = _stefcal_optimizer( data_matrix[..., ti, fi], model_matrix[..., ti, fi], wgts[..., ti, fi], @@ -953,7 +954,6 @@ def sky_calibration(data, model, flags, ant_flags={}, tol=1e-10, maxiter=1000, s _niters.append(niter) _conv_crits.append(conv_crit) _gains.append(gain) - else: gain = np.ones(data_matrix.shape[0], dtype='complex') _niters.append(0) diff --git a/hera_cal/tests/test_abscal.py b/hera_cal/tests/test_abscal.py index c080b2d50..7ed347696 100644 --- a/hera_cal/tests/test_abscal.py +++ b/hera_cal/tests/test_abscal.py @@ -761,11 +761,11 @@ def test_sky_calibration(self): gains, model_vis, data_vis = vis.sim_red_data(reds, shape=(1, nfreqs)) model_vis = RedDataContainer(model_vis, reds=reds) data_vis = DataContainer(data_vis) - flags = DataContainer({k: np.zeros(data_vis[k].shape, dtype=bool) for k in data_vis}) + weights = DataContainer({k: np.ones(data_vis[k].shape, dtype=bool) for k in data_vis}) # Test that the function runs gains, niter, conv_crit = abscal.sky_calibration( - data_vis, model_vis, flags, maxiter=1000, tol=1e-10, stepsize=0.5 + data_vis, model_vis, weights, maxiter=1000, tol=1e-10, stepsize=0.5 ) assert niter['nn'].shape == conv_crit['nn'].shape == (1, nfreqs) @@ -780,7 +780,7 @@ def test_sky_calibration(self): # Test the function with antenna flags ant_flags = {(0, 'Jnn'): True} gains, niter, conv_crit = abscal.sky_calibration( - data_vis, model_vis, flags, ant_flags=ant_flags, maxiter=1000, tol=1e-10, stepsize=0.5 + data_vis, model_vis, weights, ant_flags=ant_flags, maxiter=1000, tol=1e-10, stepsize=0.5 ) # Check that the flagged antenna has unity gain @@ -792,17 +792,17 @@ def test_sky_calibration(self): model_vis_copy[k] = model_vis[k] # Run calibration gains, niter, conv_crit = abscal.sky_calibration( - data_vis, model_vis_copy, flags, maxiter=1000, tol=1e-10, stepsize=0.5 + data_vis, model_vis_copy, weights, maxiter=1000, tol=1e-10, stepsize=0.5 ) for k in data_vis: np.testing.assert_array_almost_equal(data_vis_copy[k], model_vis[k]) # Test the function flagging a frequency and time - for k in flags: - flags[k][0, 0] = True + for k in weights: + weights[k][0, 0] = 0.0 gains, niter, conv_crit = abscal.sky_calibration( - data_vis, model_vis, flags, maxiter=1000, tol=1e-10, stepsize=0.5 + data_vis, model_vis, weights, maxiter=1000, tol=1e-10, stepsize=0.5 ) assert np.isclose(gains[(0, 'Jnn')][0, 0], 1.0 + 0.0j) From c92c01da4f0a2504be1018bf5b85c28ca69d449b Mon Sep 17 00:00:00 2001 From: Tyler Cox Date: Fri, 3 Nov 2023 14:02:45 -0700 Subject: [PATCH 5/7] Build in matching model and data bls keys --- hera_cal/abscal.py | 64 +++++++++++++++++++++++++---------- hera_cal/tests/test_abscal.py | 8 ++--- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/hera_cal/abscal.py b/hera_cal/abscal.py index cfde890a1..76852e944 100644 --- a/hera_cal/abscal.py +++ b/hera_cal/abscal.py @@ -800,7 +800,7 @@ def conditional_function(args): return jax.lax.while_loop(conditional_function, inner_function, (gains, 0, jnp.inf)) -def _build_model_matrices(data, model, weights, baselines, ant_flags={}): +def _build_model_matrices(data, model, weights, data_bls, data_to_model_bls_map, ant_flags={}): """ Function to build data, model, and weights matrices for sky_calibration optimization. @@ -814,6 +814,11 @@ def _build_model_matrices(data, model, weights, baselines, ant_flags={}): Model visibilities. Keys are antenna pair + pol tuples (must match data), values are complex ndarray weights: DataContainer Dictionary of real-valued data weights. Keys are antenna pair + pol tuples (must match data). + data_bls: list + List of baseline tuples of the form (ant1, ant2, pol) to include in the data, model, and weights matrices + data_to_model_bls_map: dict + Dictionary mapping baseline tuples in the data to baseline tuples in the model. Keys are baseline tuples in the + data, values are baseline tuples in the model. baselines: list List of baseline tuples of the form (ant1, ant2, pol) to include in the model matrices ant_flags: dict, optional, default={} @@ -835,28 +840,25 @@ def _build_model_matrices(data, model, weights, baselines, ant_flags={}): Dictionary mapping antennas to indices within the data, model, and wgts matrices """ # Get antennas and polarizations from data - keys = list(data.keys()) - pols = np.unique([k[2] for k in keys]) - ants = np.unique(np.concatenate([k[:2] for k in keys])) + pols = np.unique([k[2] for k in data_bls]) + ants = np.unique(np.concatenate([k[:2] for k in data_bls])) # Remove flagged antennas ants = [ant for ant in ants if not ant_flags.get((ant, 'J' + pols[0]), False)] nants = len(ants) # Map antennas to indices in the data, model, and wgts matrices - map_ants_to_index = { - ant: ki for ki, ant in enumerate(ants) - } + map_ants_to_index = {ant: ki for ki, ant in enumerate(ants)} # Number of times and frequencies - ntimes, nfreqs = data[baselines[0]].shape + ntimes, nfreqs = data[data_bls[0]].shape # Populate matrices data_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='complex') wgts_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='float') model_matrix = np.zeros((nants, nants, ntimes, nfreqs), dtype='complex') - for bl in baselines: + for bl in data_bls: if bl[0] in map_ants_to_index and bl[1] in map_ants_to_index: m, n = map_ants_to_index[bl[0]], map_ants_to_index[bl[1]] @@ -869,12 +871,15 @@ def _build_model_matrices(data, model, weights, baselines, ant_flags={}): wgts_matrix[n, m] = wgts_matrix[m, n] # Model Matrix - model_matrix[m, n] = model[bl] - model_matrix[n, m] = model[bl].conj() + model_matrix[m, n] = model[data_to_model_bls_map[bl]] + model_matrix[n, m] = model[data_to_model_bls_map[bl]].conj() return data_matrix, model_matrix, wgts_matrix, map_ants_to_index -def sky_calibration(data, model, weights, ant_flags={}, tol=1e-10, maxiter=1000, stepsize=0.5): +def sky_calibration( + data, model, weights, data_antpos=None, ant_flags={}, tol=1e-10, maxiter=1000, stepsize=0.5, + min_bl_cut=None, max_bl_cut=None, model_antpos=None, model_is_redundant=False, include_autos=False + ): """ Solve for per-antenna gains using the Stefcal algorithm (Salvini et al. 2014). @@ -887,6 +892,9 @@ def sky_calibration(data, model, weights, ant_flags={}, tol=1e-10, maxiter=1000, Model visibilities. Keys are antenna pair + pol tuples (must match data), values are complex ndarray weights: DataContainer Dictionary of real-valued data weights. Keys are antenna pair + pol tuples (must match data). + data_antpos: dict, default=None + Dictionary of antenna positions. Keys are antenna numbers, values are antenna position vectors. If not provided, + the antenna positions will be taken from the data dictionary if it has the attribute data_antpos. ant_flags: dict, optional, default={} Dictionary of antenna flags. Keys are antenna + pol tuples, values are boolean. If an antenna is flagged in this dictionary, it will not be included in the fit. @@ -897,6 +905,17 @@ def sky_calibration(data, model, weights, ant_flags={}, tol=1e-10, maxiter=1000, stepsize: float, optional, default=0.5 Step size for the optimization. Must be between 0 and 1. A step size of 1 will take the full step in the direction of the new gains, while a step size of 0 will take no step in the direction of the new gains. + min_bl_cut: float, optional, default=None + Minimum baseline length to include in the fit. If None, no minimum baseline length cut will be applied. + max_bl_cut: float, optional, default=None + Maximum baseline length to include in the fit. If None, no minimum baseline length cut will be applied. + model_antpos: dict, optional, default=None + Dictionary of antenna positions. Keys are antenna numbers, values are antenna position vectors. If not provided, + it is assumed that the model antpos matches the data antpos + model_is_redundant: bool, optional, default=False + If True, it is assumed that the model is redundant. + include_autos: bool, optional, default=False + If True, include auto-correlations in the model. If False, auto-correlations will be ignored. Returns: ------- @@ -915,11 +934,20 @@ def sky_calibration(data, model, weights, ant_flags={}, tol=1e-10, maxiter=1000, # Get number of times and frequencies ntimes, nfreqs = data[keys[0]].shape - # get keys from model and data dictionary + # Check if the model is a RedDataContainer. If so, we can assume that the model is redundant if isinstance(model, RedDataContainer): - all_bls = sorted(set(data.keys())) - else: - all_bls = sorted(set(data.keys()) & set(model.keys())) + model_is_redundant = True + + # User must provide data_antpos if not in data for baseline matching + assert data_antpos or has_attr(data, "data_antpos"), "data_antpos must be provided if not in data" + if data_antpos is None: + data_antpos = data.data_antpos + + # get keys from model and data dictionary + data_bls, _, data_to_model_bl_map = match_baselines( + list(data.keys()), list(model.keys()), data_antpos, model_is_redundant=model_is_redundant, + min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut, include_autos=include_autos + ) # Store gains and metadata gains = {} @@ -933,11 +961,11 @@ def sky_calibration(data, model, weights, ant_flags={}, tol=1e-10, maxiter=1000, conv_crit_array = [] # Get data baselines - baselines = [k for k in all_bls if k[2] == pol] + _data_bls = [k for k in data_bls if k[2] == pol] # Pack data and model into numpy arrays data_matrix, model_matrix, wgts, map_ants_to_index = _build_model_matrices( - data, model, weights, baselines, ant_flags=ant_flags + data, model, weights, _data_bls, data_to_model_bl_map, ant_flags=ant_flags ) for ti in range(ntimes): diff --git a/hera_cal/tests/test_abscal.py b/hera_cal/tests/test_abscal.py index 7ed347696..0574b9622 100644 --- a/hera_cal/tests/test_abscal.py +++ b/hera_cal/tests/test_abscal.py @@ -765,7 +765,7 @@ def test_sky_calibration(self): # Test that the function runs gains, niter, conv_crit = abscal.sky_calibration( - data_vis, model_vis, weights, maxiter=1000, tol=1e-10, stepsize=0.5 + data_vis, model_vis, weights, antpos, maxiter=1000, tol=1e-10, stepsize=0.5 ) assert niter['nn'].shape == conv_crit['nn'].shape == (1, nfreqs) @@ -780,7 +780,7 @@ def test_sky_calibration(self): # Test the function with antenna flags ant_flags = {(0, 'Jnn'): True} gains, niter, conv_crit = abscal.sky_calibration( - data_vis, model_vis, weights, ant_flags=ant_flags, maxiter=1000, tol=1e-10, stepsize=0.5 + data_vis, model_vis, weights, antpos, ant_flags=ant_flags, maxiter=1000, tol=1e-10, stepsize=0.5 ) # Check that the flagged antenna has unity gain @@ -792,7 +792,7 @@ def test_sky_calibration(self): model_vis_copy[k] = model_vis[k] # Run calibration gains, niter, conv_crit = abscal.sky_calibration( - data_vis, model_vis_copy, weights, maxiter=1000, tol=1e-10, stepsize=0.5 + data_vis, model_vis_copy, weights, antpos, maxiter=1000, tol=1e-10, stepsize=0.5 ) for k in data_vis: np.testing.assert_array_almost_equal(data_vis_copy[k], model_vis[k]) @@ -802,7 +802,7 @@ def test_sky_calibration(self): weights[k][0, 0] = 0.0 gains, niter, conv_crit = abscal.sky_calibration( - data_vis, model_vis, weights, maxiter=1000, tol=1e-10, stepsize=0.5 + data_vis, model_vis, weights, antpos, maxiter=1000, tol=1e-10, stepsize=0.5 ) assert np.isclose(gains[(0, 'Jnn')][0, 0], 1.0 + 0.0j) From ab899b3cde34dd38de39931b1c1990877c7fbfa3 Mon Sep 17 00:00:00 2001 From: Tyler Cox Date: Fri, 3 Nov 2023 17:07:20 -0700 Subject: [PATCH 6/7] Add additional tests to cover missing lines --- hera_cal/abscal.py | 2 +- hera_cal/tests/test_abscal.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/hera_cal/abscal.py b/hera_cal/abscal.py index 76852e944..bd80399cd 100644 --- a/hera_cal/abscal.py +++ b/hera_cal/abscal.py @@ -939,7 +939,7 @@ def sky_calibration( model_is_redundant = True # User must provide data_antpos if not in data for baseline matching - assert data_antpos or has_attr(data, "data_antpos"), "data_antpos must be provided if not in data" + assert data_antpos or hasattr(data, "data_antpos"), "data_antpos must be provided if not in data" if data_antpos is None: data_antpos = data.data_antpos diff --git a/hera_cal/tests/test_abscal.py b/hera_cal/tests/test_abscal.py index 0574b9622..63895e637 100644 --- a/hera_cal/tests/test_abscal.py +++ b/hera_cal/tests/test_abscal.py @@ -761,11 +761,12 @@ def test_sky_calibration(self): gains, model_vis, data_vis = vis.sim_red_data(reds, shape=(1, nfreqs)) model_vis = RedDataContainer(model_vis, reds=reds) data_vis = DataContainer(data_vis) + data_vis.data_antpos = antpos weights = DataContainer({k: np.ones(data_vis[k].shape, dtype=bool) for k in data_vis}) # Test that the function runs gains, niter, conv_crit = abscal.sky_calibration( - data_vis, model_vis, weights, antpos, maxiter=1000, tol=1e-10, stepsize=0.5 + data_vis, model_vis, weights, maxiter=1000, tol=1e-10, stepsize=0.5 ) assert niter['nn'].shape == conv_crit['nn'].shape == (1, nfreqs) From 46b738c56a4d063c0ce806619fd99aab84961991 Mon Sep 17 00:00:00 2001 From: Tyler Cox Date: Mon, 6 Nov 2023 11:24:34 -0800 Subject: [PATCH 7/7] Include unused parameter in skycal function --- hera_cal/abscal.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hera_cal/abscal.py b/hera_cal/abscal.py index bd80399cd..6c49aef37 100644 --- a/hera_cal/abscal.py +++ b/hera_cal/abscal.py @@ -819,8 +819,6 @@ def _build_model_matrices(data, model, weights, data_bls, data_to_model_bls_map, data_to_model_bls_map: dict Dictionary mapping baseline tuples in the data to baseline tuples in the model. Keys are baseline tuples in the data, values are baseline tuples in the model. - baselines: list - List of baseline tuples of the form (ant1, ant2, pol) to include in the model matrices ant_flags: dict, optional, default={} Dictionary of antenna flags. Keys are antenna + pol tuples, values are boolean. If an antenna is flagged in this dictionary, it will not be included in the model matrices. @@ -946,7 +944,7 @@ def sky_calibration( # get keys from model and data dictionary data_bls, _, data_to_model_bl_map = match_baselines( list(data.keys()), list(model.keys()), data_antpos, model_is_redundant=model_is_redundant, - min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut, include_autos=include_autos + min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut, include_autos=include_autos, model_antpos=model_antpos ) # Store gains and metadata