Skip to content

Commit

Permalink
Get facetgrid working again
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jan 14, 2019
1 parent a12378c commit 1d939af
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 40 deletions.
52 changes: 13 additions & 39 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import numpy as np

from ..core.alignment import broadcast
from .facetgrid import FacetGrid
from .facetgrid import _easy_facetgrid
from .utils import (
_add_colorbar, _determine_cmap_params, _ensure_numeric,
_valid_other_type, get_axis, label_from_attrs)
_add_colorbar, _determine_cmap_params,
_ensure_numeric, _valid_other_type, get_axis, label_from_attrs)


def _infer_meta_data(ds, x, y, hue, hue_style, add_colorbar,
Expand Down Expand Up @@ -79,32 +79,6 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_colorbar,
'hue_values': hue_values}


def _easy_facetgrid(ds, plotfunc, x, y, row=None, col=None,
col_wrap=None, sharex=True, sharey=True, aspect=None,
size=None, subplot_kws=None, **kwargs):
"""
Convenience method to call xarray.plot.FacetGrid from 2d plotting methods
kwargs are the arguments to 2d plotting method
"""
ax = kwargs.pop('ax', None)
figsize = kwargs.pop('figsize', None)
if ax is not None:
raise ValueError("Can't use axes when making faceted plots.")
if aspect is None:
aspect = 1
if size is None:
size = 3
elif figsize is not None:
raise ValueError('cannot provide both `figsize` and `size` arguments')

g = FacetGrid(data=ds, col=col, row=row, col_wrap=col_wrap,
sharex=sharex, sharey=sharey, figsize=figsize,
aspect=aspect, size=size, subplot_kws=subplot_kws)

return g.map_dataset(plotfunc, x, y, **kwargs)


def _infer_scatter_data(ds, x, y, hue):

data = {'x': ds[x].values.flatten(),
Expand Down Expand Up @@ -231,18 +205,17 @@ def newplotfunc(ds, x=None, y=None, hue=None, hue_style=None,
if col or row:
allargs = locals().copy()
allargs['plotfunc'] = globals()[plotfunc.__name__]

allargs['data'] = ds
# TODO dcherian: why do I need to remove kwargs?
for arg in ['meta_data', 'kwargs']:
for arg in ['meta_data', 'kwargs', 'ds']:
del allargs[arg]

return _easy_facetgrid(**allargs)
return _easy_facetgrid(kind='dataset', **allargs)

figsize = kwargs.pop('figsize', None)
ax = kwargs.pop('ax', None)
ax = get_axis(figsize, size, aspect, ax)

kwargs = kwargs.copy()
# TODO dcherian: _meta_data should not be needed
# I'm trying to avoid calling _determine_cmap_params multiple times
_meta_data = kwargs.pop('_meta_data', None)

if hue_style == 'continuous' and hue is not None:
Expand Down Expand Up @@ -271,8 +244,10 @@ def newplotfunc(ds, x=None, y=None, hue=None, hue_style=None,
else:
cmap_params_subset = {}

primitive = plotfunc(ax, ds, x, y, hue, hue_style,
cmap_params=cmap_params_subset, **kwargs)
# TODO dcherian: hue, hue_style shouldn't be needed for all methods
# update signatures
primitive = plotfunc(ds=ds, x=x, y=y, hue=hue, hue_style=hue_style,
ax=ax, cmap_params=cmap_params_subset, **kwargs)

if _meta_data: # if this was called from Facetgrid.map_dataset,
return primitive # finish here. Else, make labels
Expand Down Expand Up @@ -325,9 +300,8 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, hue=None,


@_dsplot
def scatter(ax, ds, x, y, hue, hue_style, **kwargs):
def scatter(ds, x, y, hue, hue_style, ax, **kwargs):
""" Scatter Dataset data variables against each other. """

cmap_params = kwargs.pop('cmap_params')

if hue_style == 'discrete':
Expand Down
4 changes: 3 additions & 1 deletion xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ def map_dataset(self, func, x=None, y=None, hue=None, hue_style=None,
# None is the sentinel value
if d is not None:
subset = self.data.loc[d]
maybe_mappable = func(subset, x=x, y=y, hue=hue,
maybe_mappable = func(ds=subset, x=x, y=y,
hue=hue, hue_style=hue_style,
ax=ax, **kwargs)
# TODO: this is needed to get legends to work.
# but maybe_mappable is a list in that case :/
Expand Down Expand Up @@ -597,4 +598,5 @@ def _easy_facetgrid(data, plotfunc, x=None, y=None, kind=None, row=None, col=Non
return g.map_dataarray(plotfunc, x, y, **kwargs)
elif kind == 'array line':
return g.map_dataarray_line(hue=kwargs.pop('hue'), **kwargs)
elif kind == 'dataset':
return g.map_dataset(plotfunc, x, y, **kwargs)

0 comments on commit 1d939af

Please sign in to comment.