From f014721078af5b8118d326f0a3296d3dd41c08d2 Mon Sep 17 00:00:00 2001 From: Tyler Cox Date: Mon, 5 Aug 2024 09:58:42 -0700 Subject: [PATCH] fix solve_2D_DPSS bug where basis is complex --- hera_cal/smooth_cal.py | 4 ++-- hera_cal/tests/test_smooth_cal.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/hera_cal/smooth_cal.py b/hera_cal/smooth_cal.py index 7885a60d4..8a423e822 100644 --- a/hera_cal/smooth_cal.py +++ b/hera_cal/smooth_cal.py @@ -183,12 +183,12 @@ def solve_2D_DPSS(gains, weights, time_filters, freq_filters, method="pinv", cac # einsum indices are (t -> time, f -> freq, i > time filter index, j -> freq filter index, m -> # time filter index, n -> freq filter index) XTX = jnp.einsum( - "ti,fj,tf,tm,fn->ijmn", time_filters, freq_filters, weights, time_filters, freq_filters, optimize=True + "ti,fj,tf,tm,fn->ijmn", time_filters.conj(), freq_filters.conj(), weights, time_filters, freq_filters, optimize=True ) XTX = np.reshape(XTX, (ncomps, ncomps)) # Calculate X^T W y using the property (A \otimes B) vec(y) = (A Y B) - XTWy = jnp.ravel(jnp.dot(jnp.dot(jnp.transpose(time_filters), (gains * weights)), freq_filters)) + XTWy = jnp.ravel(jnp.dot(jnp.dot(jnp.transpose(time_filters.conj()), (gains * weights)), freq_filters.conj())) # Compute beta and reshape into a 2D array beta, cached_output = _linear_fit(XTX, XTWy, solver=method, cached_input=cached_input) diff --git a/hera_cal/tests/test_smooth_cal.py b/hera_cal/tests/test_smooth_cal.py index 72c9d5d39..5f29bc479 100644 --- a/hera_cal/tests/test_smooth_cal.py +++ b/hera_cal/tests/test_smooth_cal.py @@ -115,6 +115,21 @@ def test_solve_2D_DPSS(self): fit_lsq = X @ np.linalg.pinv((X.T * weights.ravel()) @ X) @ (X.T * weights.ravel()) @ gains.ravel() np.testing.assert_array_almost_equal(fit_lsq, fit2.ravel()) + # Check that this works when the basis functions are complex + freqs = np.linspace(100e6, 150e6, 100) + x = np.linspace(0, 2 * np.pi, 50) + X = dspec.dpss_operator(freqs, [0], [20e-9], eigenval_cutoff=[1e-12])[0].real + Y = dspec.dft_operator(x, [0], [0.1]) + + ncomps = X.shape[-1] * Y.shape[-1] + values = np.random.normal(0, 1, ncomps) + 1j * np.random.normal(0, 1, ncomps) + beta = np.reshape(values, (X.shape[1], Y.shape[1])) + gains = np.dot(np.dot(X, beta), np.transpose(Y)) + weights = np.ones(gains.shape) + + fit1, cached_output = smooth_cal.solve_2D_DPSS(gains, weights, X, Y) + np.testing.assert_array_almost_equal(gains, fit1) + def test_time_filter(self): gains = np.ones((10, 10), dtype=complex) gains[3, 5] = 10.0