From 458ed5ca3767ab482d6b3b7443f9f3bbf4aa94d8 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 5 Feb 2021 12:19:26 -0700 Subject: [PATCH 1/6] Refactor _infer_line_data --- xarray/plot/plot.py | 59 +++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 8a57e17e5e8..8597dab9afa 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -29,6 +29,29 @@ ) +def _choose_x_y(darray, name, huename): + """Create x variable and y variable for line plots, appropriately transposed + based on huename.""" + xplt = darray[name] + if xplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + otherdim = darray.dims[otherindex] + yplt = darray.transpose(..., otherdim, huename, transpose_coords=False) + xplt = xplt.transpose(..., otherdim, huename, transpose_coords=False) + else: + raise ValueError( + f"For 2D inputs, hue must be a dimension i.e. one of {darray.dims!r}" + ) + + else: + (xdim,) = darray[name].dims + (huedim,) = darray[huename].dims + yplt = darray.transpose(..., xdim, huedim) + + return xplt, yplt + + def _infer_line_data(darray, x, y, hue): ndims = len(darray.dims) @@ -66,43 +89,11 @@ def _infer_line_data(darray, x, y, hue): if y is None: xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) - xplt = darray[xname] - if xplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - otherdim = darray.dims[otherindex] - yplt = darray.transpose(otherdim, huename, transpose_coords=False) - xplt = xplt.transpose(otherdim, huename, transpose_coords=False) - else: - raise ValueError( - "For 2D inputs, hue must be a dimension" - " i.e. one of " + repr(darray.dims) - ) - - else: - (xdim,) = darray[xname].dims - (huedim,) = darray[huename].dims - yplt = darray.transpose(xdim, huedim) + xplt, yplt = _choose_x_y(darray, xname, huename) else: yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) - yplt = darray[yname] - if yplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - otherdim = darray.dims[otherindex] - xplt = darray.transpose(otherdim, huename, transpose_coords=False) - yplt = yplt.transpose(otherdim, huename, transpose_coords=False) - else: - raise ValueError( - "For 2D inputs, hue must be a dimension" - " i.e. one of " + repr(darray.dims) - ) - - else: - (ydim,) = darray[yname].dims - (huedim,) = darray[huename].dims - xplt = darray.transpose(ydim, huedim) + yplt, xplt = _choose_x_y(darray, yname, huename) huelabel = label_from_attrs(darray[huename]) hueplt = darray[huename] From 2b0867d01aea07f5b1a7371603873e5f4eae45f2 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 5 Feb 2021 12:26:09 -0700 Subject: [PATCH 2/6] refactor out override_sigature --- xarray/plot/plot.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 8597dab9afa..ff01b457209 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -101,6 +101,25 @@ def _infer_line_data(darray, x, y, hue): return xplt, yplt, hueplt, huelabel +def override_signature(f): + def wrapper(func): + func.__wrapped__ = f + + return func + + return wrapper + + +# plotfunc and newplotfunc have different signatures: +# - plotfunc: (x, y, z, ax, **kwargs) +# - newplotfunc: (darray, x, y, **kwargs) +# where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray +# and variable names. newplotfunc also explicitly lists most kwargs, so we +# need to shorten it +def signature(darray, x, y, **kwargs): + pass + + def plot( darray, row=None, @@ -453,15 +472,6 @@ def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) -def override_signature(f): - def wrapper(func): - func.__wrapped__ = f - - return func - - return wrapper - - def _plot2d(plotfunc): """ Decorator for common 2d plotting logic @@ -571,15 +581,6 @@ def _plot2d(plotfunc): # Build on the original docstring plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" - # plotfunc and newplotfunc have different signatures: - # - plotfunc: (x, y, z, ax, **kwargs) - # - newplotfunc: (darray, x, y, **kwargs) - # where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray - # and variable names. newplotfunc also explicitly lists most kwargs, so we - # need to shorten it - def signature(darray, x, y, **kwargs): - pass - @override_signature(signature) @functools.wraps(plotfunc) def newplotfunc( From f0711024369cbd3547bf9ed229b7fa372f078362 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 5 Feb 2021 12:36:24 -0700 Subject: [PATCH 3/6] Use _plot1d decorator for line --- xarray/plot/facetgrid.py | 4 +- xarray/plot/plot.py | 355 ++++++++++++++++++++++++--------------- 2 files changed, 222 insertions(+), 137 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 58b38251352..9d527e0f346 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -286,7 +286,7 @@ def map_dataarray(self, func, x, y, **kwargs): return self def map_dataarray_line( - self, func, x, y, hue, add_legend=True, _labels=None, **kwargs + self, func, x, y, hue, add_legend=True, add_labels=None, **kwargs ): from .plot import _infer_line_data @@ -301,7 +301,7 @@ def map_dataarray_line( ax=ax, hue=hue, add_legend=False, - _labels=False, + add_labels=False, **kwargs, ) self._mappables.append(mappable) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index ff01b457209..c4e72e3ed01 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -207,137 +207,6 @@ def plot( return plotfunc(darray, **kwargs) -# This function signature should not change so that it can use -# matplotlib format strings -def line( - darray, - *args, - row=None, - col=None, - figsize=None, - aspect=None, - size=None, - ax=None, - hue=None, - x=None, - y=None, - xincrease=None, - yincrease=None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - add_legend=True, - _labels=True, - **kwargs, -): - """ - Line plot of DataArray index against values - - Wraps :func:`matplotlib:matplotlib.pyplot.plot` - - Parameters - ---------- - darray : DataArray - Must be 1 dimensional - figsize : tuple, optional - A tuple (width, height) of the figure in inches. - Mutually exclusive with ``size`` and ``ax``. - aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the width in - inches. Only used if a ``size`` is provided. - size : scalar, optional - If provided, create a new figure for the plot with the given size. - Height (in inches) of each plot. See also: ``aspect``. - ax : matplotlib axes object, optional - Axis on which to plot this figure. By default, use the current axis. - Mutually exclusive with ``size`` and ``figsize``. - hue : string, optional - Dimension or coordinate for which you want multiple lines plotted. - If plotting against a 2D coordinate, ``hue`` must be a dimension. - x, y : string, optional - Dimension, coordinate or MultiIndex level for x, y axis. - Only one of these may be specified. - The other coordinate plots values from the DataArray on which this - plot method is called. - xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional - Specifies scaling for the x- and y-axes respectively - xticks, yticks : Specify tick locations for x- and y-axes - xlim, ylim : Specify x- and y-axes limits - xincrease : None, True, or False, optional - Should the values on the x axes be increasing from left to right? - if None, use the default for the matplotlib function. - yincrease : None, True, or False, optional - Should the values on the y axes be increasing from top to bottom? - if None, use the default for the matplotlib function. - add_legend : bool, optional - Add legend with y axis coordinates (2D inputs only). - *args, **kwargs : optional - Additional arguments to matplotlib.pyplot.plot - """ - # Handle facetgrids first - if row or col: - allargs = locals().copy() - allargs.update(allargs.pop("kwargs")) - allargs.pop("darray") - return _easy_facetgrid(darray, line, kind="line", **allargs) - - ndims = len(darray.dims) - if ndims > 2: - raise ValueError( - "Line plots are for 1- or 2-dimensional DataArrays. " - "Passed DataArray has {ndims} " - "dimensions".format(ndims=ndims) - ) - - # The allargs dict passed to _easy_facetgrid above contains args - if args == (): - args = kwargs.pop("args", ()) - else: - assert "args" not in kwargs - - ax = get_axis(figsize, size, aspect, ax) - xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) - - # Remove pd.Intervals if contained in xplt.values and/or yplt.values. - xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( - xplt.values, yplt.values, kwargs - ) - xlabel = label_from_attrs(xplt, extra=x_suffix) - ylabel = label_from_attrs(yplt, extra=y_suffix) - - _ensure_plottable(xplt_val, yplt_val) - - primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) - - if _labels: - if xlabel is not None: - ax.set_xlabel(xlabel) - - if ylabel is not None: - ax.set_ylabel(ylabel) - - ax.set_title(darray._title_for_slice()) - - if darray.ndim == 2 and add_legend: - ax.legend(handles=primitive, labels=list(hueplt.values), title=hue_label) - - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots - if np.issubdtype(xplt.dtype, np.datetime64): - for xlabels in ax.get_xticklabels(): - xlabels.set_rotation(30) - xlabels.set_ha("right") - - _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) - - return primitive - - def step(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs): """ Step plot of DataArray index against values @@ -463,15 +332,231 @@ def __call__(self, **kwargs): def hist(self, ax=None, **kwargs): return hist(self._da, ax=ax, **kwargs) - @functools.wraps(line) - def line(self, *args, **kwargs): - return line(self._da, *args, **kwargs) - @functools.wraps(step) def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) +def _plot1d(plotfunc): + """ + Decorator for common 2d plotting logic + + Also adds the 2d plot method to class _PlotMethods + """ + commondoc = """ + Parameters + ---------- + darray : DataArray + Must be 2 dimensional, unless creating faceted plots + x : string, optional + Coordinate for x axis. If None use darray.dims[1] + y : string, optional + Coordinate for y axis. If None use darray.dims[0] + hue : string, optional + Dimension or coordinate for which you want multiple lines plotted. + figsize : tuple, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + aspect : scalar, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + ax : matplotlib.axes.Axes, optional + Axis on which to plot this figure. By default, use the current axis. + Mutually exclusive with ``size`` and ``figsize``. + row : string, optional + If passed, make row faceted plots on this dimension name + col : string, optional + If passed, make column faceted plots on this dimension name + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots + xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional + Specifies scaling for the x- and y-axes respectively + xticks, yticks : Specify tick locations for x- and y-axes + xlim, ylim : Specify x- and y-axes limits + xincrease : None, True, or False, optional + Should the values on the x axes be increasing from left to right? + if None, use the default for the matplotlib function. + yincrease : None, True, or False, optional + Should the values on the y axes be increasing from top to bottom? + if None, use the default for the matplotlib function. + add_labels : bool, optional + Use xarray metadata to label axes + subplot_kws : dict, optional + Dictionary of keyword arguments for matplotlib subplots. Only used + for 2D and FacetGrid plots. + **kwargs : optional + Additional arguments to wrapped matplotlib function + + Returns + ------- + artist : + The same type of primitive artist that the wrapped matplotlib + function returns + """ + + # Build on the original docstring + plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + + @override_signature(signature) + @functools.wraps(plotfunc) + def newplotfunc( + darray, + *args, + x=None, + y=None, + hue=None, + figsize=None, + size=None, + aspect=None, + ax=None, + row=None, + col=None, + col_wrap=None, + xincrease=True, + yincrease=True, + add_legend=True, + add_labels=True, + subplot_kws=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + **kwargs, + ): + # All 2d plots in xarray share this function signature. + # Method signature below should be consistent. + + # Handle facetgrids first + if row or col: + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + allargs.pop("plotfunc") + if plotfunc.__name__ == "line": + return _easy_facetgrid(darray, line, kind="line", **allargs) + else: + raise ValueError(f"Faceting not implemented for {plotfunc.__name__}") + + # The allargs dict passed to _easy_facetgrid above contains args + if args == (): + args = kwargs.pop("args", ()) + else: + assert "args" not in kwargs + + ax = get_axis(figsize, size, aspect, ax) + xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) + + primitive = plotfunc(xplt, yplt, ax, *args, add_labels=add_labels, **kwargs) + + if add_labels: + ax.set_title(darray._title_for_slice()) + + if hueplt is not None and add_legend: + if plotfunc.__name__ == "hist": + handles = primitive[-1] + else: + handles = primitive + ax.legend( + handles=handles, + labels=list(hueplt.values), + title=label_from_attrs(hueplt), + ) + + _update_axes( + ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim + ) + + return primitive + + # For use as DataArray.plot.plotmethod + @functools.wraps(newplotfunc) + def plotmethod( + _PlotMethods_obj, + *args, + x=None, + y=None, + figsize=None, + size=None, + aspect=None, + ax=None, + row=None, + col=None, + col_wrap=None, + xincrease=True, + yincrease=True, + add_legend=True, + add_labels=True, + subplot_kws=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + **kwargs, + ): + """ + The method should have the same signature as the function. + + This just makes the method work on Plotmethods objects, + and passes all the other arguments straight through. + """ + allargs = locals() + allargs["darray"] = _PlotMethods_obj._da + allargs.update(kwargs) + for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]: + del allargs[arg] + return newplotfunc(**allargs) + + # Add to class _PlotMethods + setattr(_PlotMethods, plotmethod.__name__, plotmethod) + + return newplotfunc + + +# This function signature should not change so that it can use +# matplotlib format strings +@_plot1d +def line(xplt, yplt, ax, *args, add_labels=True, **kwargs): + """ + Line plot of DataArray index against values + + Wraps :func:`matplotlib:matplotlib.pyplot.plot` + """ + + # Remove pd.Intervals if contained in xplt.values and/or yplt.values. + xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( + xplt.values, yplt.values, kwargs + ) + _ensure_plottable(xplt_val, yplt_val) + + primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) + + if add_labels: + xlabel = label_from_attrs(xplt, extra=x_suffix) + ylabel = label_from_attrs(yplt, extra=y_suffix) + if xlabel is not None: + ax.set_xlabel(xlabel) + if ylabel is not None: + ax.set_ylabel(ylabel) + + # Rotate dates on xlabels + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots + if np.issubdtype(xplt.dtype, np.datetime64): + for xlabels in ax.get_xticklabels(): + xlabels.set_rotation(30) + xlabels.set_ha("right") + + return primitive + + def _plot2d(plotfunc): """ Decorator for common 2d plotting logic From abe74712b84ed601e88654493b40c8e39e133361 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 6 Feb 2021 09:31:27 -0700 Subject: [PATCH 4/6] Fix signatures --- xarray/plot/plot.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index c4e72e3ed01..5befd98c0e3 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -110,16 +110,6 @@ def wrapper(func): return wrapper -# plotfunc and newplotfunc have different signatures: -# - plotfunc: (x, y, z, ax, **kwargs) -# - newplotfunc: (darray, x, y, **kwargs) -# where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray -# and variable names. newplotfunc also explicitly lists most kwargs, so we -# need to shorten it -def signature(darray, x, y, **kwargs): - pass - - def plot( darray, row=None, @@ -400,6 +390,15 @@ def _plot1d(plotfunc): # Build on the original docstring plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + # plotfunc and newplotfunc have different signatures: + # - plotfunc: (x, y, z, ax, **kwargs) + # - newplotfunc: (darray, *args, x, y, **kwargs) + # where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray + # and variable names. newplotfunc also explicitly lists most kwargs, so we + # need to shorten it + def signature(darray, *args, x, y, **kwargs): + pass + @override_signature(signature) @functools.wraps(plotfunc) def newplotfunc( @@ -666,6 +665,15 @@ def _plot2d(plotfunc): # Build on the original docstring plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + # plotfunc and newplotfunc have different signatures: + # - plotfunc: (x, y, z, ax, **kwargs) + # - newplotfunc: (darray, x, y, **kwargs) + # where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray + # and variable names. newplotfunc also explicitly lists most kwargs, so we + # need to shorten it + def signature(darray, x, y, **kwargs): + pass + @override_signature(signature) @functools.wraps(plotfunc) def newplotfunc( From 9fccf873236bb6771f058ddc762af019a4e8d465 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 9 Feb 2021 12:26:29 -0700 Subject: [PATCH 5/6] Apply suggestions from code review Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/plot/plot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 5befd98c0e3..0a253cc2035 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -329,7 +329,7 @@ def step(self, *args, **kwargs): def _plot1d(plotfunc): """ - Decorator for common 2d plotting logic + Decorator for common 1d plotting logic. Also adds the 2d plot method to class _PlotMethods """ @@ -376,7 +376,7 @@ def _plot1d(plotfunc): Use xarray metadata to label axes subplot_kws : dict, optional Dictionary of keyword arguments for matplotlib subplots. Only used - for 2D and FacetGrid plots. + for FacetGrid plots. **kwargs : optional Additional arguments to wrapped matplotlib function @@ -427,7 +427,7 @@ def newplotfunc( ylim=None, **kwargs, ): - # All 2d plots in xarray share this function signature. + # All 1d plots in xarray share this function signature. # Method signature below should be consistent. # Handle facetgrids first From b3bbe8cbc0b2421a170ebeb2239ad142b9c11751 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 9 Feb 2021 12:26:46 -0700 Subject: [PATCH 6/6] Update xarray/plot/plot.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/plot/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 0a253cc2035..2a484464772 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -331,7 +331,7 @@ def _plot1d(plotfunc): """ Decorator for common 1d plotting logic. - Also adds the 2d plot method to class _PlotMethods + Also adds the 1d plot method to class _PlotMethods. """ commondoc = """ Parameters