Skip to content

Commit

Permalink
Merge pull request #5 from christinahedges/mask
Browse files Browse the repository at this point in the history
Fixed contaminante to force a contiguous mask
  • Loading branch information
christinahedges committed Apr 8, 2021
2 parents 475b02f + e2641c6 commit 3c3e5bb
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 848 deletions.
88 changes: 78 additions & 10 deletions contaminante/contaminante.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
from tqdm import tqdm
import warnings
from scipy.ndimage import label

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
Expand Down Expand Up @@ -55,6 +56,10 @@ def calculate_contamination(
Transit midpoint of transiting object in days
duration : float
Duration of transit in days
sigma : float
The significance level at which to create an aperture for the contaminanting source.
If the apertures are large, try increasing sigma. If the apertures are small,
or contaminante fails, you could try (slightly) lowering sigma.
plot: bool
If True, will generate a figure
cbvs : bool
Expand Down Expand Up @@ -167,8 +172,15 @@ def calculate_contamination(
model = np.zeros(tpf.flux.shape)
model_err = np.zeros(tpf.flux.shape)

# Hard coded saturation limit. Probably not ideal.
saturated = np.max(np.nan_to_num(tpf.flux.value), axis=0) > 1.4e5
saturated |= np.abs(np.gradient(saturated.astype(float), axis=0)) != 0
if saturated.any():
dsat = np.gradient(saturated.astype(float), axis=0)
if (~np.any(dsat == 0.5)) | (~np.any(dsat == -0.5)):
raise ValueError(
"Too close to a saturation column that isn't fully captured."
)
saturated |= np.abs(dsat) != 0
pixels = tpf.flux.value.copy()
pixels_err = tpf.flux_err.value.copy()

Expand Down Expand Up @@ -208,7 +220,9 @@ def calculate_contamination(

with warnings.catch_warnings():
warnings.simplefilter("ignore")
contaminant_aper = (transit_pixels / transit_pixels_err) > sigma
contaminant_aper = create_threshold_mask(
transit_pixels / transit_pixels_err, sigma
)
contaminated_lc = tpf.to_lightcurve(aperture_mask=contaminant_aper).normalize()
r.lc = contaminated_lc
contaminator = r.correct(dm1, cadence_mask=~t_mask)
Expand Down Expand Up @@ -262,7 +276,7 @@ def get_coords(thumb, err, aper=None):
aper = np.ones(tpf.flux.shape[1:], bool)
with np.errstate(divide="ignore"):
Y, X = np.mgrid[: tpf.shape[1], : tpf.shape[2]]
aper = (thumb / err > 3) & aper
aper = create_threshold_mask(thumb / err, 3) & aper
cxs, cys = [], []
for count in range(500):
w = np.random.normal(loc=thumb[aper], scale=err[aper])
Expand Down Expand Up @@ -445,6 +459,13 @@ def _make_plot(tpf, res):
ax = plt.subplot2grid((1, 4), (0, 0))
ax.set_title("Target ID: {}".format(tpf.targetid))

if tpf.mission.lower() == "tess":
pix = 27 * u.arcsec.to(u.deg)
elif tpf.mission.lower() in ["kepler", "ktwo", "k2"]:
pix = 4 * u.arcsec.to(u.deg)
else:
pix = 0

xlim = [1e10, -1e10]
ylim = [1e10, -1e10]
ra, dec = np.asarray(np.median(tpf.get_coordinates(), axis=1))
Expand All @@ -456,18 +477,18 @@ def _make_plot(tpf, res):
cmap="Greys_r",
shading="auto",
)
xlim[0] = np.min([np.percentile(ra, 1), xlim[0]])
xlim[1] = np.max([np.percentile(ra, 99), xlim[1]])
ylim[0] = np.min([np.percentile(dec, 1), ylim[0]])
ylim[1] = np.max([np.percentile(dec, 99), ylim[1]])
xlim[0] = np.min([np.percentile(ra, 1) - pix, xlim[0]])
xlim[1] = np.max([np.percentile(ra, 99) + pix, xlim[1]])
ylim[0] = np.min([np.percentile(dec, 1) - pix, ylim[0]])
ylim[1] = np.max([np.percentile(dec, 99) + pix, ylim[1]])
# import pdb;pdb.set_trace()
ax.scatter(
np.hstack(res["target_ra"]),
np.hstack(res["target_dec"]),
c="C0",
marker=".",
s=10,
label="Target",
label="Center of Pipeline Aperture",
zorder=11,
)
if "contaminator_ra" in res.keys():
Expand All @@ -477,10 +498,10 @@ def _make_plot(tpf, res):
c="r",
marker=".",
s=13,
label="Source Of Transit",
label="Center of Transit Pixels",
zorder=10,
)

ax.legend(frameon=True)
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
if tpf.mission.lower() == "tess":
Expand Down Expand Up @@ -524,3 +545,50 @@ def _make_plot(tpf, res):
markersize=2,
)
return fig


def create_threshold_mask(thumb, threshold=3, reference_pixel="max"):
"""Lifted from lightkurve.
Creates a contiguous region where a "thumnbnail" is greater than some threshold
----------
thumb : np.ndarray
2D image, in this case the transit depth in every pixel divided by the
error.
threshold : float
A value for the number of sigma by which a pixel needs to be
brighter than the median flux to be included in the aperture mask.
reference_pixel: (int, int) tuple, 'center', 'max', or None
(col, row) pixel coordinate closest to the desired region.
In this case we use the maximum of the thumbnail.
Returns
-------
aperture_mask : ndarray
2D boolean numpy array containing `True` for pixels above the
threshold.
"""
if reference_pixel == "center":
reference_pixel = (thumb.shape[2] / 2, thumb.shape[1] / 2)
if reference_pixel == "max":
reference_pixel = np.where(thumb == np.nanmax(thumb))
reference_pixel = (reference_pixel[1][0], reference_pixel[0][0])
vals = thumb[np.isfinite(thumb)].flatten()
# Create a mask containing the pixels above the threshold flux
threshold_mask = np.nan_to_num(thumb) >= threshold
if (reference_pixel is None) or (not threshold_mask.any()):
# return all regions above threshold
return threshold_mask
else:
# Return only the contiguous region closest to `region`.
# First, label all the regions:
labels = label(threshold_mask)[0]
# For all pixels above threshold, compute distance to reference pixel:
label_args = np.argwhere(labels > 0)
distances = [
np.hypot(crd[0], crd[1])
for crd in label_args - np.array([reference_pixel[1], reference_pixel[0]])
]
# Which label corresponds to the closest pixel?
closest_arg = label_args[np.argmin(distances)]
closest_label = labels[closest_arg[0], closest_arg[1]]
return labels == closest_label
2 changes: 1 addition & 1 deletion contaminante/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# It is important to store the version number in a separate file
# so that we can read it from setup.py without importing the package
__version__ = "0.5.1"
__version__ = "0.5.2"
545 changes: 138 additions & 407 deletions docs/.ipynb_checkpoints/tutorial-checkpoint.ipynb

Large diffs are not rendered by default.

534 changes: 104 additions & 430 deletions docs/tutorial.ipynb

Large diffs are not rendered by default.

0 comments on commit 3c3e5bb

Please sign in to comment.