Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix build_fringe_rate_profiles bug when number of pols = 1 #903

Merged
merged 2 commits into from
Jun 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 43 additions & 46 deletions hera_cal/frf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
SPEED_OF_LIGHT = const.c.si.value
SDAY_SEC = units.sday.to("s")


def deinterleave_data_in_time(times, data: np.ndarray, ninterleave=1):
"""
Helper function for deinterleaving *time-ordered* data along time axis.
Expand All @@ -51,7 +52,7 @@ def deinterleave_data_in_time(times, data: np.ndarray, ninterleave=1):
tsets: list of np.ndarray
list of observation times sorted into the different interleaves
length equal to interleave

dsets: list of np.ndarray
list of data arrays sorted into ninterleave different interleaves.
"""
Expand All @@ -61,7 +62,7 @@ def deinterleave_data_in_time(times, data: np.ndarray, ninterleave=1):
for i in range(ninterleave):
tsets.append(times[i::ninterleave])
dsets.append(data[i::ninterleave])

return tsets, dsets


Expand Down Expand Up @@ -461,7 +462,7 @@ def build_fringe_rate_profiles(uvd, uvb, keys=None, normed=True, combine_pols=Tr
binned_power_conj = np.zeros_like(fr_grid)
# iterate over each frequency and ftaper weighting.
# use linspace to make sure we get first and last frequencies.
unflagged_chans = ~np.all(np.all(uvd.flag_array[:, :, :].squeeze(), axis=0), axis=-1)
unflagged_chans = ~np.all(np.all(uvd.flag_array, axis=0), axis=-1).squeeze()
chans_to_use = np.arange(uvd.Nfreqs).astype(int)[unflagged_chans][::fr_freq_skip]
frate_coeff = 2 * np.pi / SPEED_OF_LIGHT / (1e-3 * SDAY_SEC)
frate_over_freq = np.dot(np.cross(np.array([0, 0, 1.]), blvec), eq_xyz) * frate_coeff
Expand Down Expand Up @@ -921,15 +922,15 @@ class FRFilter(VisClean):
"""
FRFilter object. See hera_cal.vis_clean.VisClean.__init__ for instantiation options.
"""

def _deinterleave_data_in_time(self, container_name, ninterleave=1, keys=None, set_time_sets=True):
"""
Helper function to convert attached data and weights to time interleaved data and weights.

This method splits all attached data into multiple sets interleaved in time
and converts all keys from (ant1, ant2, pol) -> (ant1, ant2, pol, iset)
where iset indexes the interleaves.
where iset indexes the interleaves.

Parameters
---------
container_name: str
Expand All @@ -950,7 +951,7 @@ def _deinterleave_data_in_time(self, container_name, ninterleave=1, keys=None, s
will create two new DataContainers labeled 'data_interleave_0' and 'data_interleave_1'
It will also attach two lists called 'lst_sets' and 'time_sets' which have lists of the lsts
and times in each interleaved DataContainer.

"""
container = getattr(self, container_name)
if keys is None:
Expand All @@ -960,14 +961,14 @@ def _deinterleave_data_in_time(self, container_name, ninterleave=1, keys=None, s
new_container_name = container_name + f'_interleave_{inum}'
new_container = DataContainer({})
setattr(self, new_container_name, new_container)

for k in keys:
tsets, dsets = deinterleave_data_in_time(self.times, container[k], ninterleave=ninterleave)
for inum in range(ninterleave):
new_container_name = container_name + f'_interleave_{inum}'
new_container = getattr(self, new_container_name)
new_container[k] = dsets[inum]

if set_time_sets:
self.time_sets = tsets
self.lst_sets = [[] for i in range(ninterleave)]
Expand All @@ -977,12 +978,11 @@ def _deinterleave_data_in_time(self, container_name, ninterleave=1, keys=None, s
self.lst_sets[iset].append(lst)
iset = (iset + 1) % ninterleave
self.lst_sets = [np.asarray(lst_set) for lst_set in self.lst_sets]



def _interleave_data_in_time(self, deinterleaved_container_names, interleaved_container_name, keys=None):
"""
Helper function to restore deinterleaved data back to interleaved data.

Parameters
----------
deinterleaved_container_names: list of strings
Expand All @@ -1002,7 +1002,6 @@ def _interleave_data_in_time(self, deinterleaved_container_names, interleaved_co

for k in keys:
getattr(self, interleaved_container_name)[k] = interleave_data_in_time([getattr(self, cname)[k] for cname in deinterleaved_container_names])


def timeavg_data(self, data, times, lsts, t_avg, flags=None, nsamples=None,
wgt_by_nsample=True, wgt_by_favg_nsample=False, rephase=False,
Expand Down Expand Up @@ -1076,12 +1075,12 @@ def timeavg_data(self, data, times, lsts, t_avg, flags=None, nsamples=None,

# setup containers
for n in ['data', 'flags', 'nsamples']:

if output_postfix != '':
name = "{}_{}_{}".format(output_prefix, n, output_postfix)
else:
name = "{}_{}".format(output_prefix, n)

if not hasattr(self, name):
setattr(self, name, DataContainer({}))
if n == 'data':
Expand Down Expand Up @@ -1268,7 +1267,7 @@ def tophat_frfilter(self, frate_centers, frate_half_widths, keys=None, wgts=None
keys_before = list(filter_cache.keys())
else:
filter_cache = None

if wgts is None:
wgts = io.DataContainer({k: (~self.flags[k]).astype(float) for k in self.flags})
if pre_filter_modes_between_lobe_minimum_and_zero:
Expand All @@ -1283,17 +1282,17 @@ def tophat_frfilter(self, frate_centers, frate_half_widths, keys=None, wgts=None
filtered_name = 'clean_data'
model_name = 'clean_model'
resid_name = 'clean_resid'

if 'output_postfix' in filter_kwargs:
filtered_name = filtered_name + '_' + filter_kwargs['output_postfix']
model_name = model_name + '_' + filter_kwargs['output_postfix']
resid_name = resid_name + '_' + filter_kwargs['output_postfix']

if 'data' in filter_kwargs:
input_data = filter_kwargs.pop('data')
else:
input_data = self.data

if 'flags' in filter_kwargs:
input_flags = filter_kwargs.pop('flags')
else:
Expand Down Expand Up @@ -1356,7 +1355,7 @@ def tophat_frfilter(self, frate_centers, frate_half_widths, keys=None, wgts=None
if not mode == 'clean':
if write_cache:
filter_cache = io.write_filter_cache_scratch(filter_cache, cache_dir, skip_keys=keys_before)


def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=None,
wgt_by_nsample=True, wgt_by_favg_nsample=False, rephase=False,
Expand Down Expand Up @@ -1398,7 +1397,7 @@ def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=N
ninterleave: int, optional
number of subsets to break data into for interleaving.
this will produce ninterleave different output files
with names set equal to <output_name\ext>.interleave_<inum>.<ext>
with names set equal to <output_name/ext>.interleave_<inum>.<ext>
for example, if ninterleave = 2, outputname='averaged_data.uvh5'
then this method will produce two files named
'averaged_data.interleave_0.uvh5' and 'averaged_data.interleave_1.uvh5'
Expand All @@ -1415,14 +1414,14 @@ def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=N
default is False.
read_kwargs: kwargs dict
additional kwargs for for io.HERAData.read()

Returns
-------
None
"""
if ninterleave > 1 and filetype.lower() != 'uvh5':
raise ValueError(f"Interleaved data only supported for 'uvh5' filetype! User provided '{filetype}'.")

if baseline_list is not None and len(baseline_list) == 0:
warnings.warn("Length of baseline list is zero."
"This can happen under normal circumstances when there are more files in datafile_list then baselines."
Expand All @@ -1440,15 +1439,15 @@ def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=N
data = getattr(fr, f'data_interleave_{inum}')
flags = getattr(fr, f'flags_interleave_{inum}')
nsamples = getattr(fr, f'nsamples_interleave_{inum}')

fr.timeavg_data(data=data, flags=flags, nsamples=nsamples, times=fr.time_sets[inum],
lsts=fr.lst_sets[inum], t_avg=t_avg, wgt_by_nsample=wgt_by_nsample,
wgt_by_favg_nsample=wgt_by_favg_nsample, output_postfix=f'interleave_{inum}',
rephase=rephase)

timesets = [getattr(fr, f'avg_times_interleave_{inum}') for inum in range(ninterleave)]
ntimes = np.min([len(tset) for tset in timesets])

if equalize_interleave_times:
avg_times = np.mean([tset[:ntimes] for tset in timesets], axis=0)
avg_lsts = np.mean([getattr(fr, f'avg_lsts_interleave_{inum}')[:ntimes] for inum in range(ninterleave)], axis=0)
Expand All @@ -1457,14 +1456,13 @@ def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=N
avg_data = getattr(fr, f'avg_data_interleave_{inum}')
avg_nsamples = getattr(fr, f'avg_nsamples_interleave_{inum}')
avg_flags = getattr(fr, f'avg_flags_interleave_{inum}')

if equalize_interleave_times or equalize_interleave_ntimes:
for blk in avg_data:
avg_data[blk] = avg_data[blk][:ntimes]
avg_flags[blk] = avg_flags[blk][:ntimes]
avg_nsamples[blk] = avg_nsamples[blk][:ntimes]


if not equalize_interleave_times:
avg_times = getattr(fr, f'avg_times_interleave_{inum}')
avg_lsts = getattr(fr, f'avg_lsts_interleave_{inum}')
Expand Down Expand Up @@ -1551,11 +1549,11 @@ def tophat_frfilter_argparser(mode='clean'):
"If case == 'sky': then use fringe-rates corresponding to range of ",
"instantanous fringe-rates that include sky emission.")
filt_options.add_argument("--case", default="sky", help=' '.join(desc), type=str)
desc = ("Number interleaved time subsets to split the data into ",
desc = ("Number interleaved time subsets to split the data into ",
"and apply independent fringe-rate filters. Default is 1 (no interleaved filters).",
"This does not change the format of the output files but it does change the nature of their content.")
filt_options.add_argument("--ninterleave", default=1, type=int, help=desc)

return ap


Expand Down Expand Up @@ -1732,19 +1730,19 @@ def load_tophat_frfilter_and_write(datafile_list, case, baseline_list=None, calf
verbose=verbose, nfr=nfr)
# Lists of names of datacontainers that will hold each interleaved data set until they are
# recombined.
filtered_data_names = [ f'clean_data_interleave_{inum}' for inum in range(ninterleave) ]
filtered_flag_names = [ fstr.replace('data', 'flags') for fstr in filtered_data_names ]
filtered_resid_names = [ fstr.replace('data', 'resid') for fstr in filtered_data_names ]
filtered_model_names = [ fstr.replace('data', 'model') for fstr in filtered_data_names ]
filtered_resid_flag_names = [ fstr.replace('data', 'resid_flags') for fstr in filtered_data_names ]
filtered_data_names = [f'clean_data_interleave_{inum}' for inum in range(ninterleave)]
filtered_flag_names = [fstr.replace('data', 'flags') for fstr in filtered_data_names]
filtered_resid_names = [fstr.replace('data', 'resid') for fstr in filtered_data_names]
filtered_model_names = [fstr.replace('data', 'model') for fstr in filtered_data_names]
filtered_resid_flag_names = [fstr.replace('data', 'resid_flags') for fstr in filtered_data_names]

for inum in range(ninterleave):

# Build weights using flags, nsamples, and exlcuded lsts
flags = getattr(frfil, f'flags_interleave_{inum}')
nsamples = getattr(frfil, f'nsamples_interleave_{inum}')
wgts = io.DataContainer({k: (~flags[k]).astype(float) for k in flags})

lsts = frfil.lst_sets[inum]
for k in wgts:
if wgt_by_nsample:
Expand All @@ -1760,10 +1758,10 @@ def load_tophat_frfilter_and_write(datafile_list, case, baseline_list=None, calf
# run tophat filter
frfil.tophat_frfilter(frate_centers=frate_centers, frate_half_widths=frate_half_widths,
keys=keys, verbose=verbose, wgts=wgts, flags=getattr(frfil, f'flags_interleave_{inum}'),
data= getattr(frfil, f'data_interleave_{inum}'), output_postfix=f'interleave_{inum}',
data=getattr(frfil, f'data_interleave_{inum}'), output_postfix=f'interleave_{inum}',
times=frfil.time_sets[inum] * SDAY_SEC * 1e-3,
**filter_kwargs)

frfil._interleave_data_in_time(filtered_data_names, 'clean_data')
frfil._interleave_data_in_time(filtered_flag_names, 'clean_flags')
frfil._interleave_data_in_time(filtered_resid_names, 'clean_resid')
Expand All @@ -1786,11 +1784,11 @@ def load_tophat_frfilter_and_write(datafile_list, case, baseline_list=None, calf
frfil.clean_resid[bl] = frfil.data[bl]
frfil.clean_model[bl] = np.zeros_like(frfil.data[bl])
frfil.clean_resid_flags[bl] = frfil.flags[bl]

frfil.write_filtered_data(res_outfilename=res_outfilename, CLEAN_outfilename=CLEAN_outfilename,
filled_outfilename=filled_outfilename, partial_write=Nbls_per_load < len(baseline_list),
clobber=clobber, add_to_history=add_to_history,
extra_attrs={'Nfreqs': frfil.hd.Nfreqs, 'freq_array': frfil.hd.freq_array, 'channel_width': frfil.hd.channel_width, 'flex_spw_id_array': frfil.hd.flex_spw_id_array})
extra_attrs={'Nfreqs': frfil.hd.Nfreqs, 'freq_array': frfil.hd.freq_array, 'channel_width': frfil.hd.channel_width, 'flex_spw_id_array': frfil.hd.flex_spw_id_array})
frfil.hd.data_array = None # this forces a reload in the next loop


Expand Down Expand Up @@ -1820,9 +1818,9 @@ def time_average_argparser():
ap.add_argument("--verbose", default=False, action="store_true", help="verbose output.")
ap.add_argument("--flag_output", default=None, type=str, help="optional filename to save a separate copy of the time-averaged flags as a uvflag object.")
ap.add_argument("--filetype", default="uvh5", type=str, help="optional filetype specifier. Default is 'uvh5'. Set to 'miriad' if reading miriad files etc...")
desc = ("Number interleaved time subsets to split the data into ",
desc = ("Number interleaved time subsets to split the data into ",
"before averaging. Setting this greater than 1 will result in ninterleave different files ",
"with names equal to <output_data\ext>.interleave_<inum>.<ext>. ",
"with names equal to <output_data/ext>.interleave_<inum>.<ext>. ",
"For example, output_data = 'averaged_data.uvh5' and ninterleave=2' ",
"will result in two output files named 'averaged_data.interleave_0.uvh5 ",
"and 'averaged_data.interleave_1.uvh5'")
Expand All @@ -1831,6 +1829,5 @@ def time_average_argparser():
ap.add_argument("--equalize_interleave_times", action="store_true", default=False, help=desc)
desc = ("If set to True, truncate files with more excess interleaved times so all files have the same number of times.")
ap.add_argument("--equalize_interleave_ntimes", action="store_true", default=False, help=desc)



return ap
Loading