Skip to content

Commit

Permalink
Merge pull request #901 from HERA-Team/fix-bigmem-bug
Browse files Browse the repository at this point in the history
fix: future array shapes in chunk test and deprecation in numpy equals
  • Loading branch information
steven-murray committed Jun 22, 2023
2 parents fa684dc + 829c2de commit c7fdf56
Show file tree
Hide file tree
Showing 8 changed files with 12,341 additions and 19 deletions.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ jobs:
- name: Install
run: |
pip install .[dev]
pip install git+https://github.com/hera-team/hera_filters # temporary until hera-filters is released
- name: Run Tests
run: |
Expand Down
9 changes: 6 additions & 3 deletions hera_cal/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def chunk_files(filenames, inputfile, outputfile, chunk_size, type="data",
chunked_files = io.HERACal(filenames[start:end])
else:
raise ValueError("Invalid type provided. Must be in ['data', 'gains']")
read_args = {}

if type == 'data':
if polarizations is None:
if len(chunked_files.filepaths) > 1:
Expand All @@ -69,8 +69,11 @@ def chunk_files(filenames, inputfile, outputfile, chunk_size, type="data",
polarizations = chunked_files.pols
if spw_range is None:
spw_range = (0, chunked_files.Nfreqs)
data, flags, nsamples = chunked_files.read(polarizations=polarizations,
freq_chans=range(spw_range[0], spw_range[1]), **read_kwargs)
chunked_files.read(
polarizations=polarizations,
freq_chans=range(spw_range[0], spw_range[1]),
**read_kwargs
)
elif type == 'gains':
chunked_files.read()
if polarizations is None:
Expand Down
15 changes: 9 additions & 6 deletions hera_cal/smooth_cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,17 +888,20 @@ def check_consistency(self):
'''
all_time_indices = np.array([i for indices in self.time_indices.values() for i in indices])
assert len(all_time_indices) == len(np.unique(all_time_indices)), \
'Multiple calibration integrations map to the same time index.'
'Multiple calibration integrations map to the same time index.'
for cal in self.cals:
assert np.all(np.abs(self.cal_freqs[cal] - self.freqs) < 1e-4), \
'{} and {} have different frequencies.'.format(cal, self.cals[0])
assert np.all(
np.abs(self.cal_freqs[cal] - self.freqs) < 1e-4
), f'{cal} and {self.cals[0]} have different frequencies.'
if len(self.flag_files) > 0:
all_flag_time_indices = np.array([i for indices in self.flag_time_indices.values() for i in indices])
assert np.all(np.unique(all_flag_time_indices) == np.unique(all_time_indices)), \
'The number of unique indices for the flag files does not match the calibration files.'
unq_flag = np.unique(all_flag_time_indices)
unq_time = np.unique(all_time_indices)
assert len(unq_flag) == len(unq_time) and np.all(unq_flag == unq_time), \
'The number of unique indices for the flag files does not match the calibration files.'
for ff in self.flag_files:
assert np.all(np.abs(self.flag_freqs[ff] - self.freqs) < 1e-4), \
'{} and {} have different frequencies.'.format(ff, self.cals[0])
'{} and {} have different frequencies.'.format(ff, self.cals[0])

def rephase_to_refant(self, warn=True, propagate_refant_flags=False):
'''If the CalibrationSmoother object has a refant attribute, this function rephases the
Expand Down
20 changes: 12 additions & 8 deletions hera_cal/tests/test_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,34 @@
from hera_qm.utils import apply_yaml_flags
import numpy as np
import sys
from pyuvdata.uvdata import FastUVH5Meta


def test_chunk_data_files(tmpdir):
# list of data files:
tmp_path = tmpdir.strpath
data_files = sorted(glob.glob(DATA_PATH + '/zen.2458044.*.uvh5'))
data_files = sorted(glob.glob(f'{DATA_PATH}/zen.2458044.*.uvh5'))
nfiles = len(data_files)
# form chunks with three samples.
for chunk in range(0, nfiles, 2):
output = tmp_path + f'/chunk.{chunk}.uvh5'
chunker.chunk_files(data_files, data_files[chunk], output, 2,
polarizations=['ee'], spw_range=[0, 32],
throw_away_flagged_ants=True, ant_flag_yaml=DATA_PATH + '/test_input/a_priori_flags_sample_noflags.yaml')
throw_away_flagged_ants=True,
ant_flag_yaml=f'{DATA_PATH}/test_input/a_priori_flags_sample_noflags.yaml')

# test that chunked files contain identical data (when combined)
# to original combined list of files.
# load in chunks
chunks = sorted(glob.glob(tmp_path + '/chunk.*.uvh5'))
chunks = sorted(glob.glob(f'{tmp_path}/chunk.*.uvh5'))
uvd = UVData()
uvd.read(chunks)
uvd.read(chunks, use_future_array_shapes=True)
# load in original file
uvdo = UVData()
uvdo.read(data_files, freq_chans=range(32))
apply_yaml_flags(uvdo, DATA_PATH + '/test_input/a_priori_flags_sample_noflags.yaml', throw_away_flagged_ants=True,
uvdo.read(data_files, freq_chans=range(32), use_future_array_shapes=True)
# apply_yaml_flags always makes the uvdo object use future_array_shapes!
apply_yaml_flags(uvdo, f'{DATA_PATH}/test_input/a_priori_flags_sample_noflags.yaml',
throw_away_flagged_ants=True,
flag_freqs=False, flag_times=False, ant_indices_only=True)
assert np.all(np.isclose(uvdo.data_array, uvd.data_array))
assert np.all(np.isclose(uvdo.flag_array, uvd.flag_array))
Expand All @@ -52,10 +56,10 @@ def test_chunk_data_files(tmpdir):
# load in chunks
chunks = sorted(glob.glob(tmp_path + '/chunk.*.uvh5'))
uvd = UVData()
uvd.read(chunks)
uvd.read(chunks, use_future_array_shapes=True)
# load in original file
uvdo = UVData()
uvdo.read(data_files)
uvdo.read(data_files, use_future_array_shapes=True)
apply_yaml_flags(uvdo, DATA_PATH + '/test_input/a_priori_flags_sample_noflags.yaml', throw_away_flagged_ants=True,
flag_freqs=False, flag_times=False, ant_indices_only=True)
assert np.all(np.isclose(uvdo.data_array, uvd.data_array))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def package_files(package_dir, subdirectory):
'linsolve',
'hera_qm',
'scikit-learn',
'hera_filters',
'hera-filters',
"line_profiler",
'aipy',
"rich",
Expand Down

0 comments on commit c7fdf56

Please sign in to comment.